Contents
Common API¶
The following SageMaker distribute model parallel APIs are common across all frameworks.
Important: This API document assumes you use the following import statement in your training scripts.
TensorFlow
import smdistributed.modelparallel.tensorflow as smp
PyTorch
import smdistributed.modelparallel.torch as smp
-
smp.
init
() Initialize the library. Must be called at the beginning of training script.
-
@smp.step(non_split_inputs, input_split_axes, [*args, **kwargs])
A decorator that must be placed over a function that represents a single forward and backward pass (for training use cases), or a single forward pass (for evaluation use cases). Any computation that is defined inside the
smp.step
-decorated function is executed in a pipelined manner.By default, every tensor input to the function is split across its batch dimension into a number of microbatches specified while launching the training job. This behavior can be customized through the arguments to
smp.step
, described below. The library then orchestrates the execution of each microbatch across all partitions, based on the chosen pipeline type.In a typical use case, forward pass and back-propagation are executed inside an
smp.step
-decorated function and gradients, loss, and other relevant metrics (such as accuracy, etc.) are returned fromsmp.step
-decorated function.Any gradient post-processing operation, such as gradient clipping and allreduce, as well as
optimizer.apply_gradients
calls (for TF) oroptimizer.step
(for PT) should be applied on the gradients returned from thesmp.step
function, and not inside thesmp.step
function. This is because every operation insidesmp.step
is executed once per microbatch, so having these operations insidesmp.step
can either be inefficient (in the case of allreduce), or lead to wrong results (in the case ofapply_gradients
/optimizer.step
).If the objects returned from the
smp.step
-decorated function containtf.Tensor
s /torch.Tensor
s, they are converted toStepOutput
objects. AStepOutput
object encapsulates all versions of the tensor across different microbatches (seeStepOutput
entry for more information).The argument to
smp.step
decorated function should either be a tensor or an instance of list, tuple, dict or set for it to be split across microbatches. If your object doesn’t fall into this category, you can make the library split your object, by implementingsmp_slice
method.Below is an example of how to use it with PyTorch.
class CustomType: def __init__(self, tensor): self.data = tensor # The library will call this to invoke slicing on the object passing in total microbatches (num_mb) # and the current microbatch index (mb). def smp_slice(self, num_mb, mb, axis): dim_size = list(self.data.size())[axis] split_size = dim_size // num_mb sliced_tensor = self.data.narrow(axis, mb * split_size, split_size) return CustomType(sliced_tensor, self.other) custom_obj = CustomType(torch.ones(4,)) @smp.step() def step(custom_obj): loss = model(custom_obj) model.backward(loss) return loss
Important:
smp.step
splits the batch into microbatches, and executes everything inside the decorated function once per microbatch. This might affect the behavior of batch normalization, any operation that explicitly uses the batch size information, or any other Python code that is expected to run once.TensorFlow-specific behavior
smp.step
is a wrapper that inherits from and extends the behavior oftf.function
, and as such, all the caveats that apply to the use oftf.function
s also apply tosmp.step
. In particular, any operation that is insidesmp.step
executes in graph mode, and not eager mode.In the first call,
smp.step
performs tracing of the wrapped function every time one of the tensor arguments changes their shape or dtype, or for every new value of a Python argument, if there is one. Tracing is expensive, so such scenarios should be avoided as much as possible or, alternatively, aninput_signature
argument must be provided. For more information on the usage oftf.function
, refer to the TensorFlow documentation:Common parameters
non_split_inputs
(list
): The list of arguments to the decorated function that should not be split along the batch dimension. Should be used for all input tensors that do not have a batch dimension. Should be a list of argument names asstr
, as they appear in the signature of thesmp.step
-decorated function. By default it is considered an empty list.input_split_axes
(dict
): A dict that maps the argument name to its batch axis. The keys should be the argument names asstr
, as they appear in the signature of thesmp.step
-decorated function. By default all batch axes are assumed to be the 0-axis.
TensorFlow-only parameters
All arguments of
tf.function
. Note: Theexperimental_compile
argument oftf.function
may not work as expected withsmp.step
, since it interferes with pipelining and model partitioning. To enable XLA with the library, you can instead usetf.config.optimizer.set_jit(True)
.
PyTorch-only parameters
detach_outputs
(bool
) : IfTrue
, callstorch.Tensor.detach()
on all returnedtorch.Tensor
outputs. Setting it toFalse
increases memory consumption, unlessdetach()
is manually called on the returned tensors, because the model graph is not cleared from memory after the training step. Set toTrue
by default.
Returns
The same object(s) returned from the decorated function. All returned
tf.Tensor
,tf.Variable
objects (for TF) ortorch.Tensor
objects (for PT) are wrapped inside aStepOutput
object, even when they are inside a Pythonlist
,tuple
, ordict
.
-
class
StepOutput
A class that encapsulates all versions of a
tf.Tensor
ortorch.Tensor
across all microbatches.When a particular
tf.Tensor
ortorch.Tensor
is computed insidesmp.step
, different versions of the tensor are computed for each microbatch.When this tensor is returned from
smp.step
and is accessed outside of the decorated function, it appears as aStepOutput
object, which contains all such versions. For example,In the case of Tensorflow, the gradient for a particular
tf.Variable
is computed on each microbatch individually, and if this gradient is returned fromsmp.step
, all gradients for thistf.Variable
become part of the sameStepOutput
object. TheStepOutput
class offers the following API for commonly-used post-processing operations on such tensors.In the case of PyTorch, the loss for each microbatch is computed individually and all the
torch.Tensor
s that represent the loss for different microbatches become part of sameStepOutput
object, if loss is returned from thesmp.step
function.
The
StepOutput
class offers the following API for commonly-used post-processing operations on tensors.-
outputs
Returns a list of the underlying tensors, indexed by microbatch.
-
reduce_mean
() Returns a
tf.Tensor
,torch.Tensor
that averages the constituenttf.Tensor
storch.Tensor
s. This is commonly used for averaging loss and gradients across microbatches.
-
reduce_sum
() Returns a
tf.Tensor
/torch.Tensor
that sums the constituenttf.Tensor
s/torch.Tensor
s.
-
concat
() Returns a
tf.Tensor
/torch.Tensor
that concatenates tensors along the batch dimension usingtf.concat
/torch.cat
.
-
stack
() Applies
tf.stack
/torch.stack
operation to the list of constituenttf.Tensor
s /torch.Tensor
s.
TensorFlow-only methods
-
merge
() Returns a
tf.Tensor
that concatenates the constituenttf.Tensor
s along the batch dimension. This is commonly used for merging the model predictions across microbatches.
-
accumulate
(method='variable', var=None) Functionally the same as
StepOutput.reduce_mean()
. However, it is more memory-efficient, especially for large numbers of microbatches, since it does not wait for all constituenttf.Tensor
s to be ready to start averaging them, thereby saving memory.In some cases (XLA for example)
StepOutput.reduce_mean()
might end up being more memory-efficient thanStepOutput.accumulate()
.Parameters
method
("add_n"
or"accumulate_n"
or"variable"
): If"add_n"
or"accumulate_n"
, the library usestf.add_n
andtf.accumulate_n
, respectively, to implement accumulation. If"variable"
, the library uses an internaltf.Variable
into which to accumulate the tensors. Default is"variable"
. Note: Memory usage behavior of these choices can depend on the model and implementation.var
: Atf.Variable
into which, if provided, the library uses to accumulate the tensors. IfNone
, the library internally creates a variable. Ifmethod
is not"variable"
, this argument is ignored.
MPI Basics¶
The library exposes the following basic MPI primitives to its Python API:
smp.rank()
: The rank of the current process.smp.size()
: The total number of processes.smp.mp_rank()
: The rank of the process among the processes that hold the current model replica.smp.dp_rank()
: The rank of the process among the processes that hold different replicas of the same model partition.smp.dp_size()
: The total number of model replicas.smp.local_rank()
: The rank among the processes on the current instance.smp.local_size()
: The total number of processes on the current instance.smp.get_mp_group()
: The list of ranks over which the current model replica is partitioned.smp.get_dp_group()
: The list of ranks that hold different replicas of the same model partition.
Communication API¶
The library provides a few communication primitives which can be helpful while
developing the training script. These primitives use the following
enum
s as arguments to specify which processes the communication
should involve.
Helper structures
-
smp.
CommGroup
An
enum
that takes the valuesCommGroup.WORLD
,CommGroup.MP_GROUP
, andCommGroup.DP_GROUP
. These values can also be accessed assmp.WORLD
,smp.MP_GROUP
, andsmp.DP_GROUP
respectively.CommGroup.WORLD
: Represents the entire group of processes used in trainingCommGroup.MP_GROUP
: Represents the group of processes that hold the same model replica as the current process. The processes in a singleMP_GROUP
collectively store an entire replica of the model.CommGroup.DP_GROUP
: Represents the group of processes that hold the same model partition as the current process. The processes in a singleDP_GROUP
perform data parallelism/allreduce among themselves.
-
smp.
RankType
An
enum
that takes the valuesRankType.WORLD_RANK
,RankType.MP_RANK
, andRankType.DP_RANK
.RankType.WORLD_RANK
: The associated rank is to be interpreted as the rank of the process across all processes used in training.RankType.MP_RANK
: The associated rank is to be interpreted as the rank of the process within theMP_GROUP
.RankType.DP_RANK
: The associated rank is to be interpreted as the rank of the process within theDP_GROUP
.
Communication primitives:
-
smp.
broadcast
(obj, group) Sends the object to all processes in the group. The receiving process must call
smp.recv_from
to receive the sent object.Inputs
obj
: An arbitrary picklable Python object that will be broadcast.group
: ACommGroup
argument that represents to which group of processes the object will be sent.
Notes
When you use
broadcast
on the sender process, there needs to be an accompanyingsmp.recv_from()
call on the receiver processes.This is a synchronous call; the
broadcast
statement returns only after all ranks participating in the call have made a matchingrecv_from
call.
Example
if smp.rank() == 0: smp.broadcast(something, group=smp.CommGroup.WORLD) else: smp.recv_from(0, rank_type=smp.RankType.WORLD_RANK)
-
smp.
send
(obj, dest_rank, rank_type) Sends the object
obj
todest_rank
, which is of a type specified byrank_type
.Inputs
obj
: An arbitrary picklable Python object that will be sent.dest_rank
(int
): An integer denoting the rank of the receiving process.rank_type
(enum
): Asmp.RankType
enum
that determines howdest_rank
is to be interpreted. For example ifdest_rank
is 1 andrank_type
isMP_RANK
, thenobj
is sent to process withmp_rank
1 in theMP_GROUP
which contains the current process.
Notes
Note: This is a synchronous call; the
send
statement returns only after the destination rank has made a matchingrecv_from
call.
-
smp.
recv_from
(src_rank, rank_type) Receive an object from a peer process. Can be used with a matching
smp.send
or asmp.broadcast
call.Inputs
src_rank
(int
): An integer denoting rank of the sending process.rank_type
(enum
): Asmp.RankType
enum
that determines howdest_rank
is to be interpreted. For example ifsrc_rank
is 1 andrank_type
isMP_RANK
, then the object is received from the process withmp_rank
1 in theMP_GROUP
which contains the current process.
Returns
Returns the python object that is sent by the peer process.
Notes
Note: This is a synchronous call; the
recv_from
statement returns only after the source rank has made a matchingsend
orbroadcast
call, and the object is received.
-
smp.
allgather
(obj, group) A collective call that gathers all the submitted objects across all ranks in the specified
group
. Returns a list whosei
th index contains the object submitted by thei
th rank ingroup
.Inputs
obj
: An arbitrary picklable Python object that will be allgathered.group
: ACommGroup
argument that represents which group of processes participate inallgather
.
Notes
Note: This is a synchronous call; the
allgather
statement returns only after all ranks participating in the call have made a matchingallgather
call, and all the objects are received at the current rank.
Examples
# assuming mp_size() == 2 if smp.mp_rank() == 0: out = smp.allgather(obj1, smp.CommGroup.MP_GROUP) # returns [obj1, obj2] else: out = smp.allgather(obj2, smp.CommGroup.MP_GROUP) # returns [obj1, obj2]
-
smp.
barrier
(group=smp.WORLD) A statement that hangs until all processes in the specified group reach the barrier statement, similar to
MPI_Barrier()
.Inputs
group
: Ansmp.CommGroup
enum
that specifies the group of processes participating in the barrier call. Defaults tosmp.WORLD
.
Examples
Assume there are 8 processes and 2 model partitions, and therefore 4
mp_group
s, and 2dp_group
s. If thebarrier
call is passed the valuesmp.MP_GROUP
for its group argument, then each process only waits until the other process of its ownmp_group
reaches that point. It does not wait for processes outside thatmp_group
.
-
smp.
dp_barrier
() Same as passing
smp.DP_GROUP
tosmp.barrier()
. Waits for the processes in the samedp_group
as the current process to reach the same point in execution.
-
smp.
mp_barrier
() Same as passing
smp.MP_GROUP
tosmp.barrier()
. Waits for the processes in the samemp_group
as the current process to reach the same point in execution.