API - Cost

To make TensorLayer simple, we minimize the number of cost functions as much as we can. So we encourage you to use TensorFlow’s function. For example, you can implement L1, L2 and sum regularization by tf.nn.l2_loss, tf.contrib.layers.l1_regularizer, tf.contrib.layers.l2_regularizer and tf.contrib.layers.sum_regularizer, see TensorFlow API.

Your cost function

TensorLayer provides a simple way to create you own cost function. Take a MLP below for example.

network = InputLayer(x, name='input')
network = DropoutLayer(network, keep=0.8, name='drop1')
network = DenseLayer(network, n_units=800, act=tf.nn.relu, name='relu1')
network = DropoutLayer(network, keep=0.5, name='drop2')
network = DenseLayer(network, n_units=800, act=tf.nn.relu, name='relu2')
network = DropoutLayer(network, keep=0.5, name='drop3')
network = DenseLayer(network, n_units=10, act=tf.identity, name='output')

The network parameters will be [W1, b1, W2, b2, W_out, b_out], then you can apply L2 regularization on the weights matrix of first two layer as follow.

cost = tl.cost.cross_entropy(y, y_)
cost = cost + tf.contrib.layers.l2_regularizer(0.001)(network.all_params[0])
        + tf.contrib.layers.l2_regularizer(0.001)(network.all_params[2])

Besides, TensorLayer provides a easy way to get all variables by a given name, so you can also apply L2 regularization on some weights as follow.

l2 = 0
for w in tl.layers.get_variables_with_name('W_conv2d', train_only=True, printable=False):
    l2 += tf.contrib.layers.l2_regularizer(1e-4)(w)
cost = tl.cost.cross_entropy(y, y_) + l2

Regularization of Weights

After initializing the variables, the informations of network parameters can be observed by using network.print_params().

tl.layers.initialize_global_variables(sess)
network.print_params()
param 0: (784, 800) (mean: -0.000000, median: 0.000004 std: 0.035524)
param 1: (800,) (mean: 0.000000, median: 0.000000 std: 0.000000)
param 2: (800, 800) (mean: 0.000029, median: 0.000031 std: 0.035378)
param 3: (800,) (mean: 0.000000, median: 0.000000 std: 0.000000)
param 4: (800, 10) (mean: 0.000673, median: 0.000763 std: 0.049373)
param 5: (10,) (mean: 0.000000, median: 0.000000 std: 0.000000)
num of params: 1276810

The output of network is network.outputs, then the cross entropy can be defined as follow. Besides, to regularize the weights, the network.all_params contains all parameters of the network. In this case, network.all_params = [W1, b1, W2, b2, Wout, bout] according to param 0, 1 … 5 shown by network.print_params(). Then max-norm regularization on W1 and W2 can be performed as follow.

max_norm = 0
for w in tl.layers.get_variables_with_name('W', train_only=True, printable=False):
    max_norm += tl.cost.maxnorm_regularizer(1)(w)
cost = tl.cost.cross_entropy(y, y_) + max_norm

In addition, all TensorFlow’s regularizers like tf.contrib.layers.l2_regularizer can be used with TensorLayer.

Regularization of Activation outputs

Instance method network.print_layers() prints all outputs of different layers in order. To achieve regularization on activation output, you can use network.all_layers which contains all outputs of different layers. If you want to apply L1 penalty on the activations of first hidden layer, just simply add tf.contrib.layers.l2_regularizer(lambda_l1)(network.all_layers[1]) to the cost function.

