API - Layers

To make TensorLayer simple, we minimize the number of layer classes as much as we can. So we encourage you to use TensorFlow’s function. For example, we do not provide layer for local response normalization, we suggest you to apply tf.nn.lrn on Layer.outputs. More functions can be found in TensorFlow API

Understand layer

All TensorLayer layers have a number of properties in common:

  • layer.outputs : Tensor, the outputs of current layer.
  • layer.all_params : a list of Tensor, all network variables in order.
  • layer.all_layers : a list of Tensor, all network outputs in order.
  • layer.all_drop : a dictionary of {placeholder : float}, all keeping probabilities of noise layer.

All TensorLayer layers have a number of methods in common:

  • layer.print_params() : print the network variables information in order (after sess.run(tf.initialize_all_variables())). alternatively, print all variables by tl.layers.print_all_variables().
  • layer.print_layers() : print the network layers information in order.
  • layer.count_params() : print the number of parameters in the network.

The initialization of a network is done by input layer, then we can stacked layers as follow, then a network is a Layer class. The most important properties of a network are network.all_params, network.all_layers and network.all_drop. The all_params is a list which store all pointers of all network parameters in order, the following script define a 3 layer network, then:

all_params = [W1, b1, W2, b2, W_out, b_out]

The all_layers is a list which store all pointers of the outputs of all layers, in the following network:

all_layers = [drop(?,784), relu(?,800), drop(?,800), relu(?,800), drop(?,800)], identity(?,10)]

where ? reflects any batch size. You can print the layer information and parameters information by using network.print_layers() and network.print_params(). To count the number of parameters in a network, run network.count_params().

sess = tf.InteractiveSession()

x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
y_ = tf.placeholder(tf.int64, shape=[None, ], name='y_')

network = tl.layers.InputLayer(x, name='input_layer')
network = tl.layers.DropoutLayer(network, keep=0.8, name='drop1')
network = tl.layers.DenseLayer(network, n_units=800,
                                act = tf.nn.relu, name='relu1')
network = tl.layers.DropoutLayer(network, keep=0.5, name='drop2')
network = tl.layers.DenseLayer(network, n_units=800,
                                act = tf.nn.relu, name='relu2')
network = tl.layers.DropoutLayer(network, keep=0.5, name='drop3')
network = tl.layers.DenseLayer(network, n_units=10,
                                act = tl.activation.identity,
                                name='output_layer')

y = network.outputs
y_op = tf.argmax(tf.nn.softmax(y), 1)

cost = tl.cost.cross_entropy(y, y_)

train_params = network.all_params

train_op = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999,
                            epsilon=1e-08, use_locking=False).minimize(cost, var_list = train_params)

sess.run(tf.initialize_all_variables())

network.print_params()
network.print_layers()

In addition, network.all_drop is a dictionary which stores the keeping probabilities of all noise layer. In the above network, they are the keeping probabilities of dropout layers.

So for training, enable all dropout layers as follow.

feed_dict = {x: X_train_a, y_: y_train_a}
feed_dict.update( network.all_drop )
loss, _ = sess.run([cost, train_op], feed_dict=feed_dict)
feed_dict.update( network.all_drop )

For evaluating and testing, disable all dropout layers as follow.

feed_dict = {x: X_val, y_: y_val}
feed_dict.update(dp_dict)
print("   val loss: %f" % sess.run(cost, feed_dict=feed_dict))
print("   val acc: %f" % np.mean(y_val ==
                        sess.run(y_op, feed_dict=feed_dict)))

For more details, please read the MNIST examples.

Creating custom layers

Understand Dense layer

Before creating your own TensorLayer layer, let’s have a look at Dense layer. It creates a weights matrix and biases vector if not exists, then implement the output expression. At the end, as a layer with parameter, we also need to append the parameters into all_params.

class DenseLayer(Layer):
    """
    The :class:`DenseLayer` class is a fully connected layer.

    Parameters
    ----------
    layer : a :class:`Layer` instance
        The `Layer` class feeding into this layer.
    n_units : int
        The number of units of the layer.
    act : activation function
        The function that is applied to the layer activations.
    W_init : weights initializer
        The initializer for initializing the weight matrix.
    b_init : biases initializer
        The initializer for initializing the bias vector.
    W_init_args : dictionary
        The arguments for the weights tf.get_variable.
    b_init_args : dictionary
        The arguments for the biases tf.get_variable.
    name : a string or None
        An optional name to attach to this layer.
    """
    def __init__(
        self,
        layer = None,
        n_units = 100,
        act = tf.nn.relu,
        W_init = tf.truncated_normal_initializer(stddev=0.1),
        b_init = tf.constant_initializer(value=0.0),
        W_init_args = {},
        b_init_args = {},
        name ='dense_layer',
    ):
        Layer.__init__(self, name=name)
        self.inputs = layer.outputs
        if self.inputs.get_shape().ndims != 2:
            raise Exception("The input dimension must be rank 2")
        n_in = int(self.inputs._shape[-1])
        self.n_units = n_units
        print("  tensorlayer:Instantiate DenseLayer %s: %d, %s" % (self.name, self.n_units, act))
        with tf.variable_scope(name) as vs:
            W = tf.get_variable(name='W', shape=(n_in, n_units), initializer=W_init, **W_init_args )
            b = tf.get_variable(name='b', shape=(n_units), initializer=b_init, **b_init_args )
            self.outputs = act(tf.matmul(self.inputs, W) + b)

        # Hint : list(), dict() is pass by value (shallow).
        self.all_layers = list(layer.all_layers)
        self.all_params = list(layer.all_params)
        self.all_drop = dict(layer.all_drop)
        self.all_layers.extend( [self.outputs] )
        self.all_params.extend( [W, b] )

A simple layer

To implement a custom layer in TensorLayer, you will have to write a Python class that subclasses Layer and implement the outputs expression.

The following is an example implementation of a layer that multiplies its input by 2:

class DoubleLayer(Layer):
    def __init__(
        self,
        layer = None,
        name ='double_layer',
    ):
        Layer.__init__(self, name=name)
        self.inputs = layer.outputs
        self.outputs = self.inputs * 2

        self.all_layers = list(layer.all_layers)
        self.all_params = list(layer.all_params)
        self.all_drop = dict(layer.all_drop)
        self.all_layers.extend( [self.outputs] )

Modifying Pre-train Behaviour

Greedy layer-wise pretrain is an important task for deep neural network initialization, while there are many kinds of pre-train methods according to different network architectures and applications.

For example, the pre-train process of Vanilla Sparse Autoencoder can be implemented by using KL divergence (for sigmoid) as the following code, but for Deep Rectifier Network, the sparsity can be implemented by using the L1 regularization of activation output.

