My python code using slim library to train classification model in Tensorflow:
with tf.contrib.slim.arg_scope(mobilenet_v2.training_scope(weight_decay = 0.001)): logits, _ = mobilenet_v2.mobilenet(images, NUM_CLASSES) cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) cross_entropy = tf.reduce_mean(cross_entropy) global_step = tf.contrib.framework.get_or_create_global_step() train_op = tf.contrib.slim.learning.create_train_op(cross_entropy, opt, global_step = global_step) ... sess.run(train_op)
It works fine. However, no matter what value the ‘weight_decay’ is, the training accuracy of the model could reach higher than 90% easily. It seems ‘weight_decay’ just doesn’t work.
In order to find out the reason, I reviewed the code of Tensorflow for ‘tf.losses.sparse_softmax_cross_entropy()’:
# tensorflow/python/ops/losses/losses_impl.py @tf_export("losses.sparse_softmax_cross_entropy") def sparse_softmax_cross_entropy( labels, logits, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): ... with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss", (logits, labels, weights)) as scope: # As documented above in Args, labels contain class IDs and logits contains # 1 probability per class ID, so we expect rank(logits) - rank(labels) == 1; # therefore, expected_rank_diff=1. labels, logits, weights = _remove_squeezable_dimensions( labels, logits, weights, expected_rank_diff=1) losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits, name="xentropy") return compute_weighted_loss( losses, weights, scope, loss_collection, reduction=reduction)
The ‘losses.sparse_softmax_cross_entropy()’ simply call ‘tf.nn.sparse_softmax_cross_entropy()’. Then let’s look into the implementation of ‘compute_weighted_loss()’:
# tensorflow/python/ops/losses/losses_impl.py @tf_export("losses.compute_weighted_loss") def compute_weighted_loss( losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): ... loss = math_ops.cast(loss, input_dtype) util.add_loss(loss, loss_collection) return loss <pre> What the secret in 'util.add_loss()'? <pre lang='python' masks='6'> # tensorflow/python/ops/losses/util.py @tf_export("losses.add_loss") def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES): ... if loss_collection: ops.add_to_collection(loss_collection, loss)
The losses of ‘losses.sparse_softmax_cross_entropy()’ will be added into collection of ‘GraphKeys.LOSSES’. Then where dose the weight of parameters go ? Will they be added into same collection ? Let’s check. All the layer written by library of ‘tf.layers’ or ‘tf.contrib.slim’ are inherited from ‘class Layer’ and will call ‘add_loss()’ when this layer call ‘add_variable()’. Let’s check ‘add_loss()’ of base class ‘Layer’:
@tf_export('layers.Layer') class Layer(checkpointable.CheckpointableBase): ... def add_loss(self, losses, inputs=None): ... _add_elements_to_collection(losses, ops.GraphKeys.REGULARIZATION_LOSSES)
It’s weird. The loss from weight of variable has not been added into ‘GraphKeys.LOSSES’, but ‘GraphKeys.REGULARIZATION_LOSSES’. Then how could we get all the losses at training stage ? After grep ‘REGULARIZATION_LOSSES’ in whole codes of Tensorflow, it comes up with the ‘get_total_loss()’:
# tensorflow/python/ops/losses/util.py @tf_export("losses.get_total_loss") def get_total_loss(add_regularization_losses=True, name="total_loss"): ... losses = get_losses() if add_regularization_losses: losses += get_regularization_losses() return math_ops.add_n(losses, name=name)
That is the secret of losses in ‘tf.layers’ and ‘tf.contrib.slim’: we should use ‘get_total_loss()’ to fetch model loss and regularization loss together!
After changing my code:
cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) cross_entropy = tf.reduce_mean(cross_entropy) global_step = tf.contrib.framework.get_or_create_global_step() loss = tf.contrib.slim.losses.get_total_loss() train_op = tf.contrib.slim.learning.create_train_op(loss, opt, global_step = global_step) ... sess.run(train_op)
The ‘weight_decay’ works well now (which means training accuracy could not reach high value easily)