network.print_layers()
layer 0: Tensor("dropout/mul_1:0", shape=(?, 784), dtype=float32)
layer 1: Tensor("Relu:0", shape=(?, 800), dtype=float32)
layer 2: Tensor("dropout_1/mul_1:0", shape=(?, 800), dtype=float32)
layer 3: Tensor("Relu_1:0", shape=(?, 800), dtype=float32)
layer 4: Tensor("dropout_2/mul_1:0", shape=(?, 800), dtype=float32)
layer 5: Tensor("add_2:0", shape=(?, 10), dtype=float32)
cross_entropy(output, target[, name]) Softmax cross-entropy operation, returns the TensorFlow expression of cross-entropy for two distributions, it implements softmax internally.
sigmoid_cross_entropy(output, target[, name]) Sigmoid cross-entropy operation, see tf.nn.sigmoid_cross_entropy_with_logits.
binary_cross_entropy(output, target[, …]) Binary cross entropy operation.
mean_squared_error(output, target[, …]) Return the TensorFlow expression of mean-square-error (L2) of two batch of data.
normalized_mean_square_error(output, target) Return the TensorFlow expression of normalized mean-square-error of two distributions.
absolute_difference_error(output, target[, …]) Return the TensorFlow expression of absolute difference error (L1) of two batch of data.
dice_coe(output, target[, loss_type, axis, …]) Soft dice (Sørensen or Jaccard) coefficient for comparing the similarity of two batch of data, usually be used for binary image segmentation i.
dice_hard_coe(output, target[, threshold, …]) Non-differentiable Sørensen–Dice coefficient for comparing the similarity of two batch of data, usually be used for binary image segmentation i.
iou_coe(output, target[, threshold, axis, …]) Non-differentiable Intersection over Union (IoU) for comparing the similarity of two batch of data, usually be used for evaluating binary image segmentation.
cross_entropy_seq(logits, target_seqs[, …]) Returns the expression of cross-entropy of two sequences, implement softmax internally.
cross_entropy_seq_with_mask(logits, …[, …]) Returns the expression of cross-entropy of two sequences, implement softmax internally.
cosine_similarity(v1, v2) Cosine similarity [-1, 1].
li_regularizer(scale[, scope]) Li regularization removes the neurons of previous layer.
lo_regularizer(scale) Lo regularization removes the neurons of current layer.
maxnorm_regularizer([scale]) Max-norm regularization returns a function that can be used to apply max-norm regularization to weights.
maxnorm_o_regularizer(scale) Max-norm output regularization removes the neurons of current layer.
maxnorm_i_regularizer(scale) Max-norm input regularization removes the neurons of previous layer.

Softmax cross entropy

tensorlayer.cost.cross_entropy(output, target, name=None)[source]

Softmax cross-entropy operation, returns the TensorFlow expression of cross-entropy for two distributions, it implements softmax internally. See tf.nn.sparse_softmax_cross_entropy_with_logits.

Parameters:
  • output (Tensor) – A batch of distribution with shape: [batch_size, num of classes].
  • target (Tensor) – A batch of index with shape: [batch_size, ].
  • name (string) – Name of this loss.

Examples

>>> ce = tl.cost.cross_entropy(y_logits, y_target_logits, 'my_loss')

References

Sigmoid cross entropy

tensorlayer.cost.sigmoid_cross_entropy(output, target, name=None)[source]

Sigmoid cross-entropy operation, see tf.nn.sigmoid_cross_entropy_with_logits.

Parameters:
  • output (Tensor) – A batch of distribution with shape: [batch_size, num of classes].
  • target (Tensor) – A batch of index with shape: [batch_size, ].
  • name (string) – Name of this loss.

Binary cross entropy

tensorlayer.cost.binary_cross_entropy(output, target, epsilon=1e-08, name='bce_loss')[source]

Binary cross entropy operation.

Parameters:
  • output (Tensor) – Tensor with type of float32 or float64.
  • target (Tensor) – The target distribution, format the same with output.
  • epsilon (float) – A small value to avoid output to be zero.
  • name (str) – An optional name to attach to this function.

References

Mean squared error (L2)

tensorlayer.cost.mean_squared_error(output, target, is_mean=False, name='mean_squared_error')[source]

Return the TensorFlow expression of mean-square-error (L2) of two batch of data.

Parameters:
  • output (Tensor) – 2D, 3D or 4D tensor i.e. [batch_size, n_feature], [batch_size, height, width] or [batch_size, height, width, channel].
  • target (Tensor) – The target distribution, format the same with output.
  • is_mean (boolean) –
    Whether compute the mean or sum for each example.
    • If True, use tf.reduce_mean to compute the loss between one target and predict data.
    • If False, use tf.reduce_sum (default).

References

Normalized mean square error

tensorlayer.cost.normalized_mean_square_error(output, target)[source]

Return the TensorFlow expression of normalized mean-square-error of two distributions.

Parameters:
  • output (Tensor) – 2D, 3D or 4D tensor i.e. [batch_size, n_feature], [batch_size, height, width] or [batch_size, height, width, channel].
  • target (Tensor) – The target distribution, format the same with output.

Absolute difference error (L1)

tensorlayer.cost.absolute_difference_error(output, target, is_mean=False)[source]

Return the TensorFlow expression of absolute difference error (L1) of two batch of data.

Parameters:
  • output (Tensor) – 2D, 3D or 4D tensor i.e. [batch_size, n_feature], [batch_size, height, width] or [batch_size, height, width, channel].
  • target (Tensor) – The target distribution, format the same with output.
  • is_mean (boolean) –
    Whether compute the mean or sum for each example.
    • If True, use tf.reduce_mean to compute the loss between one target and predict data.
    • If False, use tf.reduce_sum (default).