# Vanilla Sparse Autoencoder
beta = 4
rho = 0.15
p_hat = tf.reduce_mean(activation_out, reduction_indices = 0)
KLD = beta * tf.reduce_sum( rho * tf.log(tf.div(rho, p_hat))
        + (1- rho) * tf.log((1- rho)/ (tf.sub(float(1), p_hat))) )

There are many pre-train methods, for this reason, TensorLayer provides a simple way to modify or design your own pre-train method. For Autoencoder, TensorLayer uses ReconLayer.__init__() to define the reconstruction layer and cost function, to define your own cost function, just simply modify the self.cost in ReconLayer.__init__(). To creat your own cost expression please read Tensorflow Math. By default, ReconLayer only updates the weights and biases of previous 1 layer by using self.train_params = self.all _params[-4:], where the 4 parameters are [W_encoder, b_encoder, W_decoder, b_decoder], where W_encoder, b_encoder belong to previous DenseLayer, W_decoder, b_decoder belong to this ReconLayer. In addition, if you want to update the parameters of previous 2 layers at the same time, simply modify [-4:] to [-6:].

ReconLayer.__init__(...):
    ...
    self.train_params = self.all_params[-4:]
    ...
      self.cost = mse + L1_a + L2_w
Layer([inputs, name]) The Layer class represents a single layer of a neural network.
InputLayer([inputs, n_features, name]) The InputLayer class is the starting layer of a neural network.
Word2vecEmbeddingInputlayer([inputs, …]) The Word2vecEmbeddingInputlayer class is a fully connected layer, for Word Embedding.
EmbeddingInputlayer([inputs, …]) The EmbeddingInputlayer class is a fully connected layer, for Word Embedding.
DenseLayer([layer, n_units, act, W_init, …]) The DenseLayer class is a fully connected layer.
ReconLayer([layer, x_recon, name, n_units, act]) The ReconLayer class is a reconstruction layer DenseLayer which use to pre-train a DenseLayer.
DropoutLayer([layer, keep, name]) The DropoutLayer class is a noise layer which randomly set some values to zero by a given keeping probability.
DropconnectDenseLayer([layer, keep, …]) The DropconnectDenseLayer class is DenseLayer with DropConnect behaviour which randomly remove connection between this layer to previous layer by a given keeping probability.
Conv2dLayer([layer, act, shape, strides, …]) The Conv2dLayer class is a 2D CNN layer, see tf.nn.conv2d.
Conv3dLayer([layer, act, shape, strides, …]) The Conv3dLayer class is a 3D CNN layer, see tf.nn.conv3d.
DeConv3dLayer([layer, act, shape, …]) The DeConv3dLayer class is deconvolutional 3D layer, see tf.nn.conv3d_transpose.
PoolLayer([layer, ksize, strides, padding, …]) The PoolLayer class is a Pooling layer, you can choose tf.nn.max_pool and tf.nn.avg_pool for 2D or tf.nn.max_pool3d() and tf.nn.avg_pool3d() for 3D.
RNNLayer([layer, cell_fn, cell_init_args, …]) The RNNLayer class is a RNN layer, you can implement vanilla RNN, LSTM and GRU with it.
FlattenLayer([layer, name]) The FlattenLayer class is layer which reshape high-dimension input to a vector.
ConcatLayer([layer, concat_dim, name]) The ConcatLayer class is layer which concat (merge) two or more DenseLayer to a single class:DenseLayer.
ReshapeLayer([layer, shape, name]) The ReshapeLayer class is layer which reshape the tensor.
SlimNetsLayer([layer, slim_layer, …]) The SlimNetsLayer class can be used to merge all TF-Slim nets into TensorLayer.
MultiplexerLayer([layer, name]) The MultiplexerLayer selects one of several input and forwards the selected input into the output, see tutorial_mnist_multiplexer.py.
EmbeddingAttentionSeq2seqWrapper(…[, …]) Sequence-to-sequence model with attention and for multiple buckets.
flatten_reshape(variable[, name]) Reshapes high-dimension input to a vector.
clear_layers_name() Clear all layer names in set_keep[‘_layers_name_list’], enable layer name reuse.
set_name_reuse([enable]) Enable or disable reuse layer name.
print_all_variables([train_only]) Print all trainable and non-trainable variables without initialize_all_variables()
initialize_rnn_state(state) Return the initialized RNN state.

Basic layer

class tensorlayer.layers.Layer(inputs=None, name='layer')[source]

The Layer class represents a single layer of a neural network. It should be subclassed when implementing new types of layers. Because each layer can keep track of the layer(s) feeding into it, a network’s output Layer instance can double as a handle to the full network.

Parameters:
inputs : a Layer instance

The Layer class feeding into this layer.

name : a string or None

An optional name to attach to this layer.

Methods

count_params() Return the number of parameters in the network
print_layers() Print all info of layers in the network
print_params([details]) Print all info of parameters in the network

Input layer

class tensorlayer.layers.InputLayer(inputs=None, n_features=None, name='input_layer')[source]

The InputLayer class is the starting layer of a neural network.

Parameters:
inputs : a TensorFlow placeholder

The input tensor data.

name : a string or None

An optional name to attach to this layer.

n_features : a int

The number of features. If not specify, it will assume the input is with the shape of [batch_size, n_features], then select the second element as the n_features. It is used to specify the matrix size of next layer. If apply Convolutional layer after InputLayer, n_features is not important.

Methods

count_params() Return the number of parameters in the network
print_layers() Print all info of layers in the network
print_params([details]) Print all info of parameters in the network

Word Embedding Input layer

Word2vec layer for training

class tensorlayer.layers.Word2vecEmbeddingInputlayer(inputs=None, train_labels=None, vocabulary_size=80000, embedding_size=200, num_sampled=64, nce_loss_args={}, E_init=<tensorflow.python.ops.init_ops.RandomUniform object>, E_init_args={}, nce_W_init=<tensorflow.python.ops.init_ops.TruncatedNormal object>, nce_W_init_args={}, nce_b_init=<tensorflow.python.ops.init_ops.Constant object>, nce_b_init_args={}, name='word2vec_layer')[source]

The Word2vecEmbeddingInputlayer class is a fully connected layer, for Word Embedding. Words are input as integer index. The output is the embedded word vector.

Parameters:
inputs : placeholder

For word inputs. integer index format.

train_labels : placeholder

For word labels. integer index format.

vocabulary_size : int

The size of vocabulary, number of words.

embedding_size : int

The number of embedding dimensions.

num_sampled : int

The Number of negative examples for NCE loss.

