Guide for PyTorch¶
Use this guide to learn how to use the SageMaker distributed data parallel library API for PyTorch.
Topics
Use the SageMaker Distributed Data Parallel Library as a Backend of torch.distributed
¶
To use the SageMaker distributed data parallel library,
the only thing you need to do is to import the SageMaker distributed data
parallel library’s PyTorch client (smdistributed.dataparallel.torch.torch_smddp
).
The client registers smddp
as a backend for PyTorch.
When you initialize the PyTorch distributed process group using
the torch.distributed.init_process_group
API,
make sure you specify 'smddp'
to the backend argument.
import smdistributed.dataparallel.torch.torch_smddp
import torch.distributed as dist
dist.init_process_group(backend='smddp')
If you already have a working PyTorch script and only need to add the backend specification, you can proceed to Launch a Distributed Training Job Using the SageMaker Python SDK.
Note
The smddp
backend currently does not support creating subprocess groups
with the torch.distributed.new_group()
API.
You cannot use the smddp
backend concurrently with other backends.
See also
If you still need to modify your training script to properly use the PyTorch distributed package, see Preparing a PyTorch Training Script for Distributed Training in the Amazon SageMaker Developer Guide.
PyTorch API¶
Since v1.4.0, the SageMaker distributed data parallel library
supports the PyTorch distributed package as a backend option.
To use the library with PyTorch in SageMaker,
you simply specify the backend of
the PyTorch distributed package as 'smddp'
when initializing process group.
torch.distributed.init_process_group(backend='smddp')
You don’t need to modify your script using
the smdistributed
implementation of the PyTorch distributed modules
that are supported in the library v1.3.0 and before.
Warning
The following APIs for smdistributed
implementation of the PyTorch distributed modules
are deprecated.
- class smdistributed.dataparallel.torch.parallel.DistributedDataParallel(module, device_ids=None, output_device=None, broadcast_buffers=True, process_group=None, bucket_cap_mb=None)¶
Deprecated since version 1.4.0: Use the torch.nn.parallel.DistributedDataParallel API instead.
- smdistributed.dataparallel.torch.distributed.is_available()¶
Deprecated since version 1.4.0: Use the torch.distributed package instead. For more information, see Initialization in the PyTorch documentation.
- smdistributed.dataparallel.torch.distributed.init_process_group(*args, **kwargs)¶
Deprecated since version 1.4.0: Use the torch.distributed package instead. For more information, see Initialization in the PyTorch documentation.
- smdistributed.dataparallel.torch.distributed.is_initialized()¶
Deprecated since version 1.4.0: Use the torch.distributed package instead. For more information, see Initialization in the PyTorch documentation.
- smdistributed.dataparallel.torch.distributed.get_world_size(group=smdistributed.dataparallel.torch.distributed.group.WORLD)¶
Deprecated since version 1.4.0: Use the torch.distributed package instead. For more information, see Post-Initialization in the PyTorch documentation.
- smdistributed.dataparallel.torch.distributed.get_rank(group=smdistributed.dataparallel.torch.distributed.group.WORLD)¶
Deprecated since version 1.4.0: Use the torch.distributed package instead. For more information, see Post-Initialization in the PyTorch documentation.
- smdistributed.dataparallel.torch.distributed.get_local_rank()¶
Deprecated since version 1.4.0: Use the torch.distributed package instead.
- smdistributed.dataparallel.torch.distributed.all_reduce(tensor, op=smdistributed.dataparallel.torch.distributed.ReduceOp.SUM, group=smdistributed.dataparallel.torch.distributed.group.WORLD, async_op=False)¶
Deprecated since version 1.4.0: Use the torch.distributed package instead.
- smdistributed.dataparallel.torch.distributed.broadcast(tensor, src=0, group=smdistributed.dataparallel.torch.distributed.group.WORLD, async_op=False)¶
Deprecated since version 1.4.0: Use the torch.distributed package instead.
- smdistributed.dataparallel.torch.distributed.all_gather(tensor_list, tensor, group=smdistributed.dataparallel.torch.distributed.group.WORLD, async_op=False)¶
Deprecated since version 1.4.0: Use the torch.distributed package instead.
- smdistributed.dataparallel.torch.distributed.all_to_all_single(output_t, input_t, output_split_sizes=None, input_split_sizes=None, group=group.WORLD, async_op=False)¶
Deprecated since version 1.4.0: Use the torch.distributed package instead.
- smdistributed.dataparallel.torch.distributed.barrier(group=smdistributed.dataparallel.torch.distributed.group.WORLD, async_op=False)¶
Deprecated since version 1.4.0: Use the torch.distributed package instead.
- class smdistributed.dataparallel.torch.distributed.ReduceOp¶
Deprecated since version 1.4.0: Use the torch.distributed package instead.