TensorFlow

TensorFlow Estimator

class sagemaker.tensorflow.estimator.TensorFlow(training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2', framework_version='1.6', requirements_file='', **kwargs)

Bases: sagemaker.estimator.Framework

Handle end-to-end training and deployment of user-provided TensorFlow code.

Initialize an TensorFlow estimator. :param training_steps: Perform this many steps of training. None, the default means train forever. :type training_steps: int :param evaluation_steps: Perform this many steps of evaluation. None, the default means that evaluation

runs until input from eval_input_fn is exhausted (or another exception is raised).
Parameters:
  • checkpoint_path (str) – Identifies S3 location where checkpoint data during model training can be saved (default: None). For distributed model training, this parameter is required.
  • py_version (str) – Python version you want to use for executing your model training code (default: ‘py2’).
  • framework_version (str) – TensorFlow version you want to use for executing your model training code. List of supported versions https://github.com/aws/sagemaker-python-sdk#tensorflow-sagemaker-estimators
  • requirements_file (str) – Path to a requirements.txt file (default: ‘’). The path should be within and relative to source_dir. Details on the format can be found in the Pip User Guide.
  • **kwargs – Additional kwargs passed to the Framework constructor.
fit(inputs, wait=True, logs=True, job_name=None, run_tensorboard_locally=False)

Train a model using the input training dataset.

See fit() for more details.

Parameters:
  • inputs (str or dict or sagemaker.session.s3_input) –

    Information about the training data. This can be one of three types: (str) - the S3 location where training data is saved. (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for

    training data, you can specify a dict mapping channel names to strings or s3_input() objects.
    (sagemaker.session.s3_input) - channel configuration for S3 data sources that can provide
    additional information about the training dataset. See sagemaker.session.s3_input() for full details.
  • wait (bool) – Whether the call should wait until the job completes (default: True).
  • logs (bool) – Whether to show the logs produced by the job. Only meaningful when wait is True (default: True).
  • job_name (str) – Training job name. If not specified, the estimator generates a default job name, based on the training image name and current timestamp.
  • run_tensorboard_locally (bool) – Whether to execute TensorBoard in a different process with downloaded checkpoint information (default: False). This is an experimental feature, and requires TensorBoard and AWS CLI to be installed. It terminates TensorBoard when execution ends.
train_image()

Return the Docker image to use for training.

The fit() method, which does the model training, calls this method to find the image to use for model training.

Returns:The URI of the Docker image.
Return type:str
create_model(model_server_workers=None)

Create a SageMaker TensorFlowModel object that can be deployed to an Endpoint.

Parameters:model_server_workers (int) – Optional. The number of worker processes used by the inference server. If None, server will use one worker per vCPU.
Returns:
A SageMaker TensorFlowModel object.
See TensorFlowModel() for full details.
Return type:sagemaker.tensorflow.model.TensorFlowModel
hyperparameters()

Return hyperparameters used by your custom TensorFlow code during model training.

TensorFlow Model

class sagemaker.tensorflow.model.TensorFlowModel(model_data, role, entry_point, image=None, py_version='py2', framework_version='1.6', predictor_cls=<class 'sagemaker.tensorflow.model.TensorFlowPredictor'>, model_server_workers=None, **kwargs)

Bases: sagemaker.model.FrameworkModel

Initialize an TensorFlowModel.

Parameters:
  • model_data (str) – The S3 location of a SageMaker model data .tar.gz file.
  • role (str) – An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource.
  • entry_point (str) – Path (absolute or relative) to the Python source file which should be executed as the entry point to model hosting. This should be compatible with either Python 2.7 or Python 3.5.
  • image (str) – A Docker image URI (default: None). If not specified, a default image for TensorFlow will be used.
  • py_version (str) – Python version you want to use for executing your model training code (default: ‘py2’).
  • framework_version (str) – TensorFlow version you want to use for executing your model training code.
  • predictor_cls (callable[str, sagemaker.session.Session]) – A function to call to create a predictor with an endpoint name and SageMaker Session. If specified, deploy() returns the result of invoking this function on the created endpoint name.
  • model_server_workers (int) – Optional. The number of worker processes used by the inference server. If None, server will use one worker per vCPU.
  • **kwargs – Keyword arguments passed to the FrameworkModel initializer.
prepare_container_def(instance_type)

Return a container definition with framework configuration set in model environment variables.

This also uploads user-supplied code to S3.

Parameters:instance_type (str) – The EC2 instance type to deploy this Model to. For example, ‘ml.p2.xlarge’.
Returns:A container definition object usable with the CreateModel API.
Return type:dict[str, str]

TensorFlow Predictor

class sagemaker.tensorflow.model.TensorFlowPredictor(endpoint_name, sagemaker_session=None)

Bases: sagemaker.predictor.RealTimePredictor

A RealTimePredictor for inference against TensorFlow ``Endpoint``s.

This is able to serialize Python lists, dictionaries, and numpy arrays to multidimensional tensors for MXNet inference

Initialize an TensorFlowPredictor.

Parameters:
  • endpoint_name (str) – The name of the endpoint to perform inference on.
  • sagemaker_session (sagemaker.session.Session) – Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain.