Using MXNet with the SageMaker Python SDK

With the SageMaker Python SDK, you can train and host MXNet models on Amazon SageMaker.

For information about supported versions of MXNet, see the MXNet README.

For general information about using the SageMaker Python SDK, see Using the SageMaker Python SDK.

Train a Model with MXNet

To train an MXNet model by using the SageMaker Python SDK:

  1. Prepare a training script
  2. Create a sagemaker.mxnet.MXNet Estimator
  3. Call the estimator’s fit method

Prepare an MXNet Training Script


The structure for training scripts changed starting at MXNet version 1.3. Make sure you refer to the correct section of this README when you prepare your script. For information on how to upgrade an old script to the new format, see “Updating your MXNet training script”.

For versions 1.3 and higher

Your MXNet training script must be a Python 2.7 or 3.6 compatible source file.

The training script is very similar to a training script you might run outside of SageMaker, but you can access useful properties about the training environment through various environment variables, including the following:

  • SM_MODEL_DIR: A string that represents the path where the training job writes the model artifacts to. After training, artifacts in this directory are uploaded to S3 for model hosting.
  • SM_NUM_GPUS: An integer representing the number of GPUs available to the host.
  • SM_CHANNEL_XXXX: A string that represents the path to the directory that contains the input data for the specified channel. For example, if you specify two input channels in the MXNet estimator’s fit call, named ‘train’ and ‘test’, the environment variables SM_CHANNEL_TRAIN and SM_CHANNEL_TEST are set.
  • SM_HPS: A json dump of the hyperparameters preserving json types (boolean, integer, etc.)

For the exhaustive list of available environment variables, see the SageMaker Containers documentation.

A typical training script loads data from the input channels, configures training with hyperparameters, trains a model, and saves a model to model_dir so that it can be deployed for inference later. Hyperparameters are passed to your script as arguments and can be retrieved with an argparse.ArgumentParser instance. For example, a training script might start with the following:

import argparse
import os
import json

if __name__ =='__main__':

    parser = argparse.ArgumentParser()

    # hyperparameters sent by the client are passed as command-line arguments to the script.
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--batch-size', type=int, default=100)
    parser.add_argument('--learning-rate', type=float, default=0.1)

    # an alternative way to load hyperparameters via SM_HPS environment variable.
    parser.add_argument('--sm-hps', type=json.loads, default=os.environ['SM_HPS'])

    # input data and model directories
    parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
    parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
    parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST'])

    args, _ = parser.parse_known_args()

    # ... load from args.train and args.test, train a model, write model to args.model_dir.

Because the SageMaker imports your training script, you should put your training code in a main guard (if __name__=='__main__':) if you are using the same script to host your model, so that SageMaker does not inadvertently run your training code at the wrong point in execution.

Note that SageMaker doesn’t support argparse actions. If you want to use, for example, boolean hyperparameters, you need to specify type as bool in your script and provide an explicit True or False value for this hyperparameter when instantiating your MXNet estimator.

For more on training environment variables, please visit SageMaker Containers.

For versions 1.2 and lower

Your MXNet training script must be a Python 2.7 or 3.6 compatible source file. The MXNet training script must contain a function train, which SageMaker invokes to run training. You can include other functions as well, but it must contain a train function.

When you run your script on SageMaker via the MXNet Estimator, SageMaker injects information about the training environment into your training function via Python keyword arguments. You can choose to take advantage of these by including them as keyword arguments in your train function. The full list of arguments is:

  • hyperparameters (dict[string,string]): The hyperparameters passed to SageMaker TrainingJob that runs your MXNet training script. You can use this to pass hyperparameters to your training script.
  • input_data_config (dict[string,dict]): The SageMaker TrainingJob InputDataConfig object, that’s set when the SageMaker TrainingJob is created. This is discussed in more detail below.
  • channel_input_dirs (dict[string,string]): A collection of directories containing training data. When you run training, you can partition your training data into different logical “channels”. Depending on your problem, some common channel ideas are: “train”, “test”, “evaluation” or “images’,”labels”.
  • output_data_dir (str): A directory where your training script can write data that will be moved to S3 after training is complete.
  • num_gpus (int): The number of GPU devices available on your training instance.
  • num_cpus (int): The number of CPU devices available on your training instance.
  • hosts (list[str]): The list of host names running in the SageMaker Training Job cluster.
  • current_host (str): The name of the host executing the script. When you use SageMaker for MXNet training, the script is run on each host in the cluster.

