API - Distributed Training

(Alpha release - usage might change later)

Helper sessions and methods to run a distributed training. Check this minst example.

TaskSpecDef([task_type, index, trial, …]) Specification for a distributed task.
TaskSpec() Returns the a TaskSpecDef based on the environment variables for distributed training.
DistributedSession([task_spec, …]) Creates a distributed session.
StopAtTimeHook(time_running) Hook that requests stop after a specified time.
LoadCheckpoint(saver, checkpoint) Hook that loads a checkpoint after the session is created.

Distributed training

TaskSpecDef

tensorlayer.distributed.TaskSpecDef(task_type='master', index=0, trial=None, ps_hosts=None, worker_hosts=None, master=None)[source]

Specification for a distributed task.

It contains the job name, index of the task, the parameter servers and the worker servers. If you want to use the last worker for continuous evaluation you can call the method use_last_worker_as_evaluator which returns a new TaskSpecDef object without the last worker in the cluster specification.

Parameters:
  • task_type (str) – Task type. One of master, worker or ps.
  • index (int) – The zero-based index of the task. Distributed training jobs will have a single master task, one or more parameter servers, and one or more workers.
  • trial (int) – The identifier of the trial being run.
  • ps_hosts (str OR list of str) – A string with a coma separate list of hosts for the parameter servers or a list of hosts.
  • worker_hosts (str OR list of str) – A string with a coma separate list of hosts for the worker servers or a list of hosts.
  • master (str) – A string with the master hosts

Notes

master might not be included in TF_CONFIG and can be None. The shard_index is adjusted in any case to assign 0 to master and >= 1 to workers. This implementation doesn’t support sparse arrays in the TF_CONFIG variable as the official TensorFlow documentation shows, as it is not a supported by the json definition.

References

Create TaskSpecDef from environment variables

tensorlayer.distributed.TaskSpec()

Returns the a TaskSpecDef based on the environment variables for distributed training.

References

Distributed session object

tensorlayer.distributed.DistributedSession(task_spec=None, checkpoint_dir=None, scaffold=None, hooks=None, chief_only_hooks=None, save_checkpoint_secs=600, save_summaries_steps=<object object>, save_summaries_secs=<object object>, config=None, stop_grace_period_secs=120, log_step_count_steps=100)

Creates a distributed session.

It calls MonitoredTrainingSession to create a MonitoredSession for distributed training.

Parameters:
  • task_spec (TaskSpecDef.) – The task spec definition from create_task_spec_def()
  • checkpoint_dir (str.) – Optional path to a directory where to restore variables.
  • scaffold (Scaffold) – A Scaffold used for gathering or building supportive ops. If not specified, a default one is created. It’s used to finalize the graph.
  • hooks (list of SessionRunHook objects.) – Optional
  • chief_only_hooks (list of SessionRunHook objects.) – Activate these hooks if is_chief==True, ignore otherwise.
  • save_checkpoint_secs (int) – The frequency, in seconds, that a checkpoint is saved using a default checkpoint saver. If save_checkpoint_secs is set to None, then the default checkpoint saver isn’t used.
  • save_summaries_steps (int) – The frequency, in number of global steps, that the summaries are written to disk using a default summary saver. If both save_summaries_steps and save_summaries_secs are set to None, then the default summary saver isn’t used. Default 100.
  • save_summaries_secs (int) – The frequency, in secs, that the summaries are written to disk using a default summary saver. If both save_summaries_steps and save_summaries_secs are set to None, then the default summary saver isn’t used. Default not enabled.
  • config (tf.ConfigProto) – an instance of tf.ConfigProto proto used to configure the session. It’s the config argument of constructor of tf.Session.
  • stop_grace_period_secs (int) – Number of seconds given to threads to stop after close() has been called.
  • log_step_count_steps (int) – The frequency, in number of global steps, that the global step/sec is logged.

Examples

A simple example for distributed training where all the workers use the same dataset:

>>> task_spec = TaskSpec()
>>> with tf.device(task_spec.device_fn()):
>>>      tensors = create_graph()
>>> with tl.DistributedSession(task_spec=task_spec,
...                            checkpoint_dir='/tmp/ckpt') as session:
>>>      while not session.should_stop():
>>>           session.run(tensors)