nce_loss_args : a dictionary

The arguments for tf.nn.nce_loss()

E_init : embedding initializer

The initializer for initializing the embedding matrix.

E_init_args : a dictionary

The arguments for embedding initializer

nce_W_init : NCE decoder biases initializer

The initializer for initializing the nce decoder weight matrix.

nce_W_init_args : a dictionary

The arguments for initializing the nce decoder weight matrix.

nce_b_init : NCE decoder biases initializer

The initializer for tf.get_variable() of the nce decoder bias vector.

nce_b_init_args : a dictionary

The arguments for tf.get_variable() of the nce decoder bias vector.

name : a string or None

An optional name to attach to this layer.

References

tensorflow/examples/tutorials/word2vec/word2vec_basic.py

Examples

>>> Without TensorLayer : see tensorflow/examples/tutorials/word2vec/word2vec_basic.py
>>> train_inputs = tf.placeholder(tf.int32, shape=[batch_size])
>>> train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
>>> embeddings = tf.Variable(
...     tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
>>> embed = tf.nn.embedding_lookup(embeddings, train_inputs)
>>> nce_weights = tf.Variable(
...     tf.truncated_normal([vocabulary_size, embedding_size],
...                    stddev=1.0 / math.sqrt(embedding_size)))
>>> nce_biases = tf.Variable(tf.zeros([vocabulary_size]))
>>> cost = tf.reduce_mean(
...    tf.nn.nce_loss(weights=nce_weights, biases=nce_biases,
...               inputs=embed, labels=train_labels,
...               num_sampled=num_sampled, num_classes=vocabulary_size,
...               num_true=1))
>>> With TensorLayer : see tutorial_word2vec_basic.py
>>> train_inputs = tf.placeholder(tf.int32, shape=[batch_size])
>>> train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
>>> emb_net = tl.layers.Word2vecEmbeddingInputlayer(
...         inputs = train_inputs,
...         train_labels = train_labels,
...         vocabulary_size = vocabulary_size,
...         embedding_size = embedding_size,
...         num_sampled = num_sampled,
...         nce_loss_args = {},
...         E_init = tf.random_uniform,
...         E_init_args = {'minval':-1.0, 'maxval':1.0},
...         nce_W_init = tf.truncated_normal,
...         nce_W_init_args = {'stddev': float(1.0/np.sqrt(embedding_size))},
...         nce_b_init = tf.zeros,
...         nce_b_init_args = {},
...        name ='word2vec_layer',
...    )
>>> cost = emb_net.nce_cost
>>> train_params = emb_net.all_params
>>> train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(
...                                             cost, var_list=train_params)
>>> normalized_embeddings = emb_net.normalized_embeddings

Methods

count_params() Return the number of parameters in the network
print_layers() Print all info of layers in the network
print_params([details]) Print all info of parameters in the network

Embedding Input layer

class tensorlayer.layers.EmbeddingInputlayer(inputs=None, vocabulary_size=80000, embedding_size=200, E_init=<tensorflow.python.ops.init_ops.RandomUniform object>, E_init_args={}, name='embedding_layer')[source]

The EmbeddingInputlayer class is a fully connected layer, for Word Embedding. Words are input as integer index. The output is the embedded word vector.

This class can not be used to train a word embedding matrix, so you should assign a trained matrix into it. To train a word embedding matrix, you can used class:Word2vecEmbeddingInputlayer.

Note that, do not update this embedding matrix.

Parameters:
inputs : placeholder

For word inputs. integer index format. a 2D tensor : [batch_size, num_steps(num_words)]

vocabulary_size : int

The size of vocabulary, number of words.

embedding_size : int

The number of embedding dimensions.

E_init : embedding initializer

The initializer for initializing the embedding matrix.

E_init_args : a dictionary

The arguments for embedding initializer

name : a string or None

An optional name to attach to this layer.

Examples

>>> vocabulary_size = 50000
>>> embedding_size = 200
>>> model_file_name = "model_word2vec_50k_200"
>>> batch_size = None
...
>>> all_var = tl.files.load_npy_to_any(name=model_file_name+'.npy')
>>> data = all_var['data']; count = all_var['count']
>>> dictionary = all_var['dictionary']
>>> reverse_dictionary = all_var['reverse_dictionary']
>>> tl.files.save_vocab(count, name='vocab_'+model_file_name+'.txt')
>>> del all_var, data, count
...
>>> load_params = tl.files.load_npz(name=model_file_name+'.npz')
>>> x = tf.placeholder(tf.int32, shape=[batch_size])
>>> y_ = tf.placeholder(tf.int32, shape=[batch_size, 1])
>>> emb_net = tl.layers.EmbeddingInputlayer(
...                inputs = x,
...                vocabulary_size = vocabulary_size,
...                embedding_size = embedding_size,
...                name ='embedding_layer')
>>> sess.run(tf.initialize_all_variables())
>>> tl.files.assign_params(sess, [load_params[0]], emb_net)
>>> word = b'hello'
>>> word_id = dictionary[word]
>>> print('word_id:', word_id)
... 6428
...
>>> words = [b'i', b'am', b'hao', b'dong']
>>> word_ids = tl.files.words_to_word_ids(words, dictionary)
>>> context = tl.files.word_ids_to_words(word_ids, reverse_dictionary)
>>> print('word_ids:', word_ids)
... [72, 1226, 46744, 20048]
>>> print('context:', context)
... [b'i', b'am', b'hao', b'dong']
...
>>> vector = sess.run(emb_net.outputs, feed_dict={x : [word_id]})
>>> print('vector:', vector.shape)
... (1, 200)
>>> vectors = sess.run(emb_net.outputs, feed_dict={x : word_ids})
>>> print('vectors:', vectors.shape)
... (4, 200)

Methods

count_params() Return the number of parameters in the network
print_layers() Print all info of layers in the network
print_params([details]) Print all info of parameters in the network

Dense layer

Dense layer

class tensorlayer.layers.DenseLayer(layer=None, n_units=100, act=<function relu>, W_init=<tensorflow.python.ops.init_ops.TruncatedNormal object>, b_init=<tensorflow.python.ops.init_ops.Constant object>, W_init_args={}, b_init_args={}, name='dense_layer')[source]

The DenseLayer class is a fully connected layer.

Parameters:
layer : a Layer instance

The Layer class feeding into this layer.

n_units : int

The number of units of the layer.

act : activation function

The function that is applied to the layer activations.

W_init : weights initializer

The initializer for initializing the weight matrix.

b_init : biases initializer

The initializer for initializing the bias vector.

W_init_args : dictionary

The arguments for the weights tf.get_variable.

b_init_args : dictionary

The arguments for the biases tf.get_variable.

name : a string or None

An optional name to attach to this layer.

Examples

>>> network = tl.layers.InputLayer(x, name='input_layer')
>>> network = tl.layers.DenseLayer(
...                 network,
...                 n_units=800,
...                 act = tf.nn.relu,
...                 W_init=tf.truncated_normal_initializer(stddev=0.1),
...                 name ='relu_layer'
...                 )
>>> Without TensorLayer, you can do as follow.
>>> W = tf.Variable(
...     tf.random_uniform([n_in, n_units], -1.0, 1.0), name='W')
>>> b = tf.Variable(tf.zeros(shape=[n_units]), name='b')
>>> y = tf.nn.relu(tf.matmul(inputs, W) + b)

Methods

count_params() Return the number of parameters in the network
print_layers() Print all info of layers in the network
print_params([details]) Print all info of parameters in the network

Reconstruction layer for Autoencoder

class tensorlayer.layers.ReconLayer(layer=None, x_recon=None, name='recon_layer', n_units=784, act=<function softplus>)[source]

The ReconLayer class is a reconstruction layer DenseLayer which use to pre-train a DenseLayer.

Parameters:
layer : a Layer instance

The Layer class feeding into this layer.

x_recon : tensorflow variable

The variables used for reconstruction.

name : a string or None

An optional name to attach to this layer.

n_units : int

The number of units of the layer, should be equal to x_recon

act : activation function

The activation function that is applied to the reconstruction layer. Normally, for sigmoid layer, the reconstruction activation is sigmoid; for rectifying layer, the reconstruction activation is softplus.

Notes

The input layer should be DenseLayer or a layer has only one axes. You may need to modify this part to define your own cost function. By default, the cost is implemented as follow:

For sigmoid layer, the implementation can be UFLDL

For rectifying layer, the implementation can be Glorot (2011). Deep Sparse Rectifier Neural Networks

Examples

>>> network = tl.layers.InputLayer(x, name='input_layer')
>>> network = tl.layers.DenseLayer(network, n_units=196,
...                                 act=tf.nn.sigmoid, name='sigmoid1')
>>> recon_layer1 = tl.layers.ReconLayer(network, x_recon=x, n_units=784,
...                                 act=tf.nn.sigmoid, name='recon_layer1')
>>> recon_layer1.pretrain(sess, x=x, X_train=X_train, X_val=X_val,
...                         denoise_name=None, n_epoch=1200, batch_size=128,
...                         print_freq=10, save=True, save_name='w1pre_')

Methods

pretrain(self, sess, x, X_train, X_val, denoise_name=None, n_epoch=100, batch_size=128, print_freq=10, save=True, save_name=’w1pre_’) Start to pre-train the parameters of previous DenseLayer.

Noise layer

Dropout layer

class tensorlayer.layers.DropoutLayer(layer=None, keep=0.5, name='dropout_layer')[source]

The DropoutLayer class is a noise layer which randomly set some values to zero by a given keeping probability.

Parameters:
layer : a Layer instance

The Layer class feeding into this layer.

keep : float

The keeping probability, the lower more values will be set to zero.

name : a string or None

An optional name to attach to this layer.

Examples

>>> network = tl.layers.InputLayer(x, name='input_layer')
>>> network = tl.layers.DropoutLayer(network, keep=0.8, name='drop1')
>>> network = tl.layers.DenseLayer(network, n_units=800, act = tf.nn.relu, name='relu1')

Methods

count_params() Return the number of parameters in the network
print_layers() Print all info of layers in the network
print_params([details]) Print all info of parameters in the network

Dropconnect + Dense layer

class tensorlayer.layers.DropconnectDenseLayer(layer=None, keep=0.5, n_units=100, act=<function relu>, W_init=<tensorflow.python.ops.init_ops.TruncatedNormal object>, b_init=<tensorflow.python.ops.init_ops.Constant object>, W_init_args={}, b_init_args={}, name='dropconnect_layer')[source]

The DropconnectDenseLayer class is DenseLayer with DropConnect behaviour which randomly remove connection between this layer to previous layer by a given keeping probability.

Parameters:
layer : a Layer instance

The Layer class feeding into this layer.

keep : float

The keeping probability, the lower more values will be set to zero.

n_units : int

The number of units of the layer.

act : activation function

The function that is applied to the layer activations.

W_init : weights initializer

The initializer for initializing the weight matrix.

b_init : biases initializer

The initializer for initializing the bias vector.

W_init_args : dictionary

The arguments for the weights tf.get_variable().

b_init_args : dictionary

The arguments for the biases tf.get_variable().

name : a string or None

An optional name to attach to this layer.

References

Wan, L. (2013). Regularization of neural networks using dropconnect

Examples

>>> network = tl.layers.InputLayer(x, name='input_layer')
>>> network = tl.layers.DropconnectDenseLayer(network, keep = 0.8,
...         n_units=800, act = tf.nn.relu, name='dropconnect_relu1')
>>> network = tl.layers.DropconnectDenseLayer(network, keep = 0.5,
...         n_units=800, act = tf.nn.relu, name='dropconnect_relu2')
>>> network = tl.layers.DropconnectDenseLayer(network, keep = 0.5,
...         n_units=10, act = tl.activation.identity, name='output_layer')

Methods

count_params() Return the number of parameters in the network
print_layers() Print all info of layers in the network
print_params([details]) Print all info of parameters in the network

Convolutional layer

1D Convolutional layer

We don’t provide 1D CNN layer, actually TensorFlow only provides tf.nn.conv2d and tf.nn.conv3d, so to implement 1D CNN, you can use Reshape layer as follow.

x = tf.placeholder(tf.float32, shape=[None, 500], name='x')
network = tl.layers.ReshapeLayer(x, shape=[-1, 500, 1, 1], name='reshape')
network = tl.layers.Conv2dLayer(network,
                    act = tf.nn.relu,
                    shape = [10, 1, 1, 16], # 16 features
                    strides=[1, 2, 1, 1],   # stride of 2
                    padding='SAME',
                    name = 'cnn')

2D Convolutional layer

class tensorlayer.layers.Conv2dLayer(layer=None, act=<function relu>, shape=[5, 5, 1, 100], strides=[1, 1, 1, 1], padding='SAME', W_init=<tensorflow.python.ops.init_ops.TruncatedNormal object>, b_init=<tensorflow.python.ops.init_ops.Constant object>, W_init_args={}, b_init_args={}, name='cnn_layer')[source]

The Conv2dLayer class is a 2D CNN layer, see tf.nn.conv2d.

Parameters:
layer : a Layer instance

The Layer class feeding into this layer.

act : activation function

The function that is applied to the layer activations.

shape : list of shape

shape of the filters, [filter_height, filter_width, in_channels, out_channels].

strides : a list of ints.

The stride of the sliding window for each dimension of input.

It Must be in the same order as the dimension specified with format.

padding : a string from: “SAME”, “VALID”.

The type of padding algorithm to use.

W_init : weights initializer

The initializer for initializing the weight matrix.

b_init : biases initializer

The initializer for initializing the bias vector.

W_init_args : dictionary

The arguments for the weights tf.get_variable().

b_init_args : dictionary

The arguments for the biases tf.get_variable().

name : a string or None

An optional name to attach to this layer.

Examples

>>> x = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])
>>> network = tl.layers.InputLayer(x, name='input_layer')
>>> network = tl.layers.Conv2dLayer(network,
...                   act = tf.nn.relu,
...                   shape = [5, 5, 1, 32],  # 32 features for each 5x5 patch
...                   strides=[1, 1, 1, 1],
...                   padding='SAME',
...                   W_init=tf.truncated_normal_initializer(stddev=5e-2),
...                   W_init_args={},
...                   b_init = tf.constant_initializer(value=0.0),
...                   b_init_args = {},
...                   name ='cnn_layer1')     # output: (?, 28, 28, 32)
>>> network = tl.layers.PoolLayer(network,
...                   ksize=[1, 2, 2, 1],
...                   strides=[1, 2, 2, 1],
...                   padding='SAME',
...                   pool = tf.nn.max_pool,
...                   name ='pool_layer1',)   # output: (?, 14, 14, 32)
>>> Without TensorLayer, you can initialize the parameters as follow.
>>> W = tf.Variable(W_init(shape=[5, 5, 1, 32], ), name='W_conv')
>>> b = tf.Variable(b_init(shape=[32], ), name='b_conv')
>>> outputs = tf.nn.relu( tf.nn.conv2d(inputs, W,
...                       strides=[1, 1, 1, 1],
...                       padding='SAME') + b )