A training script that takes advantage of all arguments would have the following definition:

def train(hyperparameters, input_data_config, channel_input_dirs, output_data_dir,
          num_gpus, num_cpus, hosts, current_host):

You don’t have to use all the arguments, arguments you don’t care about can be ignored by including **kwargs.

# Only work with hyperparameters and num_gpus, ignore all other hyperparameters
def train(hyperparameters, num_gpus, **kwargs):

Note: Writing a training script that imports correctly: When SageMaker runs your training script, it imports it as a Python module and then invokes train on the imported module. Consequently, you should not include any statements that won’t execute successfully in SageMaker when your module is imported. For example, don’t attempt to open any local files in top-level statements in your training script.

If you want to run your training script locally by using the Python interpreter, use a ___name__ == '__main__' guard. For more information, see

Save the Model

Just as you enable training by defining a train function in your training script, you enable model saving by defining a save function in your script. If your script includes a save function, SageMaker will invoke it with the return-value of train. Model saving is a two-step process, firstly you return the model you want to save from train, then you define your model-serialization logic in save.

SageMaker provides a default implementation of save that works with MXNet Module API Module objects. If your training script does not define a save function, then the default save function will be invoked on the return-value of your train function.

The default serialization system generates three files:

  • model-shapes.json: A json list, containing a serialization of the Module data_shapes property. Each object in the list contains the serialization of one DataShape in the returned Module. Each object has a name property, containing the DataShape name and a shape property, which is a list of that dimensions for the shape of that DataShape. For example:
    {"name":"images", "shape":[100, 1, 28, 28]},
    {"name":"labels", "shape":[100, 1]}
  • model-symbol.json: The MXNet Module Symbol serialization, produced by invoking save on the symbol property of the Module being saved.
  • modle.params: The MXNet Module parameters. Produced by invoking save_params on the Module being saved.

You can provide your own save function. This is useful if you are not working with the Module API or you need special processing.

To provide your own save function, define a save function in your training script:

def save(model, model_dir):

The function should take two arguments:

  • model: This is the object that was returned from your train function. If your train function does not return an object, it will be None. You are free to return an object of any type from train, you do not have to return Module or Gluon API specific objects.
  • model_dir: This is the string path on the SageMaker training host where you save your model. Files created in this directory will be accessible in S3 after your SageMaker Training Job completes.

After your train function completes, SageMaker will invoke save with the object returned from train.

Note: How to save Gluon models with SageMaker

If your train function returns a Gluon API net object as its model, you’ll need to write your own save function. You will want to serialize the net parameters. Saving net parameters is covered in the Serialization section of the collaborative Gluon deep-learning book “The Straight Dope”.

Save a Checkpoint

It is good practice to save the best model after each training epoch, so that you can resume a training job if it gets interrupted. This is particularly important if you are using Managed Spot training.

To save MXNet model checkpoints, do the following in your training script:

  • Set the CHECKPOINTS_DIR environment variable and enable checkpoints.

    CHECKPOINTS_DIR = '/opt/ml/checkpoints'
    checkpoints_enabled = os.path.exists(CHECKPOINTS_DIR)
  • Make sure you are emitting a validation metric to test the model. For information, see Evaluation Metric API.

  • After each training epoch, test whether the current model performs the best with respect to the validation metric, and if it does, save that model to CHECKPOINTS_DIR.

    if checkpoints_enabled and current_host == hosts[0]:
           if val_acc > best_accuracy:
               best_accuracy = val_acc
     'Saving the model, params and optimizer state')
               net.export(CHECKPOINTS_DIR + "/%.4f-cifar10"%(best_accuracy), epoch)
               trainer.save_states(CHECKPOINTS_DIR + '/%.4f-cifar10-%d.states'%(best_accuracy, epoch))

For a complete example of an MXNet training script that impelements checkpointing, see

Updating your MXNet training script

The structure for training scripts changed with MXNet version 1.3. The train function is no longer be required; instead the training script must be able to be run as a standalone script. In this way, the training script is similar to a training script you might run outside of SageMaker.

There are a few steps needed to make a training script with the old format compatible with the new format.

