API - Distribution (alpha)

Helper sessions and methods to run a distributed training. Check this minst example.

TaskSpecDef([type, index, trial, ps_hosts, …]) Specification for the distributed task with the job name, index of the task, the parameter servers and the worker servers.
TaskSpec() Returns the a TaskSpecDef based on the environment variables for distributed training.
DistributedSession([task_spec, …]) Creates a distributed session.

Distributed training

TaskSpecDef

tensorlayer.distributed.TaskSpecDef(type='master', index=0, trial=None, ps_hosts=None, worker_hosts=None, master=None)[source]

Specification for the distributed task with the job name, index of the task, the parameter servers and the worker servers. If you want to use the last worker for continuous evaluation you can call the method user_last_worker_as_evaluator which returns a new TaskSpecDef object without the last worker in the cluster specification.

Parameters:

type : A string with the job name, it will be master, worker or ps.

index : The zero-based index of the task. Distributed training jobs will have a single

master task, one or more parameter servers, and one or more workers.

trial : The identifier of the trial being run.

ps_hosts : A string with a coma separate list of hosts for the parameter servers

or a list of hosts.

worker_hosts : A string with a coma separate list of hosts for the worker servers

or a list of hosts.

master : A string with the master hosts

References

Create TaskSpecDef from environment variables

tensorlayer.distributed.TaskSpec()[source]

Returns the a TaskSpecDef based on the environment variables for distributed training.

References

Distributed session object

tensorlayer.distributed.DistributedSession(task_spec=None, checkpoint_dir=None, scaffold=None, hooks=None, chief_only_hooks=None, save_checkpoint_secs=600, save_summaries_steps=<object object>, save_summaries_secs=<object object>, config=None, stop_grace_period_secs=120, log_step_count_steps=100)[source]

Creates a distributed session. It calls MonitoredTrainingSession to create a MonitoredSession for distributed training.

Parameters:

task_spec : TaskSpecDef. The task spec definition from TaskSpec()

checkpoint_dir : A string. Optional path to a directory where to restore

variables.

scaffold : A Scaffold used for gathering or building supportive ops. If

not specified, a default one is created. It’s used to finalize the graph.

hooks : Optional list of SessionRunHook objects.

chief_only_hooks : list of SessionRunHook objects. Activate these hooks if

is_chief==True, ignore otherwise.

save_checkpoint_secs : The frequency, in seconds, that a checkpoint is saved

using a default checkpoint saver. If save_checkpoint_secs is set to None, then the default checkpoint saver isn’t used.

save_summaries_steps : The frequency, in number of global steps, that the

summaries are written to disk using a default summary saver. If both save_summaries_steps and save_summaries_secs are set to None, then the default summary saver isn’t used. Default 100.

save_summaries_secs : The frequency, in secs, that the summaries are written

to disk using a default summary saver. If both save_summaries_steps and save_summaries_secs are set to None, then the default summary saver isn’t used. Default not enabled.

config : an instance of tf.ConfigProto proto used to configure the session.

It’s the config argument of constructor of tf.Session.

stop_grace_period_secs : Number of seconds given to threads to stop after

close() has been called.

log_step_count_steps : The frequency, in number of global steps, that the

global step/sec is logged.

References

Examples

A simple example for distributed training where all the workers use the same dataset:

>>> task_spec = TaskSpec()
>>> with tf.device(task_spec.device_fn()):
>>>      tensors = create_graph()
>>> with tl.DistributedSession(task_spec=task_spec,
...                            checkpoint_dir='/tmp/ckpt') as session:
>>>      while not session.should_stop():
>>>           session.run(tensors)