Methods

count_params() Return the number of parameters in the network
print_layers() Print all info of layers in the network
print_params([details]) Print all info of parameters in the network

2D Deconvolutional layer

3D Convolutional layer

class tensorlayer.layers.Conv3dLayer(layer=None, act=<function relu>, shape=[], strides=[], padding='SAME', W_init=<tensorflow.python.ops.init_ops.TruncatedNormal object>, b_init=<tensorflow.python.ops.init_ops.Constant object>, W_init_args={}, b_init_args={}, name='cnn3d_layer')[source]

The Conv3dLayer class is a 3D CNN layer, see tf.nn.conv3d.

Parameters:
layer : a Layer instance

The Layer class feeding into this layer.

act : activation function

The function that is applied to the layer activations.

shape : list of shape

shape of the filters, [filter_depth, filter_height, filter_width, in_channels, out_channels].

strides : a list of ints. 1-D of length 4.

The stride of the sliding window for each dimension of input. Must be in the same order as the dimension specified with format.

padding : a string from: “SAME”, “VALID”.

The type of padding algorithm to use.

W_init : weights initializer

The initializer for initializing the weight matrix.

b_init : biases initializer

The initializer for initializing the bias vector.

W_init_args : dictionary

The arguments for the weights initializer.

b_init_args : dictionary

