PyTorch API¶
Supported versions: 1.7.1, 1.8.1
This API document assumes you use the following import statements in your training scripts.
import smdistributed.modelparallel.torch as smp
Tip
Refer to Modify a PyTorch Training Script to learn how to use the following API in your PyTorch training script.
-
class
smp.DistributedModel A sub-class of
torch.nn.Modulewhich specifies the model to be partitioned. Accepts atorch.nn.Moduleobjectmodulewhich is the model to be partitioned. The returnedDistributedModelobject internally manages model parallelism and data parallelism. Only one model in the training script can be wrapped withsmp.DistributedModel.Example:
model = smp.DistributedModel(model)
Important: The
__call__andbackwardmethod calls on thesmp.DistributedModelobject (in the following example, the object ismodel) can only be made inside asmp.step-decorated function.Since
DistributedModelis atorch.nn.Module, a forward pass can be performed by calling theDistributedModelobject on the input tensors.predictions = model(inputs) # model is a smp.DistributedModel object
For a backward pass, one needs to call the backward function on the
DistributedModelobject, with tensors and gradients as arguments, replacing the PyTorch operationstorch.Tensor.backwardortorch.autograd.backward.The API for
model.backwardis very similar totorch.autograd.backward. For example, the followingbackwardcalls:torch.autograd.backward(loss) or loss.backward()
should be replaced with:
model.backward(loss) # loss is a tensor with only one element as its data
Similarly, for non-scalar tensors, replace the following
backwardcall containing incoming gradient arguments:torch.autograd.backward(outputs, out_grads)
with the following line:
model.backward(outputs, out_grads)
In these examples, all
__call__andbackwardmethod calls on the model objects (model(inputs)andmodel.backward(loss)) must be made inside asmp.step-decorated function.Using DDP
If DDP is enabled, do not not place a PyTorch
DistributedDataParallelwrapper around theDistributedModelbecause theDistributedModelwrapper will also handle data parallelism.Unlike the original DDP wrapper, when you use
DistributedModel, model parameters and buffers are not immediately broadcast across processes when the wrapper is called. Instead, the broadcast is deferred to the first call of thesmp.step-decorated function when the partition is done.Parameters
module(torch.nn.Module): Module to be distributed (data parallelism and model parallelism).trace_device("cpu"or"gpu") (default:"gpu") Whether to perform the tracing step on the GPU or CPU. The tracing step gathers information on the order of execution of modules, the shapes of intermediate outputs, and execution times, to be used by the partitioning algorithm. Iftrace_deviceis set to GPU, accurate module execution times can be gathered during tracing for potentially improved partitioning decision. However, if the model is too large to fit in a single GPU, thentrace_deviceshould be set to CPU.trace_execution_times(bool) (default:False): IfTrue, the library profiles the execution time of each module during tracing, and uses it in the partitioning decision. This improves the partitioning decision, but it might make the tracing slower. It may also introduce some degree of non-determinism in partitioning results, because of the inherent randomness in module execution times. Must beFalseiftrace_deviceis"cpu".overlapping_allreduce(bool) (default:True): This is only applicable for hybrid data parallelism/model parallelism use cases (whenddpis set toTruewhile launching training). The library uses this flag to decide whether to do overlapping allreduce whenever a parameter gradients are ready. This leads to overlapping of communication and computation and can improve performance. If this is set toFalse, allreduce is performed at the end of the step.backward_passes_per_step(int) (default: 1): This is only applicable for hybrid data parallelism/model parallelism use cases (whenddpis set toTruein config). This parameter indicates the number of backward passes to perform before calling allreduce on DDP. This allows accumulating updates over multiple mini-batches before reducing and applying them.average_grads_across_microbatches(bool) (default:True): Whether or not the computed gradients should be averaged across microbatches. IfFalse, the computed gradients will be summed across microbatches, but not divided by the number of microbatches. In typical use case where the computed loss is averaged over the mini-batch, this should be left asTrue. If you use a loss function that only sums the per-sample loss across the batch (and not divide by the batch size), then this must be set toFalsefor correctness.bucket_cap_mb(default: 25):DistributedDataParallelbuckets parameters into multiple buckets so that gradient reduction of each bucket can potentially overlap with backward computation.bucket_cap_mbcontrols the bucket size in MegaBytes (MB).trace_memory_usage(default: False): When set to True, the library attempts to measure memory usage per module during tracing. If this is disabled, memory usage will be estimated through the sizes of tensors returned from the module.broadcast_buffers(default: True): Flag to be used withddp=True. This parameter is forwarded to the underlyingDistributedDataParallelwrapper. Please see: broadcast_buffer.gradient_as_bucket_view(default: False): To be used withddp=True. This parameter is forwarded to the underlyingDistributedDataParallelwrapper. Please see gradient_as_bucket_view.
Properties
partitioned: IsTrueif the model is partitioned,Falseotherwise. Initialized toFalsewhenDistributedModelis first created. It becomes beTrueduring the first call tosmp.step-decorated function. Once the model is partitioned, the local parameters or localstate_dictcan be fetched using the following methods.
Methods
-
backward(tensors, grad_tensors) Triggers a distributed backward pass across model partitions. Example usage provided in the previous section. The API is very similar to https://pytorch.org/docs/stable/autograd.html#torch.autograd.backward.
retain_gradandcreate_graphflags are not supported.
-
local_buffers() Returns an iterator over buffers for the modules in the partitioned model that have been assigned to the current process.
-
local_named_buffers() Returns an iterator over buffers for the modules in the partitioned model that have been assigned to the current process. This yields both the name of the buffer as well as the buffer itself.
-
local_parameters() Returns an iterator over parameters for the modules in the partitioned model that have been assigned to the current process.
-
local_named_parameters() Returns an iterator over parameters for the modules in the partitioned model that have been assigned to the current process. This yields both the name of the parameter as well as the parameter itself.
-
local_modules() Returns an iterator over the modules in the partitioned model that have been assigned to the current process.
-
local_named_modules() Returns an iterator over the modules in the partitioned model that have been assigned to the current process. This yields both the name of the module as well as the module itself.
-
local_state_dict() Returns the
state_dictthat contains local parameters that belong to the currentmp_rank. Thisstate_dictcontains a key_smp_is_partialto indicate this is a partialstate_dict, which indicates whether thestate_dictcontains elements corresponding to only the current partition, or to the entire model.
-
state_dict() Returns the
state_dictthat contains parameters for the entire model. It first collects thelocal_state_dictand gathers and merges thelocal_state_dictfrom allmp_ranks to create a fullstate_dict. Please note that this needs to be called on all ranks withdp_rank()==0to ensure the gather happens properly. If it is only called on all such ranks, it can hang.
-
load_state_dict() Same as the
torch.module.load_state_dict(), except: It first gathers and merges thestate_dicts acrossmp_ranks, if they are partial. The actual loading happens after the model partition so that each rank knows its local parameters.
-
register_post_partition_hook(hook) Registers a callable
hookto be executed after the model is partitioned. This is useful in situations where an operation needs to be executed after the model partition during the first call tosmp.step, but before the actual execution of the first forward pass. Returns aRemovableHandleobjecthandle, which can be used to remove the hook by callinghandle.remove().
-
cpu() Allgathers parameters and buffers across all
mp_ranks and moves them to the CPU.
-
join() A context manager to be used in conjunction with an instance of
smp.DistributedModelto be able to train with uneven inputs across participating processes. This is only supported whenddp=True. This will use the join with the wrappedDistributedDataParallelinstance. For more information, see: join in the PyTorch documentation.
-
register_comm_hook(state, callable) Available for PyTorch 1.8.1 only Registers a communication hook which is an enhancement that provides a flexible hook
callableto users where they can specify how gradients are aggregated across multiple workers. This method will be called on the wrappedDistributedDataParallelinstance.Please note that when you register a comm hook you have full control of how the gradients are processed. When using only data parallelism with Torch DDP you are expected to average grads across data parallel replicas within the hook. Similarly, when using DistributedModel you have to averaging grads across data parallel replicas within the hook. In addition to that, you also have to average grads across microbatches within the hook unless you explicitly desire to not average based on your loss function. See
average_grads_across_microbatchesfor more information about averaging grads across microbatches.This is only supported when
ddp=Trueandoverlapping_allreduce=True(default). For more information, see: register_comm_hook in the PyTorch documentation.
-
class
smp.DistributedOptimizer Parameters -
optimizerAn optimizer wrapper for saving/loading optimizer states. This wrapper returns
optimizerwith the following methods overridden:-
state_dict() Returns the
state_dictthat contains optimizer state for the entire model. It first collects thelocal_state_dictand gathers and merges thelocal_state_dictfrom allmp_rank``s to create a full ``state_dict.
-
load_state_dict() Same as the
torch.optimizer.load_state_dict(), except:It first gathers and merges the local
state_dicts if they are partial.The actual loading happens after the model partition so that each rank knows its local parameters.
-
local_state_dict() Returns the
state_dictthat contains the local optimizer state that belongs to the currentmp_rank. Thisstate_dictcontains a key_smp_is_partialto indicate this is a partialstate_dict, which indicates whether thestate_dictcontains elements corresponding to only the current partition, or to the entire model.
-
-
smp.partition(index) Inputs
index(int) - The index of the partition.
A context manager which places all modules defined inside into the partition with ID
index. Theindexargument must be less than the number of partitions.Use
smp.partitionto implement manual partitioning. If"auto_partition"isTrue, then thesmp.partitioncontexts are ignored. Any module that is not placed in anysmp.partitioncontext is placed in thedefault_partitiondefined through the SageMaker Python SDK.When
smp.partitioncontexts are nested, the innermost context overrides the rest (see the following example). In PyTorch, manual partitioning should be done inside the module__init__, and the partition assignment applies to the modules that are created inside thesmp.partitioncontext.Example:
class Model(torch.nn.Module): def __init__(self): with smp.partition(1): self.child0 = Child0() # child0 on partition 1 with smp.partition(2): self.child1 = Child1() # child1 on partition 2 self.child2 = Child2() # child2 on partition 1 self.child3 = Child3() # child3 on default_partition
-
smp.get_world_process_group() Returns a
torch.distributedProcessGroupthat consists of all processes, which can be used with thetorch.distributedAPI. Requires"ddp": Truein SageMaker Python SDK parameters.
-
smp.get_mp_process_group() Returns a
torch.distributedProcessGroupthat consists of the processes in theMP_GROUPwhich contains the current process, which can be used with thetorch.distributedAPI. Requires"ddp": Truein SageMaker Python SDK parameters.
-
smp.get_dp_process_group() Returns a
torch.distributedProcessGroupthat consists of the processes in theDP_GROUPwhich contains the current process, which can be used with thetorch.distributedAPI. Requires"ddp": Truein SageMaker Python SDK parameters.
-
smp.is_initialized() Returns
Trueifsmp.inithas already been called for the process, andFalseotherwise.
-
smp.nn.FusedLayerNorm Apex Fused Layer Norm is currently not supported by the library.
smp.nn.FusedLayerNormreplacesapexFusedLayerNormand provides the same functionality. This requiresapexto be installed on the system.
-
smp.optimizers.FusedNovoGrad Fused Novo Grad optimizer is currently not supported by the library.
smp.optimizers.FusedNovoGradreplacesapexFusedNovoGradoptimizer and provides the same functionality. This requiresapexto be installed on the system.
-
smp.optimizers.FusedLamb FusedLamb optimizer currently doesn’t work with the library.
smp.optimizers.FusedLambreplacesapexFusedLamboptimizer and provides the same functionality. This requiresapexto be installed on the system.
-
smp.amp.GradScaler Torch AMP Gradscaler currently doesn’t work with the library.
smp.amp.GradScalerreplacestorch.amp.GradScalerand provides the same functionality.
APIs for Saving and Loading¶
-
smp.save() Saves an object. This operation is similar to
torch.save(), except it has an additional keyword argument,partial, and accepts only string type for the argumentf(file). Ifpartial=True, eachmp_ranksaves a separate checkpoint file and the library adds anmp_rankindex to your saved file.Parameters
obj(dict): A saved object.f(str): A string containing a file name.partial(bool, default=True): When set toTrue, eachmp_ranksaves a separate checkpoint file and the library adds anmp_rankindex to the saved file. If you want to be able to load and further train a model that you save withsmp.save(), you must setpartial=True.pickle_module(picklemodule, default = module"pickle"from"/opt/conda/lib/python3.6/pickle.py"): A module used for pickling metadata and objects.pickle_protocol(int, default=2): Can be specified to override the defaultprotocol.
-
smp.load() Loads an object saved with
smp.save()from a file.Similar to, torch.load(), except it has an additional keyword argument,
partial, and accepts only string type for the argumentf(file). Ifpartial=True, then eachmp_rankloads a separate checkpoint file.Parameters
f(string): A string containing a file name.map_location(function): A function torch.device, a string, or a dict specifying how to remap storage locations.pickle_module(pickle module): A module used for unpickling metadata and objects (has to match thepickle_moduleused to serialize file).pickle_load_args(Python 3 only): Optional keyword arguments passed topickle_module.load()andpickle_module.Unpickler().partial(bool, default=True): When set toTrue, eachmp_rankloads the checkpoint corresponding to themp_rank. Should be used when loading a model trained with the library.
General Instruction For Saving and Loading¶
The library can save partial or full checkpoints.
For partial checkpoints, each
mp_ranksaves its own checkpoint file with only the parameters that belong to that rank.For full checkpoints, the library saves a single checkpoint that contains entire model parameters.
When saving using smp.save(), each rank only holds its own
parameters. If you want to save the full model, there will be some
communication between the ranks to create the full model. If you save
checkpoints often, you should save partial checkpoints for best
performance.
When loading using smp.load(), the library can load either partial or |
full checkpoints or full checkpoints saved by a non-model-parallel model. If you
want to resume training with a non-model-parallel model or do inference, you need
a full checkpoint.
The following is an example of how you can save and load a checkpoint:
# Original model and optimizer
model = MyModel(...)
optimizer = MyOpt(...)
# model parallel wrapper
model = smp.DistributedModel(model)
optimizer = smp.DistributedOptimizer(optimizer)
# To save, always save on dp_rank 0 to avoid data racing
if partial:
# To save the partial model on each mp rank
# the library will create `checkpoint.pt_{mprank}` for each mp rank
if save_partial_model:
if smp.dp_rank() == 0:
model_dict = model.local_state_dict() # save the partial model
opt_dict = optimizer.local_state_dict() # save the partial optimizer state
smp.save(
{"model_state_dict": model_dict, "optimizer_state_dict": opt_dict},
f"/checkpoint.pt",
partial=True,
)
# To save the full model
if save_full_model:
if smp.dp_rank() == 0:
model_dict = model.state_dict() # save the full model
opt_dict = optimizer.state_dict() # save the full optimizer state
smp.save(
{"model_state_dict": model_dict, "optimizer_state_dict": opt_dict},
"/checkpoint.pt",
partial=False,
)
# To load, load on all ranks.
# The only difference for partial/full loading is the partial flag in smp.load
# Load partial checkpoint
if partial_checkpoint:
checkpoint = smp.load("/checkpoint.pt", partial=True)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# Load full checkpoint
if full_checkpoint:
checkpoint = smp.load("/checkpoint.pt", partial=False)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])