Dice coefficient

tensorlayer.cost.dice_coe(output, target, loss_type='jaccard', axis=(1, 2, 3), smooth=1e-05)[source]

Soft dice (Sørensen or Jaccard) coefficient for comparing the similarity of two batch of data, usually be used for binary image segmentation i.e. labels are binary. The coefficient between 0 to 1, 1 means totally match.

Parameters:
  • output (Tensor) – A distribution with shape: [batch_size, ….], (any dimensions).
  • target (Tensor) – The target distribution, format the same with output.
  • loss_type (str) – jaccard or sorensen, default is jaccard.
  • axis (tuple of int) – All dimensions are reduced, default [1,2,3].
  • smooth (float) –
    This small value will be added to the numerator and denominator.
    • If both output and target are empty, it makes sure dice is 1.
    • If either output or target are empty (all pixels are background), dice = `smooth/(small_value + smooth), then if smooth is very small, dice close to 0 (even the image values lower than the threshold), so in this case, higher smooth can have a higher dice.

Examples

>>> outputs = tl.act.pixel_wise_softmax(network.outputs)
>>> dice_loss = 1 - tl.cost.dice_coe(outputs, y_)

References

Hard Dice coefficient

tensorlayer.cost.dice_hard_coe(output, target, threshold=0.5, axis=(1, 2, 3), smooth=1e-05)[source]

Non-differentiable Sørensen–Dice coefficient for comparing the similarity of two batch of data, usually be used for binary image segmentation i.e. labels are binary. The coefficient between 0 to 1, 1 if totally match.

Parameters:
  • output (tensor) – A distribution with shape: [batch_size, ….], (any dimensions).
  • target (tensor) – The target distribution, format the same with output.
  • threshold (float) – The threshold value to be true.
  • axis (tuple of integer) – All dimensions are reduced, default (1,2,3).
  • smooth (float) – This small value will be added to the numerator and denominator, see dice_coe.

References

IOU coefficient

tensorlayer.cost.iou_coe(output, target, threshold=0.5, axis=(1, 2, 3), smooth=1e-05)[source]

Non-differentiable Intersection over Union (IoU) for comparing the similarity of two batch of data, usually be used for evaluating binary image segmentation. The coefficient between 0 to 1, and 1 means totally match.

Parameters:
  • output (tensor) – A batch of distribution with shape: [batch_size, ….], (any dimensions).
  • target (tensor) – The target distribution, format the same with output.
  • threshold (float) – The threshold value to be true.
  • axis (tuple of integer) – All dimensions are reduced, default (1,2,3).
  • smooth (float) – This small value will be added to the numerator and denominator, see dice_coe.

Notes

  • IoU cannot be used as training loss, people usually use dice coefficient for training, IoU and hard-dice for evaluating.

Cross entropy for sequence

tensorlayer.cost.cross_entropy_seq(logits, target_seqs, batch_size=None)[source]

Returns the expression of cross-entropy of two sequences, implement softmax internally. Normally be used for fixed length RNN outputs, see PTB example.

Parameters:
  • logits (Tensor) – 2D tensor with shape of [batch_size * n_steps, n_classes].
  • target_seqs (Tensor) – The target sequence, 2D tensor [batch_size, n_steps], if the number of step is dynamic, please use tl.cost.cross_entropy_seq_with_mask instead.
  • batch_size (None or int.) –
    Whether to divide the cost by batch size.
    • If integer, the return cost will be divided by batch_size.
    • If None (default), the return cost will not be divided by anything.

Examples

>>> see `PTB example <https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_ptb_lstm_state_is_tuple.py>`__.for more details
>>> input_data = tf.placeholder(tf.int32, [batch_size, n_steps])
>>> targets = tf.placeholder(tf.int32, [batch_size, n_steps])
>>> # build the network
>>> print(net.outputs)
... (batch_size * n_steps, n_classes)
>>> cost = tl.cost.cross_entropy_seq(network.outputs, targets)

Cross entropy with mask for sequence

tensorlayer.cost.cross_entropy_seq_with_mask(logits, target_seqs, input_mask, return_details=False, name=None)[source]

Returns the expression of cross-entropy of two sequences, implement softmax internally. Normally be used for Dynamic RNN with Synced sequence input and output.

Parameters:
  • logits (Tensor) – 2D tensor with shape of [batch_size * ?, n_classes], ? means dynamic IDs for each example. - Can be get from DynamicRNNLayer by setting return_seq_2d to True.
  • target_seqs (Tensor) – int of tensor, like word ID. [batch_size, ?], ? means dynamic IDs for each example.
  • input_mask (Tensor) – The mask to compute loss, it has the same size with target_seqs, normally 0 or 1.
  • return_details (boolean) –
    Whether to return detailed losses.
    • If False (default), only returns the loss.
    • If True, returns the loss, losses, weights and targets (see source code).

Examples

>>> batch_size = 64
>>> vocab_size = 10000
>>> embedding_size = 256
>>> input_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="input")
>>> target_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="target")
>>> input_mask = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="mask")
>>> net = tl.layers.EmbeddingInputlayer(
...         inputs = input_seqs,
...         vocabulary_size = vocab_size,
...         embedding_size = embedding_size,
...         name = 'seq_embedding')
>>> net = tl.layers.DynamicRNNLayer(net,
...         cell_fn = tf.contrib.rnn.BasicLSTMCell,
...         n_hidden = embedding_size,
...         dropout = (0.7 if is_train else None),
...         sequence_length = tl.layers.retrieve_seq_length_op2(input_seqs),
...         return_seq_2d = True,
...         name = 'dynamicrnn')
>>> print(net.outputs)
... (?, 256)
>>> net = tl.layers.DenseLayer(net, n_units=vocab_size, name="output")
>>> print(net.outputs)
... (?, 10000)
>>> loss = tl.cost.cross_entropy_seq_with_mask(net.outputs, target_seqs, input_mask)

Cosine similarity

tensorlayer.cost.cosine_similarity(v1, v2)[source]

Cosine similarity [-1, 1].

Parameters:v2 (v1,) – Tensor with the same shape [batch_size, n_feature].
Returns:a tensor of shape [batch_size].
Return type:Tensor

References

Regularization functions

For tf.nn.l2_loss, tf.contrib.layers.l1_regularizer, tf.contrib.layers.l2_regularizer and tf.contrib.layers.sum_regularizer, see TensorFlow API.

Maxnorm

tensorlayer.cost.maxnorm_regularizer(scale=1.0)[source]

Max-norm regularization returns a function that can be used to apply max-norm regularization to weights.

More about max-norm, see wiki-max norm. The implementation follows TensorFlow contrib.

Parameters:scale (float) – A scalar multiplier Tensor. 0.0 disables the regularizer.
Returns:
Return type:A function with signature mn(weights, name=None) that apply Lo regularization.
Raises:ValueError : If scale is outside of the range [0.0, 1.0] or if scale is not a float.

Special

tensorlayer.cost.li_regularizer(scale, scope=None)[source]

Li regularization removes the neurons of previous layer. The i represents inputs. Returns a function that can be used to apply group li regularization to weights. The implementation follows TensorFlow contrib.

Parameters:
  • scale (float) – A scalar multiplier Tensor. 0.0 disables the regularizer.
  • scope (str) – An optional scope name for this function.
Returns:

Return type:

A function with signature li(weights, name=None) that apply Li regularization.

Raises:

ValueError : if scale is outside of the range [0.0, 1.0] or if scale is not a float.

tensorlayer.cost.lo_regularizer(scale)[source]

Lo regularization removes the neurons of current layer. The o represents outputs Returns a function that can be used to apply group lo regularization to weights. The implementation follows TensorFlow contrib.

Parameters:scale (float) – A scalar multiplier Tensor. 0.0 disables the regularizer.
Returns:
Return type:A function with signature lo(weights, name=None) that apply Lo regularization.
Raises:ValueError : If scale is outside of the range [0.0, 1.0] or if scale is not a float.
tensorlayer.cost.maxnorm_o_regularizer(scale)[source]

Max-norm output regularization removes the neurons of current layer. Returns a function that can be used to apply max-norm regularization to each column of weight matrix. The implementation follows TensorFlow contrib.

Parameters:scale (float) – A scalar multiplier Tensor. 0.0 disables the regularizer.
Returns:
Return type:A function with signature mn_o(weights, name=None) that apply Lo regularization.
Raises:ValueError : If scale is outside of the range [0.0, 1.0] or if scale is not a float.
tensorlayer.cost.maxnorm_i_regularizer(scale)[source]

Max-norm input regularization removes the neurons of previous layer. Returns a function that can be used to apply max-norm regularization to each row of weight matrix. The implementation follows TensorFlow contrib.

Parameters:scale (float) – A scalar multiplier Tensor. 0.0 disables the regularizer.
Returns:
Return type:A function with signature mn_i(weights, name=None) that apply Lo regularization.
Raises:ValueError : If scale is outside of the range [0.0, 1.0] or if scale is not a float.