The arguments for the biases initializer.

name : a string or None

An optional name to attach to this layer.

Methods

count_params() Return the number of parameters in the network
print_layers() Print all info of layers in the network
print_params([details]) Print all info of parameters in the network

3D Deconvolutional layer

class tensorlayer.layers.DeConv3dLayer(layer=None, act=<function relu>, shape=[2, 2, 2, 512, 1024], output_shape=[None, 50, 50, 50, 512], strides=[1, 2, 2, 2, 1], padding='SAME', W_init=<tensorflow.python.ops.init_ops.TruncatedNormal object>, b_init=<tensorflow.python.ops.init_ops.Constant object>, W_init_args={}, b_init_args={}, name='decnn_layer')[source]

The DeConv3dLayer class is deconvolutional 3D layer, see tf.nn.conv3d_transpose.

Parameters:
layer : a Layer instance

The Layer class feeding into this layer.

act : activation function

The function that is applied to the layer activations.

shape : list of shape

shape of the filters, [depth, height, width, output_channels, in_channels], filter’s in_channels dimension must match that of value.

output_shape : list of output shape

representing the output shape of the deconvolution op.

strides : a list of ints.

The stride of the sliding window for each dimension of the input tensor.

padding : a string from: “SAME”, “VALID”.

The type of padding algorithm to use.

W_init : weights initializer

The initializer for initializing the weight matrix.

b_init : biases initializer

The initializer for initializing the bias vector.

W_init_args : dictionary

The arguments for the weights initializer.

b_init_args : dictionary

The arguments for the biases initializer.

name : a string or None

An optional name to attach to this layer.

Methods

count_params() Return the number of parameters in the network
print_layers() Print all info of layers in the network
print_params([details]) Print all info of parameters in the network

Pooling layer

Max or Mean Pooling layer for any dimensions

