Guide for PyTorch

Use this guide to learn how to use the SageMaker distributed data parallel library API for PyTorch.

Use the SageMaker Distributed Data Parallel Library as a Backend of torch.distributed

To use the SageMaker distributed data parallel library, the only thing you need to do is to import the SageMaker distributed data parallel library’s PyTorch client (smdistributed.dataparallel.torch.torch_smddp). The client registers smddp as a backend for PyTorch. When you initialize the PyTorch distributed process group using the torch.distributed.init_process_group API, make sure you specify 'smddp' to the backend argument.

import smdistributed.dataparallel.torch.torch_smddp
import torch.distributed as dist

dist.init_process_group(backend='smddp')

If you already have a working PyTorch script and only need to add the backend specification, you can proceed to Launch a Distributed Training Job Using the SageMaker Python SDK.

Note

The smddp backend currently does not support creating subprocess groups with the torch.distributed.new_group() API. You cannot use the smddp backend concurrently with other backends.

See also

If you still need to modify your training script to properly use the PyTorch distributed package, see Preparing a PyTorch Training Script for Distributed Training in the Amazon SageMaker Developer Guide.

PyTorch API

Since v1.4.0, the SageMaker distributed data parallel library supports the PyTorch distributed package as a backend option. To use the library with PyTorch in SageMaker, you simply specify the backend of the PyTorch distributed package as 'smddp' when initializing process group.

torch.distributed.init_process_group(backend='smddp')

You don’t need to modify your script using the smdistributed implementation of the PyTorch distributed modules that are supported in the library v1.3.0 and before.

Warning

The following APIs for smdistributed implementation of the PyTorch distributed modules are deprecated.

class smdistributed.dataparallel.torch.parallel.DistributedDataParallel(module, device_ids=None, output_device=None, broadcast_buffers=True, process_group=None, bucket_cap_mb=None)

Deprecated since version 1.4.0: Use the torch.nn.parallel.DistributedDataParallel API instead.

smdistributed.dataparallel.torch.distributed.is_available()

Deprecated since version 1.4.0: Use the torch.distributed package instead. For more information, see Initialization in the PyTorch documentation.

smdistributed.dataparallel.torch.distributed.init_process_group(*args, **kwargs)

Deprecated since version 1.4.0: Use the torch.distributed package instead. For more information, see Initialization in the PyTorch documentation.

smdistributed.dataparallel.torch.distributed.is_initialized()

Deprecated since version 1.4.0: Use the torch.distributed package instead. For more information, see Initialization in the PyTorch documentation.

smdistributed.dataparallel.torch.distributed.get_world_size(group=smdistributed.dataparallel.torch.distributed.group.WORLD)

Deprecated since version 1.4.0: Use the torch.distributed package instead. For more information, see Post-Initialization in the PyTorch documentation.

smdistributed.dataparallel.torch.distributed.get_rank(group=smdistributed.dataparallel.torch.distributed.group.WORLD)

Deprecated since version 1.4.0: Use the torch.distributed package instead. For more information, see Post-Initialization in the PyTorch documentation.

smdistributed.dataparallel.torch.distributed.get_local_rank()

Deprecated since version 1.4.0: Use the torch.distributed package instead.

smdistributed.dataparallel.torch.distributed.all_reduce(tensor, op=smdistributed.dataparallel.torch.distributed.ReduceOp.SUM, group=smdistributed.dataparallel.torch.distributed.group.WORLD, async_op=False)

Deprecated since version 1.4.0: Use the torch.distributed package instead.

smdistributed.dataparallel.torch.distributed.broadcast(tensor, src=0, group=smdistributed.dataparallel.torch.distributed.group.WORLD, async_op=False)

Deprecated since version 1.4.0: Use the torch.distributed package instead.

smdistributed.dataparallel.torch.distributed.all_gather(tensor_list, tensor, group=smdistributed.dataparallel.torch.distributed.group.WORLD, async_op=False)

Deprecated since version 1.4.0: Use the torch.distributed package instead.

smdistributed.dataparallel.torch.distributed.all_to_all_single(output_t, input_t, output_split_sizes=None, input_split_sizes=None, group=group.WORLD, async_op=False)

Deprecated since version 1.4.0: Use the torch.distributed package instead.

smdistributed.dataparallel.torch.distributed.barrier(group=smdistributed.dataparallel.torch.distributed.group.WORLD, async_op=False)

Deprecated since version 1.4.0: Use the torch.distributed package instead.

class smdistributed.dataparallel.torch.distributed.ReduceOp

Deprecated since version 1.4.0: Use the torch.distributed package instead.