API - Distributed Training

(Alpha release - usage might change later)

Helper API to run a distributed training. Check these examples.

Trainer(training_dataset, …[, batch_size, …])

Trainer for neural networks in a distributed environment.

Distributed training


tensorlayer.distributed.Trainer(training_dataset, build_training_func, optimizer, optimizer_args, batch_size=32, prefetch_size=None, checkpoint_dir=None, scaling_learning_rate=True, log_step_size=1, validation_dataset=None, build_validation_func=None, max_iteration=inf)[source]

Trainer for neural networks in a distributed environment.

TensorLayer Trainer is a high-level training interface built on top of TensorFlow MonitoredSession and Horovod. It transparently scales the training of a TensorLayer model from a single GPU to multiple GPUs that be placed on different machines in a single cluster.

To run the trainer, you will need to install Horovod on your machine. Check the installation script at tensorlayer/scripts/download_and_install_openmpi3_ubuntu.sh

The minimal inputs to the Trainer include (1) a training dataset defined using the TensorFlow DataSet API, and (2) a model build function given the inputs of the training dataset, and returns the neural network to train, the loss function to minimize, and the names of the tensor to log during training, and (3) an optimizer and its arguments.

The default parameter choices of Trainer is inspired by the Facebook paper: Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour

  • training_dataset (class TensorFlow DataSet) – The training dataset which zips samples and labels. The trainer automatically shards the training dataset based on the number of GPUs.

  • build_training_func (function) – A function that builds the training operator. It takes the training dataset as an input, and returns the neural network, the loss function and a dictionary that maps string tags to tensors to log during training.

  • optimizer (class TensorFlow Optimizer) – The loss function optimizer. The trainer automatically linearly scale the learning rate based on the number of GPUs.

  • optimizer_args (dict) – The optimizer argument dictionary. It must contain a learning_rate field in type of float. Note that the learning rate is linearly scaled according to the number of GPU by default. You can disable it using the option scaling_learning_rate

  • batch_size (int) – The training mini-batch size (i.e., number of samples per batch).

  • prefetch_size (int or None) – The dataset prefetch buffer size. Set this parameter to overlap the GPU training and data preparation if the data preparation is heavy.

  • checkpoint_dir (None or str) – The path to the TensorFlow model checkpoint. Note that only one trainer master would checkpoints its model. If None, checkpoint is disabled.

  • log_step_size (int) – The trainer logs training information every N mini-batches (i.e., step size).

  • validation_dataset (None or class TensorFlow DataSet) – The optional validation dataset that zips samples and labels. Note that only the trainer master needs to the validation often.

  • build_validation_func (None or function) – The function that builds the validation operator. It returns the validation neural network (which share the weights of the training network) and a custom number of validation metrics.

  • scaling_learning_rate (Boolean) – Linearly scale the learning rate by the number of GPUs. Default is True. This linear scaling rule is generally effective and is highly recommended by the practioners. Check Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour

  • max_iteration (int) – The maximum iteration (i.e., mini-batch) to train. The default is math.inf. You can set it to a small number to end the training earlier. This is usually set for testing purpose.


The training model.


class TensorLayer Layer


The training session tha the Trainer wraps.


class TensorFlow MonitoredTrainingSession


The number of training mini-batch by far.




The validation metrics that zips the validation metric property and the average value.


list of tuples


See tutorial_mnist_distributed_trainer.py.