class tensorlayer.layers.PoolLayer(layer=None, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', pool=<function max_pool>, name='pool_layer')[source]

The PoolLayer class is a Pooling layer, you can choose tf.nn.max_pool and tf.nn.avg_pool for 2D or tf.nn.max_pool3d() and tf.nn.avg_pool3d() for 3D.

Parameters:
layer : a Layer instance

The Layer class feeding into this layer.

ksize : a list of ints that has length >= 4.

The size of the window for each dimension of the input tensor.

strides : a list of ints that has length >= 4.

The stride of the sliding window for each dimension of the input tensor.

padding : a string from: “SAME”, “VALID”.

The type of padding algorithm to use.

pool : a pooling function

tf.nn.max_pool , tf.nn.avg_pool …

name : a string or None

An optional name to attach to this layer.

References

TensorFlow Pooling

Examples

see Conv2dLayer

Methods

count_params() Return the number of parameters in the network
print_layers() Print all info of layers in the network
print_params([details]) Print all info of parameters in the network

Recurrent layer

Recurrent layer for any cell (LSTM, GRU etc)

class tensorlayer.layers.RNNLayer(layer=None, cell_fn=<class 'tensorflow.python.ops.rnn_cell_impl.BasicRNNCell'>, cell_init_args={}, n_hidden=100, initializer=<tensorflow.python.ops.init_ops.RandomUniform object>, n_steps=5, return_last=False, return_seq_2d=False, name='rnn_layer')[source]

The RNNLayer class is a RNN layer, you can implement vanilla RNN, LSTM and GRU with it.

Parameters:
layer : a Layer instance

The Layer class feeding into this layer.

cell_fn : a TensorFlow’s core RNN cell as follow.

see RNN Cells in TensorFlow

class tf.nn.rnn_cell.BasicRNNCell

class tf.nn.rnn_cell.BasicLSTMCell

class tf.nn.rnn_cell.GRUCell

class tf.nn.rnn_cell.LSTMCell

cell_init_args : a dictionary

The arguments for the cell initializer.

n_hidden : a int

The number of hidden units in the layer.

n_steps : a int

The sequence length.

return_last : boolen

If True, return the last output, “Sequence input and single output”

If False, return all outputs, “Synced sequence input and output”

In other word, if you want to apply one or more RNN(s) on this layer, set to False.

return_seq_2d : boolen

When return_last = False

if True, return 2D Tensor [n_example, n_hidden], for stacking DenseLayer after it. if False, return 3D Tensor [n_example/n_steps, n_steps, n_hidden], for stacking multiple RNN after it.

name : a string or None

An optional name to attach to this layer.

Notes

If the input to this layer has more than two axes, we need to flatten the input by using FlattenLayer.

References

Neural Network RNN Cells in TensorFlow

tensorflow/python/ops/rnn.py

tensorflow/python/ops/rnn_cell.py

see TensorFlow tutorial ptb_word_lm.py, TensorLayer tutorials tutorial_ptb_lstm.py and tutorial_generate_text.py

Examples

>>> For words
>>> input_data = tf.placeholder(tf.int32, [batch_size, num_steps])
>>> network = tl.layers.EmbeddingInputlayer(
...                 inputs = input_data,
...                 vocabulary_size = vocab_size,
...                 embedding_size = hidden_size,
...                 E_init = tf.random_uniform_initializer(-init_scale, init_scale),
...                 name ='embedding_layer')
>>> if is_training:
>>>     network = tl.layers.DropoutLayer(network, keep=keep_prob, name='drop1')
>>> network = tl.layers.RNNLayer(network,
...             cell_fn=tf.nn.rnn_cell.BasicLSTMCell,
...             cell_init_args={'forget_bias': 0.0},# 'state_is_tuple': True},
...             n_hidden=hidden_size,
...             initializer=tf.random_uniform_initializer(-init_scale, init_scale),
...             n_steps=num_steps,
...             return_last=False,
...             name='basic_lstm_layer1')
>>> lstm1 = network
>>> if is_training:
>>>     network = tl.layers.DropoutLayer(network, keep=keep_prob, name='drop2')
>>> network = tl.layers.RNNLayer(network,
...             cell_fn=tf.nn.rnn_cell.BasicLSTMCell,
...             cell_init_args={'forget_bias': 0.0}, # 'state_is_tuple': True},
...             n_hidden=hidden_size,
...             initializer=tf.random_uniform_initializer(-init_scale, init_scale),
...             n_steps=num_steps,
...             return_last=False,
...             return_seq_2d=True,
...             name='basic_lstm_layer2')
>>> lstm2 = network
>>> if is_training:
>>>     network = tl.layers.DropoutLayer(network, keep=keep_prob, name='drop3')
>>> network = tl.layers.DenseLayer(network,
...             n_units=vocab_size,
...             W_init=tf.random_uniform_initializer(-init_scale, init_scale),
...             b_init=tf.random_uniform_initializer(-init_scale, init_scale),
...             act = tl.activation.identity, name='output_layer')
>>> For CNN+LSTM
>>> x = tf.placeholder(tf.float32, shape=[batch_size, image_size, image_size, 1])
>>> network = tl.layers.InputLayer(x, name='input_layer')
>>> network = tl.layers.Conv2dLayer(network,
...                         act = tf.nn.relu,
...                         shape = [5, 5, 1, 32],  # 32 features for each 5x5 patch
...                         strides=[1, 2, 2, 1],
...                         padding='SAME',
...                         name ='cnn_layer1')
>>> network = tl.layers.PoolLayer(network,
...                         ksize=[1, 2, 2, 1],
...                         strides=[1, 2, 2, 1],
...                         padding='SAME',
...                         pool = tf.nn.max_pool,
...                         name ='pool_layer1')
>>> network = tl.layers.Conv2dLayer(network,
...                         act = tf.nn.relu,
...                         shape = [5, 5, 32, 10], # 10 features for each 5x5 patch
...                         strides=[1, 2, 2, 1],
...                         padding='SAME',
...                         name ='cnn_layer2')
>>> network = tl.layers.PoolLayer(network,
...                         ksize=[1, 2, 2, 1],
...                         strides=[1, 2, 2, 1],
...                         padding='SAME',
...                         pool = tf.nn.max_pool,
...                         name ='pool_layer2')
>>> network = tl.layers.FlattenLayer(network, name='flatten_layer')
>>> network = tl.layers.ReshapeLayer(network, shape=[-1, num_steps, int(network.outputs._shape[-1])])
>>> rnn1 = tl.layers.RNNLayer(network,
...                         cell_fn=tf.nn.rnn_cell.LSTMCell,
...                         cell_init_args={},
...                         n_hidden=200,
...                         initializer=tf.random_uniform_initializer(-0.1, 0.1),
...                         n_steps=num_steps,
...                         return_last=False,
...                         return_seq_2d=True,
...                         name='rnn_layer')
>>> network = tl.layers.DenseLayer(rnn1, n_units=3,
...                         act = tl.activation.identity, name='output_layer')

Methods

count_params() Return the number of parameters in the network
print_layers() Print all info of layers in the network
print_params([details]) Print all info of parameters in the network

Shape layer

Flatten layer

class tensorlayer.layers.FlattenLayer(layer=None, name='flatten_layer')[source]

The FlattenLayer class is layer which reshape high-dimension input to a vector. Then we can apply DenseLayer, RNNLayer, ConcatLayer and etc on the top of it.

[batch_size, mask_row, mask_col, n_mask] —> [batch_size, mask_row * mask_col * n_mask]

Parameters:
layer : a Layer instance

The Layer class feeding into this layer.

name : a string or None

An optional name to attach to this layer.

Examples

>>> x = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])
>>> network = tl.layers.InputLayer(x, name='input_layer')
>>> network = tl.layers.Conv2dLayer(network,
...                    act = tf.nn.relu,
...                    shape = [5, 5, 32, 64],
...                    strides=[1, 1, 1, 1],
...                    padding='SAME',
...                    name ='cnn_layer')
>>> network = tl.layers.Pool2dLayer(network,
...                    ksize=[1, 2, 2, 1],
...                    strides=[1, 2, 2, 1],
...                    padding='SAME',
...                    pool = tf.nn.max_pool,
...                    name ='pool_layer',)
>>> network = tl.layers.FlattenLayer(network, name='flatten_layer')

Methods

count_params() Return the number of parameters in the network
print_layers() Print all info of layers in the network
print_params([details]) Print all info of parameters in the network

Concat layer

class tensorlayer.layers.ConcatLayer(layer=[], concat_dim=1, name='concat_layer')[source]

The ConcatLayer class is layer which concat (merge) two or more DenseLayer to a single class:DenseLayer.

Parameters:
layer : a list of Layer instances

The Layer class feeding into this layer.

concat_dim : int

Dimension along which to concatenate.

name : a string or None

An optional name to attach to this layer.

Examples

>>> sess = tf.InteractiveSession()
>>> x = tf.placeholder(tf.float32, shape=[None, 784])
>>> inputs = tl.layers.InputLayer(x, name='input_layer')
>>> net1 = tl.layers.DenseLayer(inputs, n_units=800, act = tf.nn.relu, name='relu1_1')
>>> net2 = tl.layers.DenseLayer(inputs, n_units=300, act = tf.nn.relu, name='relu2_1')
>>> network = tl.layers.ConcatLayer(layer = [net1, net2], name ='concat_layer')
...     tensorlayer:Instantiate InputLayer input_layer (?, 784)
...     tensorlayer:Instantiate DenseLayer relu1_1: 800, <function relu at 0x1108e41e0>
...     tensorlayer:Instantiate DenseLayer relu2_1: 300, <function relu at 0x1108e41e0>
...     tensorlayer:Instantiate ConcatLayer concat_layer, 1100
...
>>> sess.run(tf.initialize_all_variables())
>>> network.print_params()
...     param 0: (784, 800) (mean: 0.000021, median: -0.000020 std: 0.035525)
...     param 1: (800,) (mean: 0.000000, median: 0.000000 std: 0.000000)
...     param 2: (784, 300) (mean: 0.000000, median: -0.000048 std: 0.042947)
...     param 3: (300,) (mean: 0.000000, median: 0.000000 std: 0.000000)
...     num of params: 863500
>>> network.print_layers()
...     layer 0: Tensor("Relu:0", shape=(?, 800), dtype=float32)
...     layer 1: Tensor("Relu_1:0", shape=(?, 300), dtype=float32)
...

Methods

count_params() Return the number of parameters in the network
print_layers() Print all info of layers in the network
print_params([details]) Print all info of parameters in the network

Reshape layer

class tensorlayer.layers.ReshapeLayer(layer=None, shape=[], name='reshape_layer')[source]

The ReshapeLayer class is layer which reshape the tensor.

Parameters:
layer : a Layer instance

The Layer class feeding into this layer.

shape : a list

The output shape.

name : a string or None

An optional name to attach to this layer.

Examples

>>> The core of this layer is ``tf.reshape``.
>>> Use TensorFlow only :
>>> x = tf.placeholder(tf.float32, shape=[None, 3])
>>> y = tf.reshape(x, shape=[-1, 3, 3])
>>> sess = tf.InteractiveSession()
>>> print(sess.run(y, feed_dict={x:[[1,1,1],[2,2,2],[3,3,3],[4,4,4],[5,5,5],[6,6,6]]}))
... [[[ 1.  1.  1.]
... [ 2.  2.  2.]
... [ 3.  3.  3.]]
... [[ 4.  4.  4.]
... [ 5.  5.  5.]
... [ 6.  6.  6.]]]

Methods

count_params() Return the number of parameters in the network
print_layers() Print all info of layers in the network
print_params([details]) Print all info of parameters in the network

Merge TF-Slim

Yes ! TF-Slim models can be merged into TensorLayer, all Google’s Pre-trained model can be used easily , see Slim-model .

class tensorlayer.layers.SlimNetsLayer(layer=None, slim_layer=None, slim_args={}, name='slim_layer')[source]

The SlimNetsLayer class can be used to merge all TF-Slim nets into TensorLayer. Model can be found in slim-model , more about slim see slim-git .

Parameters:
layer : a list of Layer instances

The Layer class feeding into this layer.

slim_layer : a slim network function

The network you want to stack onto, end with return net, end_points.

name : a string or None

An optional name to attach to this layer.

Methods

count_params() Return the number of parameters in the network
print_layers() Print all info of layers in the network
print_params([details]) Print all info of parameters in the network

Flow control layer

class tensorlayer.layers.MultiplexerLayer(layer=[], name='mux_layer')[source]

The MultiplexerLayer selects one of several input and forwards the selected input into the output, see tutorial_mnist_multiplexer.py.

Parameters:
layer : a list of Layer instances

The Layer class feeding into this layer.

name : a string or None

An optional name to attach to this layer.

References

See tf.pack() and tf.gather() at TensorFlow - Slicing and Joining

Examples

>>> x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
>>> y_ = tf.placeholder(tf.int64, shape=[None, ], name='y_')
>>> # define the network
>>> net_in = tl.layers.InputLayer(x, name='input_layer')
>>> net_in = tl.layers.DropoutLayer(net_in, keep=0.8, name='drop1')
>>> # net 0
>>> net_0 = tl.layers.DenseLayer(net_in, n_units=800,
...                                act = tf.nn.relu, name='net0/relu1')
>>> net_0 = tl.layers.DropoutLayer(net_0, keep=0.5, name='net0/drop2')
>>> net_0 = tl.layers.DenseLayer(net_0, n_units=800,
...                                act = tf.nn.relu, name='net0/relu2')
>>> # net 1
>>> net_1 = tl.layers.DenseLayer(net_in, n_units=800,
...                                act = tf.nn.relu, name='net1/relu1')
>>> net_1 = tl.layers.DropoutLayer(net_1, keep=0.8, name='net1/drop2')
>>> net_1 = tl.layers.DenseLayer(net_1, n_units=800,
...                                act = tf.nn.relu, name='net1/relu2')
>>> net_1 = tl.layers.DropoutLayer(net_1, keep=0.8, name='net1/drop3')
>>> net_1 = tl.layers.DenseLayer(net_1, n_units=800,
...                                act = tf.nn.relu, name='net1/relu3')
>>> # multiplexer
>>> net_mux = tl.layers.MultiplexerLayer(layer = [net_0, net_1], name='mux_layer')
>>> network = tl.layers.ReshapeLayer(net_mux, shape=[-1, 800], name='reshape_layer') #
>>> network = tl.layers.DropoutLayer(network, keep=0.5, name='drop3')
>>> # output layer
>>> network = tl.layers.DenseLayer(network, n_units=10,
...                                act = tf.identity, name='output_layer')

