Supported version: 2.3.1
Important: This API document assumes you use the following import statement in your training scripts.
import smdistributed.modelparallel.tensorflow as smp
Refer to Modify a TensorFlow Training Script to learn how to use the following API in your TensorFlow training script.
A sub-class of the Keras
Modelclass, which defines the model to be partitioned. Model definition is done by sub-classing
smp.DistributedModelclass, and implementing the
call()method, in the same way as the Keras model sub-classing API. Any operation that is part of the
smp.DistributedModel.call()method is subject to partitioning, meaning that every operation placed inside executes in exactly one of the devices (the operations outside run on all devices).
Similar to the regular Keras API, the forward pass is done by directly calling the model object on the input tensors. For example:
predictions = model(inputs) # model is a smp.DistributedModel object
model()calls can only be made inside a
The outputs from a
smp.DistributedModelare available in all ranks, regardless of which rank computed the last operation.
string): A path to save an unpartitioned model with latest training weights.
Saves the entire, unpartitioned model with the latest trained weights to
SavedModelformat. Defaults to
"/opt/ml/model", which SageMaker monitors to upload the model artifacts to Amazon S3.
int): The index of the partition.
A context manager which places all operations defined inside into the partition whose ID is equal to
smp.partitioncontexts are nested, the innermost context overrides the rest. The
indexargument must be smaller than the number of partitions.
smp.partitionis used in the manual partitioning API; if
"auto_partition"parameter is set to
Truewhile launching training, then
smp.partitioncontexts are ignored. Any operation that is not placed in any
smp.partitioncontext is placed in the
default_partition, as shown in the following example:
# auto_partition: False # default_partition: 0 smp.init() [...] x = tf.constant(1.2) # placed in partition 0 with smp.partition(1): y = tf.add(x, tf.constant(2.3)) # placed in partition 1 with smp.partition(3): z = tf.reduce_sum(y) # placed in partition 3
A subclass of TensorFlow CheckpointManager, which is used to manage checkpoints. The usage is similar to TensorFlow
The following returns a
smp.CheckpointManager(checkpoint, directory="/opt/ml/checkpoints", max_to_keep=None, checkpoint_name="ckpt")
smp.CheckpointManager.restore()must be called after the first training step. This is because the first call of the
smp.stepfunction constructs and partitions the model, which must take place before the checkpoint restore. Calling it before the first
smp.stepcall might result in hangs or unexpected behavior.
checkpoint: A tf.train.Checkpoint instance that represents a model checkpoint.
str) The path to a directory in which to write checkpoints. A file named “checkpoint” is also written to this directory (in a human-readable text format) which contains the state of the
CheckpointManager. Defaults to
"/opt/ml/checkpoints", which is the directory that SageMaker monitors for uploading the checkpoints to Amazon S3.
int): The number of checkpoints to keep. If
None, all checkpoints are kept.
str): Custom name for the checkpoint file. Defaults to
Saves a new checkpoint in the specified directory. Internally uses
Restores the latest checkpoint in the specified directory. Internally uses
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) ckpt_manager = smp.CheckpointManager(checkpoint, max_to_keep=5) # use /opt/ml/checkpoints for inputs in train_ds: loss = train_step(inputs) # [...] ckpt_manager.save() # save a new checkpoint in /opt/ml/checkpoints
for step, inputs in enumerate(train_ds): if step == 1: # NOTE: restore occurs on the second step ckpt_manager.restore() loss = train_step(inputs)