An example where the dataset is shared among the workers (see https://www.tensorflow.org/programmers_guide/datasets):

>>> task_spec = TaskSpec()
>>> # dataset is a :class:`tf.data.Dataset` with the raw data
>>> dataset = create_dataset()
>>> if task_spec is not None:
>>>     dataset = dataset.shard(task_spec.num_workers, task_spec.shard_index)
>>> # shuffle or apply a map function to the new sharded dataset, for example:
>>> dataset = dataset.shuffle(buffer_size=10000)
>>> dataset = dataset.batch(batch_size)
>>> dataset = dataset.repeat(num_epochs)
>>> # create the iterator for the dataset and the input tensor
>>> iterator = dataset.make_one_shot_iterator()
>>> next_element = iterator.get_next()
>>> with tf.device(task_spec.device_fn()):
>>>      # next_element is the input for the graph
>>>      tensors = create_graph(next_element)
>>> with tl.DistributedSession(task_spec=task_spec,
...                            checkpoint_dir='/tmp/ckpt') as session:
>>>      while not session.should_stop():
>>>           session.run(tensors)

Data sharding

In some cases we want to shard the data among all the training servers and not use all the data in all servers. TensorFlow >=1.4 provides some helper classes to work with data that support data sharding: Datasets

It is important in sharding that the shuffle or any non deterministic operation is done after creating the shards:

from tensorflow.contrib.data import TextLineDataset
from tensorflow.contrib.data import Dataset

task_spec = TaskSpec()
task_spec.create_server()
files_dataset = Dataset.list_files(files_pattern)
dataset = TextLineDataset(files_dataset)
dataset = dataset.map(your_python_map_function, num_threads=4)
if task_spec is not None:
      dataset = dataset.shard(task_spec.num_workers, task_spec.shard_index)
dataset = dataset.shuffle(buffer_size)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.device(task_spec.device_fn()):
      tensors = create_graph(next_element)
with tl.DistributedSession(task_spec=task_spec,
                           checkpoint_dir='/tmp/ckpt') as session:
      while not session.should_stop():
          session.run(tensors)

Logging

We can use task_spec to log only in the master server:

while not session.should_stop():
      should_log = task_spec.is_master() and your_conditions
      if should_log:
          results = session.run(tensors_with_log_info)
          logging.info(...)
      else:
          results = session.run(tensors)

Continuous evaluation

You can use one of the workers to run an evaluation for the saved checkpoints:

import tensorflow as tf
from tensorflow.python.training import session_run_hook
from tensorflow.python.training.monitored_session import SingularMonitoredSession

class Evaluator(session_run_hook.SessionRunHook):
      def __init__(self, checkpoints_path, output_path):
          self.checkpoints_path = checkpoints_path
          self.summary_writer = tf.summary.FileWriter(output_path)
          self.lastest_checkpoint = ''

      def after_create_session(self, session, coord):
          checkpoint = tf.train.latest_checkpoint(self.checkpoints_path)
          # wait until a new check point is available
          while self.lastest_checkpoint == checkpoint:
              time.sleep(30)
              checkpoint = tf.train.latest_checkpoint(self.checkpoints_path)
          self.saver.restore(session, checkpoint)
          self.lastest_checkpoint = checkpoint

      def end(self, session):
          super(Evaluator, self).end(session)
          # save summaries
          step = int(self.lastest_checkpoint.split('-')[-1])
          self.summary_writer.add_summary(self.summary, step)

      def _create_graph():
          # your code to create the graph with the dataset

      def run_evaluation():
          with tf.Graph().as_default():
              summary_tensors = create_graph()
              self.saver = tf.train.Saver(var_list=tf_variables.trainable_variables())
              hooks = self.create_hooks()
              hooks.append(self)
              if self.max_time_secs and self.max_time_secs > 0:
                  hooks.append(StopAtTimeHook(self.max_time_secs))
              # this evaluation runs indefinitely, until the process is killed
              while True:
                  with SingularMonitoredSession(hooks=[self]) as session:
                      try:
                          while not sess.should_stop():
                              self.summary = session.run(summary_tensors)
                      except OutOfRangeError:
                          pass
                      # end of evaluation

task_spec = TaskSpec().user_last_worker_as_evaluator()
if task_spec.is_evaluator():
      Evaluator().run_evaluation()
else:
      task_spec.create_server()
      # run normal training

Session hooks

TensorFlow provides some Session Hooks to do some operations in the sessions. We added more to help with common operations.

Stop after maximum time

tensorlayer.distributed.StopAtTimeHook(time_running)[source]

Hook that requests stop after a specified time.

Parameters:time_running: Maximum time running in seconds

Initialize network with checkpoint

tensorlayer.distributed.LoadCheckpoint(saver, checkpoint)[source]

Hook that loads a checkpoint after the session is created.

>>> from tensorflow.python.ops import variables as tf_variables
>>> from tensorflow.python.training.monitored_session import SingularMonitoredSession
>>>
>>> tensors = create_graph()
>>> saver = tf.train.Saver(var_list=tf_variables.trainable_variables())
>>> checkpoint_hook = LoadCheckpoint(saver, my_checkpoint_file)
>>> with tf.SingularMonitoredSession(hooks=[checkpoint_hook]) as session:
>>>      while not session.should_stop():
>>>           session.run(tensors)