Methods

count_params() Return the number of parameters in the network
print_layers() Print all info of layers in the network
print_params([details]) Print all info of parameters in the network

Wrapper

Embedding + Attention + Seq2seq

class tensorlayer.layers.EmbeddingAttentionSeq2seqWrapper(source_vocab_size, target_vocab_size, buckets, size, num_layers, max_gradient_norm, batch_size, learning_rate, learning_rate_decay_factor, use_lstm=False, num_samples=512, forward_only=False, name='wrapper')[source]

Sequence-to-sequence model with attention and for multiple buckets.

This example implements a multi-layer recurrent neural network as encoder, and an attention-based decoder. This is the same as the model described in this paper:

“Grammar as a Foreign Language” http://arxiv.org/abs/1412.7449 - please look there for details,

or into the seq2seq library for complete model implementation. This example also allows to use GRU cells in addition to LSTM cells, and sampled softmax to handle large output vocabulary size. A single-layer version of this model, but with bi-directional encoder, was presented in

“Neural Machine Translation by Jointly Learning to Align and Translate” http://arxiv.org/abs/1409.0473
The sampled softmax is described in Section 3 of the following paper.
“On Using Very Large Target Vocabulary for Neural Machine Translation” http://arxiv.org/abs/1412.2007
Parameters:
source_vocab_size : size of the source vocabulary.
target_vocab_size : size of the target vocabulary.
buckets : a list of pairs (I, O), where I specifies maximum input length

that will be processed in that bucket, and O specifies maximum output length. Training instances that have inputs longer than I or outputs longer than O will be pushed to the next bucket and padded accordingly. We assume that the list is sorted, e.g., [(2, 4), (8, 16)].

size : number of units in each layer of the model.
num_layers : number of layers in the model.
max_gradient_norm : gradients will be clipped to maximally this norm.
batch_size : the size of the batches used during training;

the model construction is independent of batch_size, so it can be changed after initialization if this is convenient, e.g., for decoding.

learning_rate : learning rate to start with.
learning_rate_decay_factor : decay learning rate by this much when needed.
use_lstm : if true, we use LSTM cells instead of GRU cells.
num_samples : number of samples for sampled softmax.
forward_only : if set, we do not construct the backward pass in the model.
name : a string or None

An optional name to attach to this layer.

Methods

count_params() Return the number of parameters in the network
get_batch(data, bucket_id[, PAD_ID, GO_ID, …]) Get a random batch of data from the specified bucket, prepare for step.
print_layers() Print all info of layers in the network
print_params([details]) Print all info of parameters in the network
step(session, encoder_inputs, …) Run a step of the model feeding the given inputs.
get_batch(data, bucket_id, PAD_ID=0, GO_ID=1, EOS_ID=2, UNK_ID=3)[source]

Get a random batch of data from the specified bucket, prepare for step.

To feed data in step(..) it must be a list of batch-major vectors, while data here contains single length-major cases. So the main logic of this function is to re-index data cases to be in the proper format for feeding.

Parameters:
data : a tuple of size len(self.buckets) in which each element contains

lists of pairs of input and output data that we use to create a batch.

bucket_id : integer, which bucket to get the batch for.
PAD_ID : int

Index of Padding in vocabulary

GO_ID : int

Index of GO in vocabulary

EOS_ID : int

Index of End of sentence in vocabulary

UNK_ID : int

Index of Unknown word in vocabulary

Returns:
The triple (encoder_inputs, decoder_inputs, target_weights) for
the constructed batch that has the proper format to call step(…) later.
step(session, encoder_inputs, decoder_inputs, target_weights, bucket_id, forward_only)[source]

Run a step of the model feeding the given inputs.

Parameters:
session : tensorflow session to use.
encoder_inputs : list of numpy int vectors to feed as encoder inputs.
decoder_inputs : list of numpy int vectors to feed as decoder inputs.
target_weights : list of numpy float vectors to feed as target weights.
bucket_id : which bucket of the model to use.
forward_only : whether to do the backward step or only forward.
Returns:
A triple consisting of gradient norm (or None if we did not do backward),
average perplexity, and the outputs.
Raises:
ValueError : if length of encoder_inputs, decoder_inputs, or

target_weights disagrees with bucket size for the specified bucket_id.

Helper functions

tensorlayer.layers.flatten_reshape(variable, name='')[source]

Reshapes high-dimension input to a vector. [batch_size, mask_row, mask_col, n_mask] —> [batch_size, mask_row * mask_col * n_mask]

Parameters:
variable : a tensorflow variable
name : a string or None

An optional name to attach to this layer.

Examples

>>> W_conv2 = weight_variable([5, 5, 100, 32])   # 64 features for each 5x5 patch
>>> b_conv2 = bias_variable([32])
>>> W_fc1 = weight_variable([7 * 7 * 32, 256])
>>> h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
>>> h_pool2 = max_pool_2x2(h_conv2)
>>> h_pool2.get_shape()[:].as_list() = [batch_size, 7, 7, 32]
...         [batch_size, mask_row, mask_col, n_mask]
>>> h_pool2_flat = tl.layers.flatten_reshape(h_pool2)
...         [batch_size, mask_row * mask_col * n_mask]
>>> h_pool2_flat_drop = tf.nn.dropout(h_pool2_flat, keep_prob)
...
tensorlayer.layers.clear_layers_name()[source]

Clear all layer names in set_keep[‘_layers_name_list’], enable layer name reuse.

Examples

>>> network = tl.layers.InputLayer(x, name='input_layer')
>>> network = tl.layers.DenseLayer(network, n_units=800, name='relu1')
...
>>> tl.layers.clear_layers_name()
>>> network2 = tl.layers.InputLayer(x, name='input_layer')
>>> network2 = tl.layers.DenseLayer(network2, n_units=800, name='relu1')
...
tensorlayer.layers.set_name_reuse(enable=True)[source]

Enable or disable reuse layer name. By default, each layer must has unique name. When you want two or more input placeholder (inference) share the same model parameters, you need to enable layer name reuse, then allow the parameters have same name scope.

Examples

see tutorial_ptb_lstm.py for example.

tensorlayer.layers.print_all_variables(train_only=False)[source]

Print all trainable and non-trainable variables without initialize_all_variables()

Parameters:
train_only : boolen

If True, only print the trainable variables, otherwise, print all variables.

tensorlayer.layers.initialize_rnn_state(state)[source]

Return the initialized RNN state. The input is LSTMStateTuple or State of RNNCells.