Estimators

A high level interface for SageMaker training

class sagemaker.estimator.Estimator(image_name, role, train_instance_count, train_instance_type, train_volume_size=30, train_max_run=86400, input_mode='File', output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, hyperparameters=None)

Bases: sagemaker.estimator.EstimatorBase

A generic Estimator to train using any supplied algorithm. This class is designed for use with algorithms that don’t have their own, custom class.

Initialize an Estimator instance.

Parameters:
  • image_name (str) – The container image to use for training.
  • role (str) – An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource.
  • train_instance_count (int) – Number of Amazon EC2 instances to use for training.
  • train_instance_type (str) – Type of EC2 instance to use for training, for example, ‘ml.c4.xlarge’.
  • train_volume_size (int) – Size in GB of the EBS volume to use for storing input data during training (default: 30). Must be large enough to store training data if File Mode is used (which is the default).
  • train_max_run (int) – Timeout in seconds for training (default: 24 * 60 * 60). After this amount of time Amazon SageMaker terminates the job regardless of its current status.
  • input_mode (str) –

    The input mode that the algorithm supports (default: ‘File’). Valid modes:

    • ’File’ - Amazon SageMaker copies the training dataset from the S3 location to a local directory.
    • ’Pipe’ - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe.
  • output_path (str) – S3 location for saving the trainig result (model artifacts and output files). If not specified, results are stored to a default bucket. If the bucket with the specific name does not exist, the estimator creates the bucket during the fit() method execution.
  • output_kms_key (str) – Optional. KMS key ID for encrypting the training output (default: None).
  • base_job_name (str) – Prefix for training job name when the fit() method launches. If not specified, the estimator generates a default job name, based on the training image name and current timestamp.
  • sagemaker_session (sagemaker.session.Session) – Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain.
  • hyperparameters (dict) – Dictionary containing the hyperparameters to initialize this estimator with.
train_image()

Returns the docker image to use for training.

The fit() method, that does the model training, calls this method to find the image to use for model training.

set_hyperparameters(**kwargs)
hyperparameters()

Returns the hyperparameters as a dictionary to use for training.

The fit() method, that does the model training, calls this method to find the hyperparameters you specified.

create_model(role=None, image=None, predictor_cls=None, serializer=None, deserializer=None, content_type=None, accept=None, **kwargs)

Create a model to deploy.

Parameters:
  • role (str) – The ExecutionRoleArn IAM Role ARN for the Model, which is also used during transform jobs. If not specified, the role from the Estimator will be used.
  • image (str) – An container image to use for deploying the model. Defaults to the image used for training.
  • predictor_cls (RealTimePredictor) – The predictor class to use when deploying the model.
  • serializer (callable) – Should accept a single argument, the input data, and return a sequence of bytes. May provide a content_type attribute that defines the endpoint request content type
  • deserializer (callable) – Should accept two arguments, the result data and the response content type, and return a sequence of bytes. May provide a content_type attribute that defines th endpoint response Accept content type.
  • content_type (str) – The invocation ContentType, overriding any content_type from the serializer
  • accept (str) – The invocation Accept, overriding any accept from the deserializer.
  • serializer, deserializer, content_type, and accept arguments are only used to define a default (The) –
  • They are ignored if an explicit predictor class is passed in. Other arguments (RealTimePredictor.) –
  • passed through to the Model class. (are) –

Returns: a Model ready for deployment.

classmethod attach(training_job_name, sagemaker_session=None)

Attach to an existing training job.

Create an Estimator bound to an existing training job, each subclass is responsible to implement _prepare_init_params_from_job_description() as this method delegates the actual conversion of a training job description to the arguments that the class constructor expects. After attaching, if the training job has a Complete status, it can be deploy() ed to create a SageMaker Endpoint and return a Predictor.

If the training job is in progress, attach will block and display log messages from the training job, until the training job completes.

Parameters:
  • training_job_name (str) – The name of the training job to attach to.
  • sagemaker_session (sagemaker.session.Session) – Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain.

Examples