An example where the dataset is shared among the workers (see https://www.tensorflow.org/programmers_guide/datasets):

>>> task_spec = TaskSpec()
>>> # dataset is a :class:`tf.data.Dataset` with the raw data
>>> dataset = create_dataset()
>>> if task_spec is not None:
>>>     dataset = dataset.shard(task_spec.num_workers, task_spec.shard_index)
>>> # shuffle or apply a map function to the new sharded dataset, for example:
>>> dataset = dataset.shuffle(buffer_size=10000)
>>> dataset = dataset.batch(batch_size)
>>> dataset = dataset.repeat(num_epochs)
>>> # create the iterator for the dataset and the input tensor
>>> iterator = dataset.make_one_shot_iterator()
>>> next_element = iterator.get_next()
>>> with tf.device(task_spec.device_fn()):
>>>      # next_element is the input for the graph
>>>      tensors = create_graph(next_element)
>>> with tl.DistributedSession(task_spec=task_spec,
...                            checkpoint_dir='/tmp/ckpt') as session:
>>>      while not session.should_stop():
>>>           session.run(tensors)

References

Data sharding

In some cases we want to shard the data among all the training servers and not use all the data in all servers. TensorFlow >=1.4 provides some helper classes to work with data that support data sharding: Datasets

It is important in sharding that the shuffle or any non deterministic operation is done after creating the shards:

from tensorflow.contrib.data import TextLineDataset
from tensorflow.contrib.data import Dataset

task_spec = TaskSpec()
task_spec.create_server()
files_dataset = Dataset.list_files(files_pattern)
dataset = TextLineDataset(files_dataset)
dataset = dataset.map(your_python_map_function, num_threads=4)
if task_spec is not None:
      dataset = dataset.shard(task_spec.num_workers, task_spec.shard_index)
dataset = dataset.shuffle(buffer_size)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.device(task_spec.device_fn()):
      tensors = create_graph(next_element)
with tl.DistributedSession(task_spec=task_spec,
                           checkpoint_dir='/tmp/ckpt') as session:
      while not session.should_stop():
          session.run(tensors)

Logging

We can use task_spec to log only in the master server:

while not session.should_stop():
      should_log = task_spec.is_master() and your_conditions
      if should_log:
          results = session.run(tensors_with_log_info)
          logging.info(...)
      else:
          results = session.run(tensors)

Continuous evaluation

You can use one of the workers to run an evaluation for the saved checkpoints:

import tensorflow as tf
from tensorflow.python.training import session_run_hook
from tensorflow.python.training.monitored_session import SingularMonitoredSession

class Evaluator(session_run_hook.SessionRunHook):
      def __init__(self, checkpoints_path, output_path):
          self.checkpoints_path = checkpoints_path
          self.summary_writer = tf.summary.FileWriter(output_path)
          self.lastest_checkpoint = ''

      def after_create_session(self, session, coord):
          checkpoint = tf.train.latest_checkpoint(self.checkpoints_path)
          # wait until a new check point is available
          while self.lastest_checkpoint == checkpoint:
              time.sleep(30)
              checkpoint = tf.train.latest_checkpoint(self.checkpoints_path)
          self.saver.restore(session, checkpoint)
          self.lastest_checkpoint = checkpoint

      def end(self, session):
          super(Evaluator, self).end(session)
          # save summaries
          step = int(self.lastest_checkpoint.split('-')[-1])
          self.summary_writer.add_summary(self.summary, step)

      def _create_graph():
          # your code to create the graph with the dataset

      def run_evaluation():
          with tf.Graph().as_default():
              summary_tensors = create_graph()
              self.saver = tf.train.Saver(var_list=tf_variables.trainable_variables())
              hooks = self.create_hooks()
              hooks.append(self)
              if self.max_time_secs and self.max_time_secs > 0:
                  hooks.append(StopAtTimeHook(self.max_time_secs))
              # this evaluation runs indefinitely, until the process is killed
              while True:
                  with SingularMonitoredSession(hooks=[self]) as session:
                      try:
                          while not sess.should_stop():
                              self.summary = session.run(summary_tensors)
                      except OutOfRangeError:
                          pass
                      # end of evaluation

task_spec = TaskSpec().user_last_worker_as_evaluator()
if task_spec.is_evaluator():
      Evaluator().run_evaluation()
else:
      task_spec.create_server()
      # run normal training

Session hooks

TensorFlow provides some Session Hooks to do some operations in the sessions. We added more to help with common operations.

Stop after maximum time

tensorlayer.distributed.StopAtTimeHook(time_running)[source]

Hook that requests stop after a specified time.

Parameters:time_running (int) – Maximum time running in seconds

Initialize network with checkpoint

tensorlayer.distributed.LoadCheckpoint(saver, checkpoint)[source]

Hook that loads a checkpoint after the session is created.

>>> from tensorflow.python.ops import variables as tf_variables
>>> from tensorflow.python.training.monitored_session import SingularMonitoredSession
>>>
>>> tensors = create_graph()
>>> saver = tf.train.Saver(var_list=tf_variables.trainable_variables())
>>> checkpoint_hook = LoadCheckpoint(saver, my_checkpoint_file)
>>> with tf.SingularMonitoredSession(hooks=[checkpoint_hook]) as session:
>>>      while not session.should_stop():
>>>           session.run(tensors)