First, add a main guard (if __name__ == '__main__':). The code executed from your main guard needs to:

  1. Set hyperparameters and directory locations
  2. Initiate training
  3. Save the model

Hyperparameters will be passed as command-line arguments to your training script. In addition, the container will define the locations of input data and where to save the model artifacts and output data as environment variables rather than passing that information as arguments to the train function. You can find the full list of available environment variables in the SageMaker Containers README.

We recommend using an argument parser for this part. Using the argparse library as an example, the code would look something like this:

import argparse
import os

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # hyperparameters sent by the client are passed as command-line arguments to the script.
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--batch-size', type=int, default=100)
    parser.add_argument('--learning-rate', type=float, default=0.1)

    # input data and model directories
    parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
    parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
    parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST'])

    args, _ = parser.parse_known_args()

The code in the main guard should also take care of training and saving the model. This can be as simple as just calling the train and save methods used in the previous training script format:

if __name__ == '__main__':
    # arg parsing (shown above) goes here

    model = train(args.batch_size, args.epochs, args.learning_rate, args.train, args.test)
    save(args.model_dir, model)

Note that saving the model will no longer be done by default; this must be done by the training script. If you were previously relying on the default save method, you can now import one from the container:

from sagemaker_mxnet_container.training_utils import save

if __name__ == '__main__':
    # arg parsing and training (shown above) goes here

    save(args.model_dir, model)

Lastly, if you were relying on the container launching a parameter server for use with distributed training, you must now set distributions to the following dictionary when creating an MXNet estimator:

from sagemaker.mxnet import MXNet

estimator = MXNet('',
                  distributions={'parameter_server': {'enabled': True}})
Using third-party libraries

When running your training script on SageMaker, it will have access to some pre-installed third-party libraries including mxnet, numpy, onnx, and keras-mxnet. For more information on the runtime environment, including specific package versions, see SageMaker MXNet Containers.

If there are other packages you want to use with your script, you can include a requirements.txt file in the same directory as your training script to install other dependencies at runtime. A requirements.txt file is a text file that contains a list of items that are installed by using pip install. You can also specify the version of an item to install. For information about the format of a requirements.txt file, see Requirements Files in the pip documentation.

Create an Estimator

You run MXNet training scripts on SageMaker by creating an MXNet estimator. When you call fit on an MXNet estimator, SageMaker starts a training job using your script as training code. The following code sample shows how you train a custom MXNet script “”.

mxnet_estimator = MXNet('',
                        hyperparameters={'batch-size': 100,
                                         'epochs': 10,
                                         'learning-rate': 0.1})'s3://my_bucket/my_training_data/')

For more information about the sagemaker.mxnet.MXNet estimator, see sagemaker.mxnet.MXNet Class.

Call the fit Method

You start your training script by calling fit on an MXNet Estimator. fit takes both required and optional arguments.

fit Required argument

  • inputs: This can take one of the following forms: A string S3 URI, for example s3://my-bucket/my-training-data. In this case, the S3 objects rooted at the my-training-data prefix will be available in the default training channel. A dict from string channel names to S3 URIs. In this case, the objects rooted at each S3 prefix will available as files in each channel directory.

For example:


fit Optional arguments

  • wait: Defaults to True, whether to block and wait for the training script to complete before returning.
  • logs: Defaults to True, whether to show logs produced by training job in the Python session. Only meaningful when wait is True.

Distributed training

If you want to use parameter servers for distributed training, set the following parameter in your MXNet constructor:

distributions={'parameter_server': {'enabled': True}}

Then, when writing a distributed training script, use an MXNet kvstore to store and share model parameters. During training, SageMaker automatically starts an MXNet kvstore server and scheduler processes on hosts in your training job cluster. Your script runs as an MXNet worker task, with one server process on each host in your cluster. One host is selected arbitrarily to run the scheduler process.

To learn more about writing distributed MXNet programs, please see Distributed Training in the MXNet docs.

Deploy MXNet models

After an MXNet Estimator has been fit, you can host the newly created model in SageMaker.

After calling fit, you can call deploy on an MXNet Estimator to create a SageMaker Endpoint. The Endpoint runs a SageMaker-provided MXNet model server and hosts the model produced by your training script, which was run when you called fit. This was the model object you returned from train and saved with either a custom save function or the default save function.

