PyTorch API¶
To use the PyTorch-specific APIs for SageMaker distributed model parallism,
import the smdistributed.modelparallel.torch
package at the top of your training script.
import smdistributed.modelparallel.torch as smp
Tip
Refer to Modify a PyTorch Training Script to learn how to use the following API in your PyTorch training script.
Topics
smdistributed.modelparallel.torch.DistributedModel¶
-
class
smdistributed.modelparallel.torch.
DistributedModel
¶ A sub-class of
torch.nn.Module
which specifies the model to be partitioned. Accepts atorch.nn.Module
objectmodule
which is the model to be partitioned. The returnedDistributedModel
object internally manages model parallelism and data parallelism. Only one model in the training script can be wrapped withsmdistributed.modelparallel.torch.DistributedModel
.Example:
import smdistributed.modelparallel.torch as smp model = smp.DistributedModel(model)
Important: The
__call__
andbackward
method calls on thesmdistributed.modelparallel.torch.DistributedModel
object (in the following example, the object ismodel
) can only be made inside asmdistributed.modelparallel.torch.step
-decorated function.Since
DistributedModel
is atorch.nn.Module
, a forward pass can be performed by calling theDistributedModel
object on the input tensors.predictions = model(inputs) # model is a smp.DistributedModel object
For a backward pass, one needs to call the backward function on the
DistributedModel
object, with tensors and gradients as arguments, replacing the PyTorch operationstorch.Tensor.backward
ortorch.autograd.backward
.The API for
model.backward
is very similar totorch.autograd.backward
. For example, the followingbackward
calls:torch.autograd.backward(loss) or loss.backward()
should be replaced with:
model.backward(loss) # loss is a tensor with only one element as its data
Similarly, for non-scalar tensors, replace the following
backward
call containing incoming gradient arguments:torch.autograd.backward(outputs, out_grads)
with the following line:
model.backward(outputs, out_grads)
In these examples, all
__call__
andbackward
method calls on the model objects (model(inputs)
andmodel.backward(loss)
) must be made inside asmdistributed.modelparallel.torch.step
-decorated function.Using DDP
If DDP is enabled with the SageMaker model parallel library, do not not place a PyTorch
DistributedDataParallel
wrapper around theDistributedModel
because theDistributedModel
wrapper will also handle data parallelism.Unlike the original DDP wrapper, when you use
DistributedModel
, model parameters and buffers are not immediately broadcast across processes when the wrapper is called. Instead, the broadcast is deferred to the first call of thesmdistributed.modelparallel.torch.step
-decorated function when the partition is done.Parameters
module
(torch.nn.Module
): Module to be distributed (data parallelism and model parallelism).trace_device
("cpu"
or"gpu"
) (default:"gpu"
) Whether to perform the tracing step on the GPU or CPU. The tracing step gathers information on the order of execution of modules, the shapes of intermediate outputs, and execution times, to be used by the partitioning algorithm. Iftrace_device
is set to GPU, accurate module execution times can be gathered during tracing for potentially improved partitioning decision. However, if the model is too large to fit in a single GPU, thentrace_device
should be set to CPU.trace_execution_times
(bool
) (default:False
): IfTrue
, the library profiles the execution time of each module during tracing, and uses it in the partitioning decision. This improves the partitioning decision, but it might make the tracing slower. It may also introduce some degree of non-determinism in partitioning results, because of the inherent randomness in module execution times. Must beFalse
iftrace_device
is"cpu"
.overlapping_allreduce
(bool
) (default:True
): This is only applicable for hybrid data parallelism/model parallelism use cases (whenddp
is set toTrue
while launching training). The library uses this flag to decide whether to do overlapping allreduce whenever a parameter gradients are ready. This leads to overlapping of communication and computation and can improve performance. If this is set toFalse
, allreduce is performed at the end of the step.backward_passes_per_step
(int
) (default: 1): This is only applicable for hybrid data parallelism/model parallelism use cases (whenddp
is set toTrue
in config). This parameter indicates the number of backward passes to perform before calling allreduce on DDP. This allows accumulating updates over multiple mini-batches before reducing and applying them.average_grads_across_microbatches
(bool
) (default:True
): Whether or not the computed gradients should be averaged across microbatches. IfFalse
, the computed gradients will be summed across microbatches, but not divided by the number of microbatches. In typical use case where the computed loss is averaged over the mini-batch, this should be left asTrue
. If you use a loss function that only sums the per-sample loss across the batch (and not divide by the batch size), then this must be set toFalse
for correctness.bucket_cap_mb
(default: 25):DistributedDataParallel
buckets parameters into multiple buckets so that gradient reduction of each bucket can potentially overlap with backward computation.bucket_cap_mb
controls the bucket size in MegaBytes (MB).trace_memory_usage
(default: False): When set to True, the library attempts to measure memory usage per module during tracing. If this is disabled, memory usage will be estimated through the sizes of tensors returned from the module.broadcast_buffers
(default: True): Flag to be used withddp=True
. This parameter is forwarded to the underlyingDistributedDataParallel
wrapper. Please see: broadcast_buffer.gradient_as_bucket_view
(default: False): To be used withddp=True
. This parameter is forwarded to the underlyingDistributedDataParallel
wrapper. Please see gradient_as_bucket_view.
Properties
partitioned
: IsTrue
if the model is partitioned,False
otherwise. Initialized toFalse
whenDistributedModel
is first created. It becomes beTrue
during the first call tosmdistributed.modelparallel.torch.step
-decorated function. Once the model is partitioned, the local parameters or localstate_dict
can be fetched using the following methods.
Methods
-
backward
(tensors, grad_tensors)¶ Triggers a distributed backward pass across model partitions. Example usage provided in the previous section. The API is very similar to https://pytorch.org/docs/stable/autograd.html#torch.autograd.backward.
retain_grad
andcreate_graph
flags are not supported.
-
local_buffers
()¶ Returns an iterator over buffers for the modules in the partitioned model that have been assigned to the current process.
-
local_named_buffers
()¶ Returns an iterator over buffers for the modules in the partitioned model that have been assigned to the current process. This yields both the name of the buffer as well as the buffer itself.
-
local_parameters
()¶ Returns an iterator over parameters for the modules in the partitioned model that have been assigned to the current process.
-
local_named_parameters
()¶ Returns an iterator over parameters for the modules in the partitioned model that have been assigned to the current process. This yields both the name of the parameter as well as the parameter itself.
-
local_modules
()¶ Returns an iterator over the modules in the partitioned model that have been assigned to the current process.
-
local_named_modules
()¶ Returns an iterator over the modules in the partitioned model that have been assigned to the current process. This yields both the name of the module as well as the module itself.
-
local_state_dict
()¶ Returns the
state_dict
that contains local parameters that belong to the currentmp_rank
. Thisstate_dict
contains a key_smp_is_partial
to indicate this is a partialstate_dict
, which indicates whether thestate_dict
contains elements corresponding to only the current partition, or to the entire model.
-
state_dict
()¶ Returns the
state_dict
that contains parameters for the entire model. It first collects thelocal_state_dict
and gathers and merges thelocal_state_dict
from allmp_rank
s to create a fullstate_dict
. Please note that this needs to be called on all ranks withdp_rank()==0
to ensure the gather happens properly. If it is only called on all such ranks, it can hang.
-
load_state_dict
()¶ Same as the
torch.module.load_state_dict()
, except: It first gathers and merges thestate_dict
s acrossmp_rank
s, if they are partial. The actual loading happens after the model partition so that each rank knows its local parameters.
-
register_post_partition_hook
(hook)¶ Registers a callable
hook
to be executed after the model is partitioned. This is useful in situations where an operation needs to be executed after the model partition during the first call tosmdistributed.modelparallel.torch.step
, but before the actual execution of the first forward pass. Returns aRemovableHandle
objecthandle
, which can be used to remove the hook by callinghandle.remove()
.
-
cpu
()¶ Allgathers parameters and buffers across all
mp_rank
s and moves them to the CPU.
-
join
()¶ A context manager to be used in conjunction with an instance of
smdistributed.modelparallel.torch.DistributedModel
to be able to train with uneven inputs across participating processes. This is only supported whenddp=True
. This will use the join with the wrappedDistributedDataParallel
instance. For more information, see: join in the PyTorch documentation.
-
register_comm_hook
(state, callable)¶ Available for PyTorch 1.8.1 only Registers a communication hook which is an enhancement that provides a flexible hook
callable
to users where they can specify how gradients are aggregated across multiple workers. This method will be called on the wrappedDistributedDataParallel
instance.Please note that when you register a comm hook you have full control of how the gradients are processed. When using only data parallelism with Torch DDP you are expected to average grads across data parallel replicas within the hook. Similarly, when using DistributedModel you have to averaging grads across data parallel replicas within the hook. In addition to that, you also have to average grads across microbatches within the hook unless you explicitly desire to not average based on your loss function. See
average_grads_across_microbatches
for more information about averaging grads across microbatches.This is only supported when
ddp=True
andoverlapping_allreduce=True
(default). For more information, see: register_comm_hook in the PyTorch documentation.
Behavior of
smdistributed.modelparallel.torch.DistributedModel
with Tensor ParallelismWhen a model is wrapped by
smdistributed.modelparallel.torch.DistributedModel
, the library immediately traverses the modules of the model object, and replaces the modules that are supported for tensor parallelism with their distributed counterparts. This replacement happens in place. If there are no other references to the original modules in the script, they are garbage-collected. The module attributes that previously referred to the original submodules now refer to the distributed versions of those submodules.Example:
# register DistributedSubmodule as the distributed version of Submodule # (note this is a hypothetical example, smp.nn.DistributedSubmodule does not exist) import smdistributed.modelparallel.torch as smp smp.tp_register_with_module(Submodule, smp.nn.DistributedSubmodule) class MyModule(nn.Module): def __init__(self): ... self.submodule = Submodule() ... # enabling tensor parallelism for the entire model with smp.tensor_parallelism(): model = MyModule() # here model.submodule is still a Submodule object assert isinstance(model.submodule, Submodule) model = smp.DistributedModel(model) # now model.submodule is replaced with an equivalent instance # of smp.nn.DistributedSubmodule assert isinstance(model.module.submodule, smp.nn.DistributedSubmodule)
If
pipeline_parallel_degree
(equivalently,partitions
) is 1, the placement of model partitions into GPUs and the initial broadcast of model parameters and buffers across data-parallel ranks take place immediately. This is because it does not need to wait for the model partition whensmdistributed.modelparallel.torch.DistributedModel
wrapper is called. For other cases withpipeline_parallel_degree
greater than 1, the broadcast and device placement will be deferred until the first call of ansmdistributed.modelparallel.torch.step
-decorated function happens. This is because the firstsmdistributed.modelparallel.torch.step
-decorated function call is when the model partitioning happens if pipeline parallelism is enabled.Because of the module replacement during the
smdistributed.modelparallel.torch.DistributedModel
call, anyload_state_dict
calls on the model, as well as any direct access to model parameters, such as during the optimizer creation, should be done after thesmdistributed.modelparallel.torch.DistributedModel
call.Since the broadcast of the model parameters and buffers happens immediately during
smdistributed.modelparallel.torch.DistributedModel
call when the degree of pipeline parallelism is 1, using@smp.step
decorators is not required when tensor parallelism is used by itself (without pipeline parallelism).For more information about the library’s tensor parallelism APIs for PyTorch, see PyTorch API for Tensor Parallelism.
Additional Methods of
smdistributed.modelparallel.torch.DistributedModel
for Tensor ParallelismThe following are the new methods of
smdistributed.modelparallel.torch.DistributedModel
, in addition to the ones listed in the documentation.-
distributed_modules
()¶ An iterator that runs over the set of distributed (tensor-parallelized) modules in the model
-
is_distributed_parameter
(param)¶ Returns
True
if the givennn.Parameter
is distributed over tensor-parallel ranks.
-
is_distributed_buffer
(buf)¶ Returns
True
if the given buffer is distributed over tensor-parallel ranks.
-
is_scaled_batch_parameter
(param)¶ Returns
True
if the givennn.Parameter
is operates on the scaled batch (batch over the entireTP_GROUP
, and not only the local batch).
-
is_scaled_batch_buffer
(buf)¶ Returns
True
if the parameter corresponding to the given buffer operates on the scaled batch (batch over the entireTP_GROUP
, and not only the local batch).
-
default_reducer_named_parameters
()¶ Returns an iterator that runs over
(name, param)
tuples, forparam
that is allreduced over theDP_GROUP
.
-
scaled_batch_reducer_named_parameters
()¶ Returns an iterator that runs over
(name, param)
tuples, forparam
that is allreduced over theRDP_GROUP
.
smdistributed.modelparallel.torch.DistributedOptimizer¶
-
class
smdistributed.modelparallel.torch.
DistributedOptimizer
(optimizer, static_loss_scale=1.0, dynamic_loss_scale=False, **dynamic_loss_args)¶ An optimizer wrapper for saving and loading optimizer states.
- Parameters
optimizer (object) – An optimizer object.
static_loss_scale (float) – Effective only for FP16 training. The default value is
1.0
.dynamic_loss_scale (boolean) – Effective only for FP16 training. Set to
True
to use dynamic loss scale. The default value isFalse
.dynamic_loss_args (dict) –
Effective only for FP16 training. If
dynamic_loss_scale=True
, you can configure additional scale parameters for dynamic loss scale. The following list shows available parameters."init_scale"
: Default is2**32
"scale_factor"
: Default is2.
"scale_window"
: Default is1000
"min_scale"
: Default is1
"delayed_shift"
: Default is1
"consecutive_hysteresis"
: Default isFalse
Example usage of an FP32 Optimizer:
optimizer = torch.optim.AdaDelta(...) optimizer = smdistributed.modelparallel.torch.DistributedOptimizer(optimizer)
Example usage of an FP16 Optimizer with static loss scale:
optimizer = torch.optim.AdaDelta(...) optimizer = smdistributed.modelparallel.torch.DistributedOptimizer( optimizer, static_loss_scale=1.0 )
Example usage of an FP16 Optimizer with dynamic loss scale:
optimizer = torch.optim.AdaDelta(...) optimizer = smdistributed.modelparallel.torch.DistributedOptimizer( optimizer, static_loss_scale=None, dynamic_loss_scale=True, dynamic_loss_args={ "scale_window": 1000, "min_scale": 1, "delayed_shift": 2 } )
Tip
After you modify training scripts with
smdistributed.modelparallel.torch.DistributedModel
andsmdistributed.modelparallel.torch.DistributedOptimizer
, use the SageMaker PyTorch estimator’s distribution configuration to enable FP16 training. You simply need to add"fp16": True
to thesmp_options
config dictionary’s"parameters"
key as shown in Using the SageMaker TensorFlow and PyTorch Estimators. For more information about available parameters for thesmp_options
config, see Run a Distributed Training Job Using the SageMaker Python SDK.This wrapper returns an
optimizer
object with the following methods overridden:-
state_dict
()¶ Returns the
state_dict
that contains optimizer state for the entire model. It first collects thelocal_state_dict
and gathers and merges thelocal_state_dict
from allmp_rank
s to create a fullstate_dict
.
-
load_state_dict
()¶ Same as the
torch.optimizer.load_state_dict()
, except:It first gathers and merges the local
state_dict
s if they are partial.The actual loading happens after the model partition so that each rank knows its local parameters.
-
local_state_dict
()¶ Returns the
state_dict
that contains the local optimizer state that belongs to the currentmp_rank
. Thisstate_dict
contains a key_smp_is_partial
to indicate this is a partialstate_dict
, which indicates whether thestate_dict
contains elements corresponding to only the current partition, or to the entire model.
smdistributed.modelparallel.torch Context Managers and Util Functions¶
-
smdistributed.modelparallel.torch.
model_creation
(tensor_parallelism=False, dtype=None, **tensor_parallel_config)¶ Context manager to create a
torch
model. This API combines both thesmdistributed.modelparallel.torch.tensor_parallelism
andsmdistributed.modelparallel.torch.delay_param_initialization
decorators, so you can simply use this single context when creating the torch model.- Parameters
tensor_parallelism (boolean) – Whether to enable tensor parallelism during model creation.
dtype (
torch.dtype
) –The dtype to use when creating the model. It has the following rules.
If dtype is specified, it will be used during model creation.
If dtype is not specified, the default dtype will be used during model creation, which is usually FP32. This is for the best performance on CPU.
Any model that causes out-of-memory problems with FP32 initialization is recommended to be created with
smdistributed.modelparallel.torch.delayed_parameter_initialization
.FP16_Module
casts the model back to FP16 if FP16 training is enabled with thesmp
config. For more inforamtion about FP16 training in SageMaker with the model parallel library, see FP16 Training in the Amazon SageMaker Developer Guide.
tensor_parallel_config (dict) – kwargs to specifiy other tensor parallel configs. This is not used if
tensor_parallelism
isFalse
.
Example Usage:
import smdistributed.modelparallel.torch as smp with smp.model_creation( tensor_parallelism=smp.tp_size() > 1, dtype=torch.float16 if args.fp16 else torch.get_default_dtype() ): model = MyModel(...)
-
smdistributed.modelparallel.torch.
partition
(index)¶ - Parameters
index (int) – The index of the partition.
A context manager which places all modules defined inside into the partition with ID
index
. Theindex
argument must be less than the number of partitions.Use
smdistributed.modelparallel.torch.partition
to implement manual partitioning. If"auto_partition"
isTrue
, then thesmdistributed.modelparallel.torch.partition
contexts are ignored. Any module that is not placed in anysmdistributed.modelparallel.torch.partition
context is placed in thedefault_partition
defined through the SageMaker Python SDK.When
smdistributed.modelparallel.torch.partition
contexts are nested, the innermost context overrides the rest (see the following example). In PyTorch, manual partitioning should be done inside the module__init__
, and the partition assignment applies to the modules that are created inside thesmdistributed.modelparallel.torch.partition
context.Example:
import smdistributed.modelparallel.torch as smp class Model(torch.nn.Module): def __init__(self): with smp.partition(1): self.child0 = Child0() # child0 on partition 1 with smp.partition(2): self.child1 = Child1() # child1 on partition 2 self.child2 = Child2() # child2 on partition 1 self.child3 = Child3() # child3 on default_partition
-
smdistributed.modelparallel.torch.amp.
GradScaler
¶ Torch AMP Gradscaler currently doesn’t work with the library.
smdistributed.modelparallel.torch.amp.GradScaler
replacestorch.amp.GradScaler
and provides the same functionality.
-
smdistributed.modelparallel.torch.
delay_param_initialization
(enabled=True)¶ If enabled, it delays the initialization of parameters to save CPU memory. That is, parameter initialization takes place after the model is partitioned on GPUs.
-
smdistributed.modelparallel.torch.
get_world_process_group
()¶ Returns a
torch.distributed
ProcessGroup
that consists of all processes, which can be used with thetorch.distributed
API. Requires"ddp": True
in SageMaker Python SDK parameters.
-
smdistributed.modelparallel.torch.
get_mp_process_group
()¶ Returns a
torch.distributed
ProcessGroup
that consists of the processes in theMP_GROUP
which contains the current process, which can be used with thetorch.distributed
API. Requires"ddp": True
in SageMaker Python SDK parameters.
-
smdistributed.modelparallel.torch.
get_dp_process_group
()¶ Returns a
torch.distributed
ProcessGroup
that consists of the processes in theDP_GROUP
which contains the current process, which can be used with thetorch.distributed
API. Requires"ddp": True
in SageMaker Python SDK parameters.
-
smdistributed.modelparallel.torch.
is_initialized
()¶ Returns
True
ifsmdistributed.modelparallel.torch.init
has already been called for the process, andFalse
otherwise.
-
smdistributed.modelparallel.torch.nn.
FusedLayerNorm
¶ Apex Fused Layer Norm is currently not supported by the library.
smdistributed.modelparallel.torch.nn.FusedLayerNorm
replacesapex
FusedLayerNorm
and provides the same functionality. This requiresapex
to be installed on the system.
-
smdistributed.modelparallel.torch.optimizers.
FusedNovoGrad
¶ Fused Novo Grad optimizer is currently not supported by the library.
smdistributed.modelparallel.torch.optimizers.FusedNovoGrad
replacesapex
FusedNovoGrad
optimizer and provides the same functionality. This requiresapex
to be installed on the system.
-
smdistributed.modelparallel.torch.optimizers.
FusedLamb
¶ FusedLamb optimizer currently doesn’t work with the library.
smdistributed.modelparallel.torch.optimizers.FusedLamb
replacesapex
FusedLamb
optimizer and provides the same functionality. This requiresapex
to be installed on the system.
smdistributed.modelparallel.torch APIs for Saving and Loading¶
-
smdistributed.modelparallel.torch.
save
(obj, f, partial=True, pickel_module=picklemodule, pickle_protocol=2)¶ Saves an object. This operation is similar to torch.save(), except that it has an additional keyword argument,
partial
, and accepts only string type for the argumentf
(file). Ifpartial=True
, eachmp_rank
saves a separate checkpoint file and the library adds anmp_rank
index to your saved file.Parameters
obj
(dict): A saved object.f
(str): A string containing a file name.partial
(bool, default=True
): When set toTrue
, eachmp_rank
saves a separate checkpoint file and the library adds anmp_rank
index to the saved file. If you want to be able to load and further train a model that you save withsmdistributed.modelparallel.torch.save()
, you must setpartial=True
.pickle_module
(picklemodule, default = module"pickle"
from"/opt/conda/lib/python3.6/pickle.py"
): A module used for pickling metadata and objects.pickle_protocol
(int, default=2): Can be specified to override the defaultprotocol.
-
smdistributed.modelparallel.torch.
load
(f, map_location, pickle_module, pickle_load_args, partial=True)¶ Loads an object saved with
smdistributed.modelparallel.torch.save()
from a file.Similar to, torch.load(), except it has an additional keyword argument,
partial
, and accepts only string type for the argumentf
(file). Ifpartial=True
, then eachmp_rank
loads a separate checkpoint file.Parameters
f
(string): A string containing a file name.map_location
(function): A function torch.device, a string, or a dict specifying how to remap storage locations.pickle_module
(pickle module): A module used for unpickling metadata and objects (has to match thepickle_module
used to serialize file).pickle_load_args
(Python 3 only): Optional keyword arguments passed topickle_module.load()
andpickle_module.Unpickler()
.partial
(bool, default=True
): When set toTrue
, eachmp_rank
loads the checkpoint corresponding to themp_rank
. Should be used when loading a model trained with the library.
-
smdistributed.modelparallel.torch.
save_checkpoint
(path, tag, partial=True, model=None, optimizer=None, user_content=None, translate_if_full=True, num_kept_partial_checkpoints=None)¶ Saves a checkpoint. While
smdistributed.modelparallel.torch.save
saves model and optimizer objects, this function checkpoints model and optimizer and saves the checkpoints as separate files. It creates checkpoint folders in the following structure.- path - ${tag}_partial (folder for partial checkpoint) - model_rankinfo.pt - optimizer_rankinfo.pt - fp16_states_rankinfo.pt - user_content.pt - $tag (checkpoint file for full checkpoint) - user_content_$tag (user_content file for full checkpoint) - newest (a file that indicates the newest checkpoint)
Parameters
path
(str) (required): Path to save the checkpoint. The library creates the directory if it does not already exist. For example,/opt/ml/checkpoint/model_parallel
.tag
(str) (required): A tag for the current checkpoint, usually the train steps. Note: tag needs to be the same across all ranks (GPU workers). Whenpartial=False
this will be the checkpoint file name.partial
(boolean) (default: True): Whether to save the partial checkpoint.model
(smdistributed.modelparallel.torch.DistributedModel
) (default: None): The model to save. It needs to ansmp.DistributedModel
object.optimizer
(smdistributed.modelparallel.torch.DistributedOptimizer
) (default: None): The optimizer to save. It needs to be ansmp.DistributedOptimizer
object.user_content
(any) (default: None): User-defined content to save.translate_if_full
(boolean) (default: True): Whether to translate the fullstate_dict
to HFstate_dict
if possible.num_kept_partial_checkpoints
(int) (default: None): The maximum number of partial checkpoints to keep on disk.
-
smdistributed.modelparallel.torch.
resume_from_checkpoint
(path, tag=None, partial=True, strict=True, load_optimizer=True, load_sharded_optimizer_state=True, translate_function=None)¶ While
smdistributed.modelparallel.torch.load
loads saved model and optimizer objects, this function resumes from a saved checkpoint file.Parameters
path
(str) (required): Path to load the checkpoint.tag
(str) (default: None): Tag of the checkpoint to resume. If not provided, the library tries to locate the newest checkpoint from the saved newest file.partial
(boolean) (default: True): Whether to load the partial checkpoint.strict
(boolean) (default: True): Load with strict load, no extra key or missing key is allowed.load_optimizer
(boolean) (default: True): Whether to loadoptimizer
.load_sharded_optimizer_state
(boolean) (default: True): Whether to load the sharded optimizer state of a model. It can be used only when you activate the sharded data parallelism feature of the SageMaker model parallel library. When this isFalse
, the library only loads the FP16 states, such as FP32 master parameters and the loss scaling factor, not the sharded optimizer states.translate_function
(function) (default: None): function to translate the full checkpoint into smdistributed.modelparallel format. For supported models, this is not required.
Example usage
# Save smp.save_checkpoint( checkpoint_dir, tag=f"total_steps{total_steps}", partial=True, model=model, optimizer=optimizer, user_content=user_content num_kept_partial_checkpoints=args.num_kept_checkpoints) # Load: this will automatically load the newest checkpoint user_content = smp.resume_from_checkpoint(path, partial=partial)
General instruction on saving and loading¶
The library can save partial or full checkpoints.
For partial checkpoints, each
mp_rank
saves its own checkpoint file with only the parameters that belong to that rank.For full checkpoints, the library saves a single checkpoint that contains entire model parameters.
When saving using smdistributed.modelparallel.torch.save()
, each rank only holds its own
parameters. If you want to save the full model, there will be some
communication between the ranks to create the full model. If you save
checkpoints often, you should save partial checkpoints for best
performance.
When loading using smdistributed.modelparallel.torch.load()
, the library can load either partial or |
full checkpoints or full checkpoints saved by a non-model-parallel model. If you
want to resume training with a non-model-parallel model or do inference, you need
a full checkpoint.
The following is an example of how you can save and load a checkpoint:
import smdistributed.modelparallel.torch as smp
# Original model and optimizer
model = MyModel(...)
optimizer = MyOpt(...)
# model parallel wrapper
model = smp.DistributedModel(model)
optimizer = smp.DistributedOptimizer(optimizer)
# To save, always save on dp_rank 0 to avoid data racing
if partial:
# To save the partial model on each mp rank
# the library will create `checkpoint.pt_{mprank}` for each mp rank
if save_partial_model:
if smp.dp_rank() == 0:
model_dict = model.local_state_dict() # save the partial model
opt_dict = optimizer.local_state_dict() # save the partial optimizer state
smp.save(
{"model_state_dict": model_dict, "optimizer_state_dict": opt_dict},
f"/checkpoint.pt",
partial=True,
)
# To save the full model
if save_full_model:
if smp.dp_rank() == 0:
model_dict = model.state_dict() # save the full model
opt_dict = optimizer.state_dict() # save the full optimizer state
smp.save(
{"model_state_dict": model_dict, "optimizer_state_dict": opt_dict},
"/checkpoint.pt",
partial=False,
)
# To load, load on all ranks.
# The only difference for partial/full loading is the partial flag in smp.load
# Load partial checkpoint
if partial_checkpoint:
checkpoint = smp.load("/checkpoint.pt", partial=True)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# Load full checkpoint
if full_checkpoint:
checkpoint = smp.load("/checkpoint.pt", partial=False)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])