>>> my_estimator.fit(wait=False)
>>> training_job_name = my_estimator.latest_training_job.name
Later on:
>>> attached_estimator = Estimator.attach(training_job_name)
>>> attached_estimator.deploy()
Returns:Instance of the calling Estimator Class with the attached training job.
delete_endpoint()

Delete an Amazon SageMaker Endpoint.

Raises:ValueError – If the endpoint does not exist.
deploy(initial_instance_count, instance_type, endpoint_name=None, **kwargs)

Deploy the trained model to an Amazon SageMaker endpoint and return a sagemaker.RealTimePredictor object.

More information: http://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-training.html

Parameters:
  • initial_instance_count (int) – Minimum number of EC2 instances to deploy to an endpoint for prediction.
  • instance_type (str) – Type of EC2 instance to deploy to an endpoint for prediction, for example, ‘ml.c4.xlarge’.
  • endpoint_name (str) – Name to use for creating an Amazon SageMaker endpoint. If not specified, the name of the training job is used.
  • **kwargs – Passed to invocation of create_model(). Implementations may customize create_model() to accept **kwargs to customize model creation during deploy. For more, see the implementation docs.
Returns:

A predictor that provides a predict() method,

which can be used to send requests to the Amazon SageMaker endpoint and obtain inferences.

Return type:

sagemaker.predictor.RealTimePredictor

fit(inputs, wait=True, logs=True, job_name=None)

Train a model using the input training dataset.

The API calls the Amazon SageMaker CreateTrainingJob API to start model training. The API uses configuration you provided to create the estimator and the specified input training data to send the CreatingTrainingJob request to Amazon SageMaker.

This is a synchronous operation. After the model training successfully completes, you can call the deploy() method to host the model using the Amazon SageMaker hosting services.

Parameters:
  • inputs (str or dict or sagemaker.session.s3_input) –

    Information about the training data. This can be one of three types:

    • (str) the S3 location where training data is saved.
    • (dict[str, str] or dict[str, sagemaker.session.s3_input]) If using multiple channels for
      training data, you can specify a dict mapping channel names to strings or s3_input() objects.
    • (sagemaker.session.s3_input) - channel configuration for S3 data sources that can provide
      additional information as well as the path to the training dataset. See sagemaker.session.s3_input() for full details.
  • wait (bool) – Whether the call should wait until the job completes (default: True).
  • logs (bool) – Whether to show the logs produced by the job. Only meaningful when wait is True (default: True).
  • job_name (str) – Training job name. If not specified, the estimator generates a default job name, based on the training image name and current timestamp.
model_data

str – The model location in S3. Only set if Estimator has been fit().

training_job_analytics

Return a TrainingJobAnalytics object for the current training job.

transformer(instance_count, instance_type, strategy=None, assemble_with=None, output_path=None, output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None, max_payload=None, tags=None, role=None)

Return a Transformer that uses a SageMaker Model based on the training job. It reuses the SageMaker Session and base job name used by the Estimator.

Parameters:
  • instance_count (int) – Number of EC2 instances to use.
  • instance_type (str) – Type of EC2 instance to use, for example, ‘ml.c4.xlarge’.
  • strategy (str) – The strategy used to decide how to batch records in a single request (default: None). Valid values: ‘MULTI_RECORD’ and ‘SINGLE_RECORD’.
  • assemble_with (str) – How the output is assembled (default: None). Valid values: ‘Line’ or ‘None’.
  • output_path (str) – S3 location for saving the transform result. If not specified, results are stored to a default bucket.
  • output_kms_key (str) – Optional. KMS key ID for encrypting the transform output (default: None).
  • accept (str) – The content type accepted by the endpoint deployed during the transform job.
  • env (dict) – Environment variables to be set for use during the transform job (default: None).
  • max_concurrent_transforms (int) – The maximum number of HTTP requests to be made to each individual transform container at one time.
  • max_payload (int) – Maximum size of the payload in a single HTTP request to the container in MB.
  • tags (list[dict]) – List of tags for labeling a transform job. If none specified, then the tags used for the training job are used for the transform job.
  • role (str) – The ExecutionRoleArn IAM Role ARN for the Model, which is also used during transform jobs. If not specified, the role from the Estimator will be used.