deploy returns a Predictor object, which you can use to do inference on the Endpoint hosting your MXNet model. Each Predictor provides a predict method which can do inference with numpy arrays or Python lists. Inference arrays or lists are serialized and sent to the MXNet model server by an InvokeEndpoint SageMaker operation.

predict returns the result of inference against your model. By default, the inference result is either a Python list or dictionary.

# Train my estimator
mxnet_estimator = MXNet('',

# Deploy my estimator to a SageMaker Endpoint and get a Predictor
predictor = mxnet_estimator.deploy(instance_type='ml.m4.xlarge',

You use the SageMaker MXNet model server to host your MXNet model when you call deploy on an MXNet Estimator. The model server runs inside a SageMaker Endpoint, which your call to deploy creates. You can access the name of the Endpoint by the name property on the returned Predictor.

MXNet on SageMaker has support for Elastic Inference, which allows for inference acceleration to a hosted endpoint for a fraction of the cost of using a full GPU instance. In order to attach an Elastic Inference accelerator to your endpoint provide the accelerator type to accelerator_type to your deploy call.

predictor = mxnet_estimator.deploy(instance_type='ml.m4.xlarge',

The SageMaker MXNet Model Server

The MXNet Endpoint you create with deploy runs a SageMaker MXNet model server. The model server loads the model that was saved by your training script and performs inference on the model in response to SageMaker InvokeEndpoint API calls.

You can configure two components of the SageMaker MXNet model server: Model loading and model serving. Model loading is the process of deserializing your saved model back into an MXNet model. Serving is the process of translating InvokeEndpoint requests to inference calls on the loaded model.

As with MXNet training, you configure the MXNet model server by defining functions in the Python source file you passed to the MXNet constructor.

Load a Model

Before a model can be served, it must be loaded. The SageMaker model server loads your model by invoking a model_fn function on your training script. If you don’t provide a model_fn function, SageMaker will use a default model_fn function. The default function works with MXNet Module model objects, saved via the default save function.

If you wrote a custom save function then you may need to write a custom model_fn function. If your save function serializes Module objects under the same format as the default save function, then you won’t need to write a custom model_fn function. If you do write a model_fn function must have the following signature:

def model_fn(model_dir)

SageMaker will inject the directory where your model files and sub-directories, saved by save, have been mounted. Your model function should return a model object that can be used for model serving. SageMaker provides automated serving functions that work with Gluon API net objects and Module API Module objects. If you return either of these types of objects, then you will be able to use the default serving request handling functions.

The following code-snippet shows an example custom model_fn implementation. This loads returns an MXNet Gluon net model for resnet-34 inference. It loads the model parameters from a model.params file in the SageMaker model directory.

def model_fn(model_dir):
    Load the gluon model. Called once when hosting service starts.
    :param: model_dir The directory where model files are stored.
    :return: a model (in this case a Gluon network)
    net = models.get_model('resnet34_v2', ctx=mx.cpu(), pretrained=False, classes=10)
    net.load_params('%s/model.params' % model_dir, ctx=mx.cpu())
    return net

MXNet on SageMaker has support for Elastic Inference, which allows for inference acceleration to a hosted endpoint for a fraction of the cost of using a full GPU instance. In order to load and serve your MXNet model through Amazon Elastic Inference, the MXNet context passed to your MXNet Symbol or Module object within your model_fn needs to be set to eia, as shown here.

Based on the example above, the following code-snippet shows an example custom model_fn implementation, which enables loading and serving our MXNet model through Amazon Elastic Inference.

def model_fn(model_dir):
    Load the gluon model in an Elastic Inference context. Called once when hosting service starts.
    :param: model_dir The directory where model files are stored.
    :return: a model (in this case a Gluon network)
    net = models.get_model('resnet34_v2', ctx=mx.eia(), pretrained=False, classes=10)
    net.load_params('%s/model.params' % model_dir, ctx=mx.eia())
    return net

The default_model_fn will load and serve your model through Elastic Inference, if applicable, within the SageMaker MXNet containers.

For more information on how to enable MXNet to interact with Amazon Elastic Inference, see Use Elastic Inference with MXNet.

Serve an MXNet Model

After the SageMaker model server loads your model by calling either the default model_fn or the implementation in your script, SageMaker serves your model. Model serving is the process of responding to inference requests received by SageMaker InvokeEndpoint API calls. Defining how to handle these requests can be done in one of two ways:

  • using input_fn, predict_fn, and output_fn, some of which may be your own implementations
  • writing your own transform_fn for handling input processing, prediction, and output processing
Using input_fn, predict_fn, and output_fn

The SageMaker MXNet model server breaks request handling into three steps:

  • input processing
  • prediction
  • output processing

Just like with model_fn, you configure these steps by defining functions in your Python source file.

Each step has its own Python function, which takes in information about the request and the return value from the previous function in the chain. Inside the SageMaker MXNet model server, the process looks like:

# Deserialize the Invoke request body into an object we can perform prediction on
input_object = input_fn(request_body, request_content_type)

# Perform prediction on the deserialized object, with the loaded model
prediction = predict_fn(input_object, model)

# Serialize the prediction result into the desired response content type
ouput = output_fn(prediction, response_content_type)

The above code sample shows the three function definitions that correlate to the three steps mentioned above:

  • input_fn: Takes request data and deserializes the data into an object for prediction.
  • predict_fn: Takes the deserialized request object and performs inference against the loaded model.
  • output_fn: Takes the result of prediction and serializes this according to the response content type.

The SageMaker MXNet model server provides default implementations of these functions. These work with both Gluon API and Module API model objects. The following content types are supported:

  • Gluon API: ‘application/json’, ‘application/x-npy’
  • Module API: ‘application/json’, ‘application/x-npy’, ‘text-csv’

You can also provide your own implementations for these functions in your training script. If you omit any definition then the SageMaker MXNet model server will use its default implementation for that function.

If you rely solely on the SageMaker MXNet model server defaults, you get the following functionality:

  • Prediction on MXNet Gluon API net and Module API Module objects.
  • Deserialization from CSV and JSON to NDArrayIters.
  • Serialization of NDArrayIters to CSV or JSON.

In the following sections we describe the default implementations of input_fn, predict_fn, and output_fn. We describe the input arguments and expected return types of each, so you can define your own implementations.

Process Model Input

When an InvokeEndpoint operation is made against an Endpoint running a SageMaker MXNet model server, the model server receives two pieces of information:

  • The request’s content type, for example “application/json”
  • The request data body as a byte array

The SageMaker MXNet model server will invoke input_fn, passing in this information. If you define an input_fn function definition, it should return an object that can be passed to predict_fn and have the following signature:

def input_fn(request_body, request_content_type)

Where request_body is a byte buffer and request_content_type is the content type of the request.

The SageMaker MXNet model server provides a default implementation of input_fn. This function deserializes JSON or CSV encoded data into an MXNet NDArrayIter (external API docs) multi-dimensional array iterator. This works with the default predict_fn implementation, which expects an NDArrayIter as input.

Default JSON deserialization requires request_body contain a single json list. Sending multiple json objects within the same request_body is not supported. The list must have a dimensionality compatible with the MXNet net or Module object. Specifically, after the list is loaded, it’s either padded or split to fit the first dimension of the model input shape. The list’s shape must be identical to the model’s input shape, for all dimensions after the first.

Default CSV deserialization requires request_body contain one or more lines of CSV numerical data. The data is loaded into a two-dimensional array, where each line break defines the boundaries of the first dimension. This two-dimensional array is then re-shaped to be compatible with the shape expected by the model object. Specifically, the first dimension is kept unchanged, but the second dimension is reshaped to be consistent with the shape of all dimensions in the model, following the first dimension.

If you provide your own implementation of input_fn, you should abide by the input_fn signature. If you want to use this with the default predict_fn, then you should return an NDArrayIter. The NDArrayIter should have a shape identical to the shape of the model being predicted on. The example below shows a custom input_fn for preparing pickled numpy arrays.

import numpy as np
import mxnet as mx

def input_fn(request_body, request_content_type):
    """An input_fn that loads a pickled numpy array"""
    if request_content_type == 'application/python-pickle':
        array = np.load(StringIO(request_body))
        # Handle other content-types here or raise an Exception
        # if the content type is not supported.

Getting Predictions from a Deployed Model

After the inference request has been deserialized by input_fn, the SageMaker MXNet model server invokes predict_fn. As with input_fn, you can define your own predict_fn or use the SageMaker Mxnet default.

The predict_fn function has the following signature:

def predict_fn(input_object, model)

Where input_object is the object returned from input_fn and model is the model loaded by model_fn.

The default implementation of predict_fn requires input_object be an NDArrayIter, which is the return-type of the default input_fn. It also requires that model be either an MXNet Gluon API net object or a Module API Module object.

The default implementation performs inference with the input NDArrayIter on the Gluon or Module object. If the model is a Gluon net it performs: net.forward(input_object). If the model is a Module object it performs module.predict(input_object). In both cases, it returns the result of that call.

If you implement your own prediction function, you should take care to ensure that:

  • The first argument is expected to be the return value from input_fn. If you use the default input_fn, this will be an NDArrayIter.
  • The second argument is the loaded model. If you use the default model_fn implementation, this will be an MXNet Module object. Otherwise, it will be the return value of your model_fn implementation.
  • The return value should be of the correct type to be passed as the first argument to output_fn. If you use the default output_fn, this should be an NDArrayIter.

Processing Model Output

After invoking predict_fn, the model server invokes output_fn, passing in the return value from predict_fn and the InvokeEndpoint requested response content type.

The output_fn has the following signature:

def output_fn(prediction, content_type)

Where prediction is the result of invoking predict_fn and content_type is the requested response content type for InvokeEndpoint. The function should return an array of bytes serialized to the expected content type.

The default implementation expects prediction to be an NDArray and can serialize the result to either JSON or CSV. It accepts response content types of “application/json” and “text/csv”.

Using transform_fn

If you would rather not structure your code around the three methods described above, you can instead define your own transform_fn to handle inference requests. An error will be thrown if a transform_fn is present in conjunction with any input_fn, predict_fn, and/or output_fn. transform_fn has the following signature:

def transform_fn(model, request_body, content_type, accept_type)

Where model is the model objected loaded by model_fn, request_body is the data from the inference request, content_type is the content type of the request, and accept_type is the request content type for the response.

This one function should handle processing the input, performing a prediction, and processing the output. The return object should be one of the following:

For versions 1.4 and higher:

  • a tuple with two items: the response data and accept_type (the content type of the response data), or
  • the response data: (the content type of the response will be set to either the accept header in the initial request or default to “application/json”)

For versions 1.3 and lower:

You can find examples of hosting scripts using this structure in the example notebooks, such as the mxnet_gluon_sentiment notebook.

Working with existing model data and training jobs

Attach to Existing Training Jobs

You can attach an MXNet Estimator to an existing training job using the attach method.

my_training_job_name = 'MyAwesomeMXNetTrainingJob'
mxnet_estimator = MXNet.attach(my_training_job_name)

After attaching, if the training job is in a Complete status, it can be deployed to create a SageMaker Endpoint and return a Predictor. If the training job is in progress, attach will block and display log messages from the training job, until the training job completes.

The attach method accepts the following arguments:

  • training_job_name (str): The name of the training job to attach to.
  • sagemaker_session (sagemaker.Session or None): The Session used to interact with SageMaker

Deploy Endpoints from Model Data

As well as attaching to existing training jobs, you can deploy models directly from model data in S3. The following code sample shows how to do this, using the MXNetModel class.

mxnet_model = MXNetModel(model_data='s3://bucket/model.tar.gz', role='SageMakerRole', entry_point='')

predictor = mxnet_model.deploy(instance_type='ml.c4.xlarge', initial_instance_count=1)

The MXNetModel constructor takes the following arguments:

  • model_data (str): An S3 location of a SageMaker model data .tar.gz file
  • image (str): A Docker image URI
  • role (str): An IAM role name or Arn for SageMaker to access AWS resources on your behalf.
  • predictor_cls (callable[string,sagemaker.Session]): A function to call to create a predictor. If not None, deploy will return the result of invoking this function on the created endpoint name
  • env (dict[string,string]): Environment variables to run with image when hosted in SageMaker.
  • name (str): The model name. If None, a default model name will be selected on each deploy.
  • entry_point (str): Path (absolute or relative) to the Python file which should be executed as the entry point to model hosting.
  • source_dir (str): Optional. Path (absolute or relative) to a directory with any other training source code dependencies including tne entry point file. Structure within this directory will be preserved when training on SageMaker.
  • container_log_level (int): Log level to use within the container. Valid values are defined in the Python logging module.
  • code_location (str): Optional. Name of the S3 bucket where your custom code will be uploaded to. If not specified, will use the SageMaker default bucket created by sagemaker.Session.
  • sagemaker_session (sagemaker.Session): The SageMaker Session object, used for SageMaker interaction

Your model data must be a .tar.gz file in S3. SageMaker Training Job model data is saved to .tar.gz files in S3, however if you have local data you want to deploy, you can prepare the data yourself.

Assuming you have a local directory containg your model data named “my_model” you can tar and gzip compress the file and upload to S3 using the following commands:

tar -czf model.tar.gz my_model
aws s3 cp model.tar.gz s3://my-bucket/my-path/model.tar.gz

This uploads the contents of my_model to a gzip compressed tar file to S3 in the bucket “my-bucket”, with the key “my-path/model.tar.gz”.

To run this command, you’ll need the aws cli tool installed. Please refer to our FAQ for more information on installing this.


Amazon provides several example Jupyter notebooks that demonstrate end-to-end training on Amazon SageMaker using MXNet. Please refer to:

These are also available in SageMaker Notebook Instance hosted Jupyter notebooks under the “sample notebooks” folder.

sagemaker.mxnet.MXNet Class

The following are the most commonly used MXNet constructor arguments.

Required arguments

The following are required arguments to the MXNet constructor. When you create an MXNet object, you must include these in the constructor, either positionally or as keyword arguments.

  • entry_point Path (absolute or relative) to the Python file which should be executed as the entry point to training.
  • role An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if accessing AWS resource.
  • train_instance_count Number of Amazon EC2 instances to use for training.
  • train_instance_type Type of EC2 instance to use for training, for example, ‘ml.c4.xlarge’.

Optional arguments

The following are optional arguments. When you create an MXNet object, you can specify these as keyword arguments.

  • source_dir Path (absolute or relative) to a directory with any other training source code dependencies including the entry point file. Structure within this directory will be preserved when training on SageMaker.

  • dependencies (list[str]) A list of paths to directories (absolute or relative) with any additional libraries that will be exported to the container (default: []). The library folders will be copied to SageMaker in the same folder where the entrypoint is copied. If the source_dir points to S3, code will be uploaded and the S3 location will be used instead. For example, the following call

    >>> MXNet(entry_point='', dependencies=['my/libs/common', 'virtual-env'])

    results in the following inside the container:

      ├── common
      └── virtual-env
  • hyperparameters Hyperparameters that will be used for training. Will be made accessible as a dict[str, str] to the training code on SageMaker. For convenience, accepts other types besides str, but str() will be called on keys and values to convert them before training.

  • py_version Python version you want to use for executing your model training code. Valid values: ‘py2’ and ‘py3’.

  • train_volume_size Size in GB of the EBS volume to use for storing input data during training. Must be large enough to store training data if input_mode=’File’ is used (which is the default).

  • train_max_run Timeout in seconds for training, after which Amazon SageMaker terminates the job regardless of its current status.

  • input_mode The input mode that the algorithm supports. Valid modes: ‘File’ - Amazon SageMaker copies the training dataset from the S3 location to a directory in the Docker container. ‘Pipe’ - Amazon SageMaker streams data directly from S3 to the container via a Unix named pipe.

  • output_path Location where you want the training result (model artifacts and optional output files) saved. This should be an S3 location unless you’re using Local Mode, which also supports local output paths. If not specified, results are stored to a default S3 bucket.

  • output_kms_key Optional KMS key ID to optionally encrypt training output with.

  • job_name Name to assign for the training job that the fit() method launches. If not specified, the estimator generates a default job name, based on the training image name and current timestamp

  • image_name An alternative docker image to use for training and serving. If specified, the estimator will use this image for training and hosting, instead of selecting the appropriate SageMaker official image based on framework_version and py_version. Refer to: SageMaker MXNet Docker Containers for details on what the Official images support and where to find the source code to build your custom image.

  • distributions For versions 1.3 and above only. Specifies information for how to run distributed training. To launch a parameter server during training, set this argument to:

  'parameter_server': {
    'enabled': True

SageMaker MXNet Containers

For information about SageMaker MXNet containers, see the following topics:

For information about the dependencies installed in SageMaker MXNet containers, see