TensorFlow¶
TensorFlow Estimator¶
-
class
sagemaker.tensorflow.estimator.
TensorFlow
(training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2', framework_version='1.8', requirements_file='', **kwargs)¶ Bases:
sagemaker.estimator.Framework
Handle end-to-end training and deployment of user-provided TensorFlow code.
Initialize an
TensorFlow
estimator. :param training_steps: Perform this many steps of training. None, the default means train forever. :type training_steps: int :param evaluation_steps: Perform this many steps of evaluation. None, the default means that evaluationruns until input from eval_input_fn is exhausted (or another exception is raised).Parameters: - checkpoint_path (str) – Identifies S3 location where checkpoint data during model training can be saved (default: None). For distributed model training, this parameter is required.
- py_version (str) – Python version you want to use for executing your model training code (default: ‘py2’).
- framework_version (str) – TensorFlow version you want to use for executing your model training code. List of supported versions https://github.com/aws/sagemaker-python-sdk#tensorflow-sagemaker-estimators
- requirements_file (str) – Path to a
requirements.txt
file (default: ‘’). The path should be within and relative tosource_dir
. Details on the format can be found in the Pip User Guide. - **kwargs – Additional kwargs passed to the Framework constructor.
-
fit
(inputs, wait=True, logs=True, job_name=None, run_tensorboard_locally=False)¶ Train a model using the input training dataset.
See
fit()
for more details.Parameters: - inputs (str or dict or sagemaker.session.s3_input) –
Information about the training data. This can be one of three types: (str) - the S3 location where training data is saved. (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
training data, you can specify a dict mapping channel names to strings ors3_input()
objects.- (sagemaker.session.s3_input) - channel configuration for S3 data sources that can provide
- additional information about the training dataset. See
sagemaker.session.s3_input()
for full details.
- wait (bool) – Whether the call should wait until the job completes (default: True).
- logs (bool) – Whether to show the logs produced by the job. Only meaningful when wait is True (default: True).
- job_name (str) – Training job name. If not specified, the estimator generates a default job name, based on the training image name and current timestamp.
- run_tensorboard_locally (bool) – Whether to execute TensorBoard in a different process with downloaded checkpoint information (default: False). This is an experimental feature, and requires TensorBoard and AWS CLI to be installed. It terminates TensorBoard when execution ends.
- inputs (str or dict or sagemaker.session.s3_input) –
-
train_image
()¶ Return the Docker image to use for training.
The
fit()
method, which does the model training, calls this method to find the image to use for model training.Returns: The URI of the Docker image. Return type: str
-
create_model
(model_server_workers=None)¶ Create a SageMaker
TensorFlowModel
object that can be deployed to anEndpoint
.Parameters: model_server_workers (int) – Optional. The number of worker processes used by the inference server. If None, server will use one worker per vCPU. Returns: - A SageMaker
TensorFlowModel
object. - See
TensorFlowModel()
for full details.
Return type: sagemaker.tensorflow.model.TensorFlowModel - A SageMaker
-
hyperparameters
()¶ Return hyperparameters used by your custom TensorFlow code during model training.
TensorFlow Model¶
-
class
sagemaker.tensorflow.model.
TensorFlowModel
(model_data, role, entry_point, image=None, py_version='py2', framework_version='1.8', predictor_cls=<class 'sagemaker.tensorflow.model.TensorFlowPredictor'>, model_server_workers=None, **kwargs)¶ Bases:
sagemaker.model.FrameworkModel
Initialize an TensorFlowModel.
Parameters: - model_data (str) – The S3 location of a SageMaker model data
.tar.gz
file. - role (str) – An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource.
- entry_point (str) – Path (absolute or relative) to the Python source file which should be executed as the entry point to model hosting. This should be compatible with either Python 2.7 or Python 3.5.
- image (str) – A Docker image URI (default: None). If not specified, a default image for TensorFlow will be used.
- py_version (str) – Python version you want to use for executing your model training code (default: ‘py2’).
- framework_version (str) – TensorFlow version you want to use for executing your model training code.
- predictor_cls (callable[str, sagemaker.session.Session]) – A function to call to create a predictor
with an endpoint name and SageMaker
Session
. If specified,deploy()
returns the result of invoking this function on the created endpoint name. - model_server_workers (int) – Optional. The number of worker processes used by the inference server. If None, server will use one worker per vCPU.
- **kwargs – Keyword arguments passed to the
FrameworkModel
initializer.
-
prepare_container_def
(instance_type)¶ Return a container definition with framework configuration set in model environment variables.
This also uploads user-supplied code to S3.
Parameters: instance_type (str) – The EC2 instance type to deploy this Model to. For example, ‘ml.p2.xlarge’. Returns: A container definition object usable with the CreateModel API. Return type: dict[str, str]
- model_data (str) – The S3 location of a SageMaker model data
TensorFlow Predictor¶
-
class
sagemaker.tensorflow.model.
TensorFlowPredictor
(endpoint_name, sagemaker_session=None)¶ Bases:
sagemaker.predictor.RealTimePredictor
A
RealTimePredictor
for inference against TensorFlow ``Endpoint``s.This is able to serialize Python lists, dictionaries, and numpy arrays to multidimensional tensors for MXNet inference
Initialize an
TensorFlowPredictor
.Parameters: - endpoint_name (str) – The name of the endpoint to perform inference on.
- sagemaker_session (sagemaker.session.Session) – Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain.