API - Distribution (alpha)¶
Helper sessions and methods to run a distributed training. Check this minst example.
TaskSpecDef ([type, index, trial, ps_hosts, …]) |
Specification for the distributed task with the job name, index of the task, the parameter servers and the worker servers. |
TaskSpec () |
Returns the a TaskSpecDef based on the environment variables for distributed training. |
DistributedSession ([task_spec, …]) |
Creates a distributed session. |
Distributed training¶
TaskSpecDef¶
-
tensorlayer.distributed.
TaskSpecDef
(type='master', index=0, trial=None, ps_hosts=None, worker_hosts=None, master=None)[source]¶ Specification for the distributed task with the job name, index of the task, the parameter servers and the worker servers. If you want to use the last worker for continuous evaluation you can call the method user_last_worker_as_evaluator which returns a new
TaskSpecDef
object without the last worker in the cluster specification.Parameters: - type : A string with the job name, it will be master, worker or ps.
- index : The zero-based index of the task. Distributed training jobs will have a single
master task, one or more parameter servers, and one or more workers.
- trial : The identifier of the trial being run.
- ps_hosts : A string with a coma separate list of hosts for the parameter servers
or a list of hosts.
- worker_hosts : A string with a coma separate list of hosts for the worker servers
or a list of hosts.
- master : A string with the master hosts
References
Create TaskSpecDef from environment variables¶
-
tensorlayer.distributed.
TaskSpec
()[source]¶ Returns the a
TaskSpecDef
based on the environment variables for distributed training.References
Distributed session object¶
-
tensorlayer.distributed.
DistributedSession
(task_spec=None, checkpoint_dir=None, scaffold=None, hooks=None, chief_only_hooks=None, save_checkpoint_secs=600, save_summaries_steps=<object object>, save_summaries_secs=<object object>, config=None, stop_grace_period_secs=120, log_step_count_steps=100)[source]¶ Creates a distributed session. It calls MonitoredTrainingSession to create a
MonitoredSession
for distributed training.Parameters: - task_spec : TaskSpecDef. The task spec definition from TaskSpec()
- checkpoint_dir : A string. Optional path to a directory where to restore
variables.
- scaffold : A Scaffold used for gathering or building supportive ops. If
not specified, a default one is created. It’s used to finalize the graph.
- hooks : Optional list of SessionRunHook objects.
- chief_only_hooks : list of SessionRunHook objects. Activate these hooks if
is_chief==True, ignore otherwise.
- save_checkpoint_secs : The frequency, in seconds, that a checkpoint is saved
using a default checkpoint saver. If save_checkpoint_secs is set to None, then the default checkpoint saver isn’t used.
- save_summaries_steps : The frequency, in number of global steps, that the
summaries are written to disk using a default summary saver. If both save_summaries_steps and save_summaries_secs are set to None, then the default summary saver isn’t used. Default 100.
- save_summaries_secs : The frequency, in secs, that the summaries are written
to disk using a default summary saver. If both save_summaries_steps and save_summaries_secs are set to None, then the default summary saver isn’t used. Default not enabled.
- config : an instance of tf.ConfigProto proto used to configure the session.
It’s the config argument of constructor of tf.Session.
- stop_grace_period_secs : Number of seconds given to threads to stop after
close() has been called.
- log_step_count_steps : The frequency, in number of global steps, that the
global step/sec is logged.
References
Examples
A simple example for distributed training where all the workers use the same dataset:
>>> task_spec = TaskSpec() >>> with tf.device(task_spec.device_fn()): >>> tensors = create_graph() >>> with tl.DistributedSession(task_spec=task_spec, ... checkpoint_dir='/tmp/ckpt') as session: >>> while not session.should_stop(): >>> session.run(tensors)
An example where the dataset is shared among the workers (see https://www.tensorflow.org/programmers_guide/datasets):
>>> task_spec = TaskSpec() >>> # dataset is a :class:`tf.data.Dataset` with the raw data >>> dataset = create_dataset() >>> if task_spec is not None: >>> dataset = dataset.shard(task_spec.num_workers, task_spec.shard_index) >>> # shuffle or apply a map function to the new sharded dataset, for example: >>> dataset = dataset.shuffle(buffer_size=10000) >>> dataset = dataset.batch(batch_size) >>> dataset = dataset.repeat(num_epochs) >>> # create the iterator for the dataset and the input tensor >>> iterator = dataset.make_one_shot_iterator() >>> next_element = iterator.get_next() >>> with tf.device(task_spec.device_fn()): >>> # next_element is the input for the graph >>> tensors = create_graph(next_element) >>> with tl.DistributedSession(task_spec=task_spec, ... checkpoint_dir='/tmp/ckpt') as session: >>> while not session.should_stop(): >>> session.run(tensors)
Data sharding¶
In some cases we want to shard the data among all the training servers and not use all the data in all servers. TensorFlow >=1.4 provides some helper classes to work with data that support data sharding: Datasets
It is important in sharding that the shuffle or any non deterministic operation is done after creating the shards:
from tensorflow.contrib.data import TextLineDataset
from tensorflow.contrib.data import Dataset
task_spec = TaskSpec()
task_spec.create_server()
files_dataset = Dataset.list_files(files_pattern)
dataset = TextLineDataset(files_dataset)
dataset = dataset.map(your_python_map_function, num_threads=4)
if task_spec is not None:
dataset = dataset.shard(task_spec.num_workers, task_spec.shard_index)
dataset = dataset.shuffle(buffer_size)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.device(task_spec.device_fn()):
tensors = create_graph(next_element)
with tl.DistributedSession(task_spec=task_spec,
checkpoint_dir='/tmp/ckpt') as session:
while not session.should_stop():
session.run(tensors)
Logging¶
We can use task_spec to log only in the master server:
while not session.should_stop():
should_log = task_spec.is_master() and your_conditions
if should_log:
results = session.run(tensors_with_log_info)
logging.info(...)
else:
results = session.run(tensors)
Continuous evaluation¶
You can use one of the workers to run an evaluation for the saved checkpoints:
import tensorflow as tf
from tensorflow.python.training import session_run_hook
from tensorflow.python.training.monitored_session import SingularMonitoredSession
class Evaluator(session_run_hook.SessionRunHook):
def __init__(self, checkpoints_path, output_path):
self.checkpoints_path = checkpoints_path
self.summary_writer = tf.summary.FileWriter(output_path)
self.lastest_checkpoint = ''
def after_create_session(self, session, coord):
checkpoint = tf.train.latest_checkpoint(self.checkpoints_path)
# wait until a new check point is available
while self.lastest_checkpoint == checkpoint:
time.sleep(30)
checkpoint = tf.train.latest_checkpoint(self.checkpoints_path)
self.saver.restore(session, checkpoint)
self.lastest_checkpoint = checkpoint
def end(self, session):
super(Evaluator, self).end(session)
# save summaries
step = int(self.lastest_checkpoint.split('-')[-1])
self.summary_writer.add_summary(self.summary, step)
def _create_graph():
# your code to create the graph with the dataset
def run_evaluation():
with tf.Graph().as_default():
summary_tensors = create_graph()
self.saver = tf.train.Saver(var_list=tf_variables.trainable_variables())
hooks = self.create_hooks()
hooks.append(self)
if self.max_time_secs and self.max_time_secs > 0:
hooks.append(StopAtTimeHook(self.max_time_secs))
# this evaluation runs indefinitely, until the process is killed
while True:
with SingularMonitoredSession(hooks=[self]) as session:
try:
while not sess.should_stop():
self.summary = session.run(summary_tensors)
except OutOfRangeError:
pass
# end of evaluation
task_spec = TaskSpec().user_last_worker_as_evaluator()
if task_spec.is_evaluator():
Evaluator().run_evaluation()
else:
task_spec.create_server()
# run normal training
Session hooks¶
TensorFlow provides some Session Hooks to do some operations in the sessions. We added more to help with common operations.
Stop after maximum time¶
Initialize network with checkpoint¶
-
tensorlayer.distributed.
LoadCheckpoint
(saver, checkpoint)[source]¶ Hook that loads a checkpoint after the session is created.
>>> from tensorflow.python.ops import variables as tf_variables >>> from tensorflow.python.training.monitored_session import SingularMonitoredSession >>> >>> tensors = create_graph() >>> saver = tf.train.Saver(var_list=tf_variables.trainable_variables()) >>> checkpoint_hook = LoadCheckpoint(saver, my_checkpoint_file) >>> with tf.SingularMonitoredSession(hooks=[checkpoint_hook]) as session: >>> while not session.should_stop(): >>> session.run(tensors)