Estimators¶
A high level interface for SageMaker training
-
class
sagemaker.estimator.
EstimatorBase
(role, train_instance_count, train_instance_type, train_volume_size=30, train_volume_kms_key=None, train_max_run=86400, input_mode='File', output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None, subnets=None, security_group_ids=None, model_uri=None, model_channel_name='model', metric_definitions=None, encrypt_inter_container_traffic=False, train_use_spot_instances=False, train_max_wait=None, checkpoint_s3_uri=None, checkpoint_local_path=None, rules=None, debugger_hook_config=None, tensorboard_output_config=None, enable_sagemaker_metrics=None)¶ Bases:
object
Handle end-to-end Amazon SageMaker training and deployment tasks.
For introduction to model training and deployment, see http://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-training.html
Subclasses must define a way to determine what image to use for training, what hyperparameters to use, and how to create an appropriate predictor instance.
Initialize an
EstimatorBase
instance.Parameters: - role (str) – An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource.
- train_instance_count (int) – Number of Amazon EC2 instances to use for training.
- train_instance_type (str) – Type of EC2 instance to use for training, for example, ‘ml.c4.xlarge’.
- train_volume_size (int) – Size in GB of the EBS volume to use for storing input data during training (default: 30). Must be large enough to store training data if File Mode is used (which is the default).
- train_volume_kms_key (str) – Optional. KMS key ID for encrypting EBS volume attached to the training instance (default: None).
- train_max_run (int) – Timeout in seconds for training (default: 24 * 60 * 60). After this amount of time Amazon SageMaker terminates the job regardless of its current status.
- input_mode (str) – The input mode that the algorithm supports
(default: ‘File’). Valid modes: ‘File’ - Amazon SageMaker copies
the training dataset from the S3 location to a local directory.
‘Pipe’ - Amazon SageMaker streams data directly from S3 to the
container via a Unix-named pipe. This argument can be overriden
on a per-channel basis using
sagemaker.session.s3_input.input_mode
. - output_path (str) – S3 location for saving the training result (model
artifacts and output files). If not specified, results are
stored to a default bucket. If the bucket with the specific name
does not exist, the estimator creates the bucket during the
fit()
method execution. - output_kms_key (str) – Optional. KMS key ID for encrypting the training output (default: None).
- base_job_name (str) – Prefix for training job name when the
fit()
method launches. If not specified, the estimator generates a default job name, based on the training image name and current timestamp. - sagemaker_session (sagemaker.session.Session) – Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain.
- tags (list[dict]) – List of tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
- subnets (list[str]) – List of subnet ids. If not specified training job will be created without VPC config.
- security_group_ids (list[str]) – List of security group ids. If not specified training job will be created without VPC config.
- model_uri (str) –
URI where a pre-trained model is stored, either locally or in S3 (default: None). If specified, the estimator will create a channel pointing to the model so the training job can download it. This model can be a ‘model.tar.gz’ from a previous training job, or other artifacts coming from a different source.
In local mode, this should point to the path in which the model is located and not the file itself, as local Docker containers will try to mount the URI as a volume.
More information: https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html#td-deserialization
- model_channel_name (str) – Name of the channel where ‘model_uri’ will be downloaded (default: ‘model’).
- metric_definitions (list[dict]) – A list of dictionaries that defines the metric(s) used to evaluate the training jobs. Each dictionary contains two keys: ‘Name’ for the name of the metric, and ‘Regex’ for the regular expression used to extract the metric from the logs. This should be defined only for jobs that don’t use an Amazon algorithm.
- encrypt_inter_container_traffic (bool) – Specifies whether traffic
between training containers is encrypted for the training job
(default:
False
). - train_use_spot_instances (bool) –
Specifies whether to use SageMaker Managed Spot instances for training. If enabled then the train_max_wait arg should also be set.
More information: https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html (default:
False
). - train_max_wait (int) – Timeout in seconds waiting for spot training
instances (default: None). After this amount of time Amazon
SageMaker will stop waiting for Spot instances to become
available (default:
None
). - checkpoint_s3_uri (str) – The S3 URI in which to persist checkpoints
that the algorithm persists (if any) during training. (default:
None
). - checkpoint_local_path (str) – The local path that the algorithm
writes its checkpoints to. SageMaker will persist all files
under this path to checkpoint_s3_uri continually during
training. On job startup the reverse happens - data from the
s3 location is downloaded to this path before the algorithm is
started. If the path is unset then SageMaker assumes the
checkpoints will be provided under /opt/ml/checkpoints/.
(default:
None
). - enable_sagemaker_metrics (bool) – enable SageMaker Metrics Time
Series. For more information see:
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
(default:
None
).
-
train_image
()¶ Return the Docker image to use for training.
The
fit()
method, which does the model training, calls this method to find the image to use for model training.Returns: The URI of the Docker image. Return type: str
-
hyperparameters
()¶ Return the hyperparameters as a dictionary to use for training.
The
fit()
method, which trains the model, calls this method to find the hyperparameters.Returns: The hyperparameters. Return type: dict[str, str]
-
enable_network_isolation
()¶ Return True if this Estimator will need network isolation to run.
Returns: Whether this Estimator needs network isolation or not. Return type: bool
-
prepare_workflow_for_training
(job_name=None)¶ Calls _prepare_for_training. Used when setting up a workflow.
Parameters: job_name (str) – Name of the training job to be created. If not specified, one is generated, using the base name given to the constructor if applicable.
-
latest_job_debugger_artifacts_path
()¶ Gets the path to the DebuggerHookConfig output artifacts.
Returns: An S3 path to the output artifacts. Return type: str
-
latest_job_tensorboard_artifacts_path
()¶ Gets the path to the TensorBoardOutputConfig output artifacts.
Returns: An S3 path to the output artifacts. Return type: str
-
fit
(inputs=None, wait=True, logs='All', job_name=None, experiment_config=None)¶ Train a model using the input training dataset.
The API calls the Amazon SageMaker CreateTrainingJob API to start model training. The API uses configuration you provided to create the estimator and the specified input training data to send the CreatingTrainingJob request to Amazon SageMaker.
This is a synchronous operation. After the model training successfully completes, you can call the
deploy()
method to host the model using the Amazon SageMaker hosting services.Parameters: - inputs (str or dict or sagemaker.session.s3_input) –
Information about the training data. This can be one of three types:
- (str) the S3 location where training data is saved.
- (dict[str, str] or dict[str, sagemaker.session.s3_input]) If using multiple
- channels for training data, you can specify a dict mapping channel names to
strings or
s3_input()
objects.
- (sagemaker.session.s3_input) - channel configuration for S3 data sources that can
- provide additional information as well as the path to the training dataset.
See
sagemaker.session.s3_input()
for full details.
- (sagemaker.session.FileSystemInput) - channel configuration for
- a file system data source that can provide additional information as well as the path to the training dataset.
- wait (bool) – Whether the call should wait until the job completes (default: True).
- logs ([str]) – A list of strings specifying which logs to print. Acceptable strings are “All”, “None”, “Training”, or “Rules”. To maintain backwards compatibility, boolean values are also accepted and converted to strings. Only meaningful when wait is True.
- job_name (str) – Training job name. If not specified, the estimator generates a default job name, based on the training image name and current timestamp.
- experiment_config (dict[str, str]) – Experiment management configuration. Dictionary contains three optional keys, ‘ExperimentName’, ‘TrialName’, and ‘TrialComponentDisplayName’.
- inputs (str or dict or sagemaker.session.s3_input) –
-
compile_model
(target_instance_family, input_shape, output_path, framework=None, framework_version=None, compile_max_run=300, tags=None, **kwargs)¶ Compile a Neo model using the input model.
Parameters: - target_instance_family (str) – Identifies the device that you want to run your model after compilation, for example: ml_c5. For allowed strings see https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
- input_shape (dict) – Specifies the name and shape of the expected inputs for your trained model in json dictionary form, for example: {‘data’:[1,3,1024,1024]}, or {‘var1’: [1,1,28,28], ‘var2’:[1,1,28,28]}
- output_path (str) – Specifies where to store the compiled model
- framework (str) – The framework that is used to train the original model. Allowed values: ‘mxnet’, ‘tensorflow’, ‘keras’, ‘pytorch’, ‘onnx’, ‘xgboost’
- framework_version (str) – The version of the framework
- compile_max_run (int) – Timeout in seconds for compilation (default: 3 * 60). After this amount of time Amazon SageMaker Neo terminates the compilation job regardless of its current status.
- tags (list[dict]) – List of tags for labeling a compilation job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
- **kwargs – Passed to invocation of
create_model()
. Implementations may customizecreate_model()
to accept**kwargs
to customize model creation during deploy. For more, see the implementation docs.
Returns: A SageMaker
Model
object. SeeModel()
for full details.Return type:
-
classmethod
attach
(training_job_name, sagemaker_session=None, model_channel_name='model')¶ Attach to an existing training job.
Create an Estimator bound to an existing training job, each subclass is responsible to implement
_prepare_init_params_from_job_description()
as this method delegates the actual conversion of a training job description to the arguments that the class constructor expects. After attaching, if the training job has a Complete status, it can bedeploy()
ed to create a SageMaker Endpoint and return aPredictor
.If the training job is in progress, attach will block and display log messages from the training job, until the training job completes.
Examples
>>> my_estimator.fit(wait=False) >>> training_job_name = my_estimator.latest_training_job.name Later on: >>> attached_estimator = Estimator.attach(training_job_name) >>> attached_estimator.deploy()
Parameters: - training_job_name (str) – The name of the training job to attach to.
- sagemaker_session (sagemaker.session.Session) – Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain.
- model_channel_name (str) – Name of the channel where pre-trained model data will be downloaded (default: ‘model’). If no channel with the same name exists in the training job, this option will be ignored.
Returns: Instance of the calling
Estimator
Class with the attached training job.
-
deploy
(initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None, use_compiled_model=False, update_endpoint=False, wait=True, model_name=None, kms_key=None, data_capture_config=None, tags=None, **kwargs)¶ Deploy the trained model to an Amazon SageMaker endpoint and return a
sagemaker.RealTimePredictor
object.More information: http://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-training.html
Parameters: - initial_instance_count (int) – Minimum number of EC2 instances to deploy to an endpoint for prediction.
- instance_type (str) – Type of EC2 instance to deploy to an endpoint for prediction, for example, ‘ml.c4.xlarge’.
- accelerator_type (str) – Type of Elastic Inference accelerator to attach to an endpoint for model loading and inference, for example, ‘ml.eia1.medium’. If not specified, no Elastic Inference accelerator will be attached to the endpoint. For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
- endpoint_name (str) – Name to use for creating an Amazon SageMaker endpoint. If not specified, the name of the training job is used.
- use_compiled_model (bool) – Flag to select whether to use compiled (optimized) model. Default: False.
- update_endpoint (bool) – Flag to update the model in an existing Amazon SageMaker endpoint. If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources corresponding to the previous EndpointConfig. Default: False
- wait (bool) – Whether the call should wait until the deployment of model completes (default: True).
- model_name (str) – Name to use for creating an Amazon SageMaker model. If not specified, the name of the training job is used.
- kms_key (str) – The ARN of the KMS key that is used to encrypt the data on the storage volume attached to the instance hosting the endpoint.
- data_capture_config (sagemaker.model_monitor.DataCaptureConfig) – Specifies configuration related to Endpoint data capture for use with Amazon SageMaker Model Monitoring. Default: None.
- tags (List[dict[str, str]]) – Optional. The list of tags to attach to this specific endpoint. Example: >>> tags = [{‘Key’: ‘tagname’, ‘Value’: ‘tagvalue’}] For more information about tags, see https://boto3.amazonaws.com/v1/documentation /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
- **kwargs – Passed to invocation of
create_model()
. Implementations may customizecreate_model()
to accept**kwargs
to customize model creation during deploy. For more, see the implementation docs.
Returns: - A predictor that provides a
predict()
method, which can be used to send requests to the Amazon SageMaker endpoint and obtain inferences.
Return type:
-
model_data
¶ str – The model location in S3. Only set if Estimator has been
fit()
.
-
create_model
(**kwargs)¶ Create a SageMaker
Model
object that can be deployed to anEndpoint
.Parameters: **kwargs – Keyword arguments used by the implemented method for creating the Model
.Returns: A SageMaker Model
object. SeeModel()
for full details.Return type: sagemaker.model.Model
-
delete_endpoint
()¶ Delete an Amazon SageMaker
Endpoint
.Raises: botocore.exceptions.ClientError
– If the endpoint does not exist.
-
transformer
(instance_count, instance_type, strategy=None, assemble_with=None, output_path=None, output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None, max_payload=None, tags=None, role=None, volume_kms_key=None, vpc_config_override='VPC_CONFIG_DEFAULT')¶ Return a
Transformer
that uses a SageMaker Model based on the training job. It reuses the SageMaker Session and base job name used by the Estimator.Parameters: - instance_count (int) – Number of EC2 instances to use.
- instance_type (str) – Type of EC2 instance to use, for example, ‘ml.c4.xlarge’.
- strategy (str) – The strategy used to decide how to batch records in a single request (default: None). Valid values: ‘MULTI_RECORD’ and ‘SINGLE_RECORD’.
- assemble_with (str) – How the output is assembled (default: None). Valid values: ‘Line’ or ‘None’.
- output_path (str) – S3 location for saving the transform result. If not specified, results are stored to a default bucket.
- output_kms_key (str) – Optional. KMS key ID for encrypting the transform output (default: None).
- accept (str) – The accept header passed by the client to the inference endpoint. If it is supported by the endpoint, it will be the format of the batch transform output.
- env (dict) – Environment variables to be set for use during the transform job (default: None).
- max_concurrent_transforms (int) – The maximum number of HTTP requests to be made to each individual transform container at one time.
- max_payload (int) – Maximum size of the payload in a single HTTP request to the container in MB.
- tags (list[dict]) – List of tags for labeling a transform job. If none specified, then the tags used for the training job are used for the transform job.
- role (str) – The
ExecutionRoleArn
IAM Role ARN for theModel
, which is also used during transform jobs. If not specified, the role from the Estimator will be used. - volume_kms_key (str) – Optional. KMS key ID for encrypting the volume attached to the ML compute instance (default: None).
- vpc_config_override (dict[str, list[str]]) – Optional override for the VpcConfig set on the model. Default: use subnets and security groups from this Estimator. * ‘Subnets’ (list[str]): List of subnet ids. * ‘SecurityGroupIds’ (list[str]): List of security group ids.
-
training_job_analytics
¶ Return a
TrainingJobAnalytics
object for the current training job.
-
get_vpc_config
(vpc_config_override='VPC_CONFIG_DEFAULT')¶ Returns VpcConfig dict either from this Estimator’s subnets and security groups, or else validate and return an optional override value.
Parameters: vpc_config_override –
-
class
sagemaker.estimator.
Estimator
(image_name, role, train_instance_count, train_instance_type, train_volume_size=30, train_volume_kms_key=None, train_max_run=86400, input_mode='File', output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, hyperparameters=None, tags=None, subnets=None, security_group_ids=None, model_uri=None, model_channel_name='model', metric_definitions=None, encrypt_inter_container_traffic=False, train_use_spot_instances=False, train_max_wait=None, checkpoint_s3_uri=None, checkpoint_local_path=None, enable_network_isolation=False, rules=None, debugger_hook_config=None, tensorboard_output_config=None, enable_sagemaker_metrics=None)¶ Bases:
sagemaker.estimator.EstimatorBase
A generic Estimator to train using any supplied algorithm. This class is designed for use with algorithms that don’t have their own, custom class.
Initialize an
Estimator
instance.Parameters: - image_name (str) – The container image to use for training.
- role (str) – An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource.
- train_instance_count (int) – Number of Amazon EC2 instances to use for training.
- train_instance_type (str) – Type of EC2 instance to use for training, for example, ‘ml.c4.xlarge’.
- train_volume_size (int) – Size in GB of the EBS volume to use for storing input data during training (default: 30). Must be large enough to store training data if File Mode is used (which is the default).
- train_volume_kms_key (str) – Optional. KMS key ID for encrypting EBS volume attached to the training instance (default: None).
- train_max_run (int) – Timeout in seconds for training (default: 24 * 60 * 60). After this amount of time Amazon SageMaker terminates the job regardless of its current status.
- input_mode (str) –
The input mode that the algorithm supports (default: ‘File’). Valid modes:
- ’File’ - Amazon SageMaker copies the training dataset from the S3 location to a local directory.
- ’Pipe’ - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe.
This argument can be overriden on a per-channel basis using
sagemaker.session.s3_input.input_mode
. - output_path (str) – S3 location for saving the training result (model
artifacts and output files). If not specified, results are
stored to a default bucket. If the bucket with the specific name
does not exist, the estimator creates the bucket during the
fit()
method execution. - output_kms_key (str) – Optional. KMS key ID for encrypting the training output (default: None).
- base_job_name (str) – Prefix for training job name when the
fit()
method launches. If not specified, the estimator generates a default job name, based on the training image name and current timestamp. - sagemaker_session (sagemaker.session.Session) – Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain.
- hyperparameters (dict) – Dictionary containing the hyperparameters to initialize this estimator with.
- tags (list[dict]) – List of tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
- subnets (list[str]) – List of subnet ids. If not specified training job will be created without VPC config.
- security_group_ids (list[str]) – List of security group ids. If not specified training job will be created without VPC config.
- model_uri (str) –
URI where a pre-trained model is stored, either locally or in S3 (default: None). If specified, the estimator will create a channel pointing to the model so the training job can download it. This model can be a ‘model.tar.gz’ from a previous training job, or other artifacts coming from a different source.
In local mode, this should point to the path in which the model is located and not the file itself, as local Docker containers will try to mount the URI as a volume.
More information: https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html#td-deserialization
- model_channel_name (str) – Name of the channel where ‘model_uri’ will be downloaded (default: ‘model’).
- metric_definitions (list[dict]) – A list of dictionaries that defines the metric(s) used to evaluate the training jobs. Each dictionary contains two keys: ‘Name’ for the name of the metric, and ‘Regex’ for the regular expression used to extract the metric from the logs. This should be defined only for jobs that don’t use an Amazon algorithm.
- encrypt_inter_container_traffic (bool) – Specifies whether traffic
between training containers is encrypted for the training job
(default:
False
). - train_use_spot_instances (bool) –
Specifies whether to use SageMaker Managed Spot instances for training. If enabled then the train_max_wait arg should also be set.
More information: https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html (default:
False
). - train_max_wait (int) – Timeout in seconds waiting for spot training
instances (default: None). After this amount of time Amazon
SageMaker will stop waiting for Spot instances to become
available (default:
None
). - checkpoint_s3_uri (str) – The S3 URI in which to persist checkpoints
that the algorithm persists (if any) during training. (default:
None
). - checkpoint_local_path (str) – The local path that the algorithm
writes its checkpoints to. SageMaker will persist all files
under this path to checkpoint_s3_uri continually during
training. On job startup the reverse happens - data from the
s3 location is downloaded to this path before the algorithm is
started. If the path is unset then SageMaker assumes the
checkpoints will be provided under /opt/ml/checkpoints/.
(default:
None
). - enable_network_isolation (bool) – Specifies whether container will
run in network isolation mode. Network isolation mode restricts
the container access to outside networks (such as the Internet).
The container does not make any inbound or outbound network
calls. If
True
, a channel named “code” will be created for any user entry script for training. The user entry script, files in source_dir (if specified), and dependencies will be uploaded in a tar to S3. Also known as internet-free mode (default:False
). - enable_sagemaker_metrics (bool) – enable SageMaker Metrics Time
Series. For more information see:
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
(default:
None
).
-
enable_network_isolation
()¶ If this Estimator can use network isolation when running.
Returns: Whether this Estimator can use network isolation or not. Return type: bool
-
train_image
()¶ Returns the docker image to use for training.
The fit() method, that does the model training, calls this method to find the image to use for model training.
-
set_hyperparameters
(**kwargs)¶ Parameters: **kwargs –
-
hyperparameters
()¶ Returns the hyperparameters as a dictionary to use for training.
The fit() method, that does the model training, calls this method to find the hyperparameters you specified.
-
create_model
(role=None, image=None, predictor_cls=None, serializer=None, deserializer=None, content_type=None, accept=None, vpc_config_override='VPC_CONFIG_DEFAULT', **kwargs)¶ Create a model to deploy.
The serializer, deserializer, content_type, and accept arguments are only used to define a default RealTimePredictor. They are ignored if an explicit predictor class is passed in. Other arguments are passed through to the Model class.
Parameters: - role (str) – The
ExecutionRoleArn
IAM Role ARN for theModel
, which is also used during transform jobs. If not specified, the role from the Estimator will be used. - image (str) – An container image to use for deploying the model. Defaults to the image used for training.
- predictor_cls (RealTimePredictor) – The predictor class to use when deploying the model.
- serializer (callable) – Should accept a single argument, the input data, and return a sequence of bytes. May provide a content_type attribute that defines the endpoint request content type
- deserializer (callable) – Should accept two arguments, the result data and the response content type, and return a sequence of bytes. May provide a content_type attribute that defines th endpoint response Accept content type.
- content_type (str) – The invocation ContentType, overriding any content_type from the serializer
- accept (str) – The invocation Accept, overriding any accept from the deserializer.
- vpc_config_override (dict[str, list[str]]) – Optional override for VpcConfig set on the model. Default: use subnets and security groups from this Estimator. * ‘Subnets’ (list[str]): List of subnet ids. * ‘SecurityGroupIds’ (list[str]): List of security group ids.
- **kwargs – Additional parameters passed to
Model
Tip
You can find additional parameters for using this method at
Model
.Returns: (sagemaker.model.Model) a Model ready for deployment. - role (str) – The
-
classmethod
attach
(training_job_name, sagemaker_session=None, model_channel_name='model')¶ Attach to an existing training job.
Create an Estimator bound to an existing training job, each subclass is responsible to implement
_prepare_init_params_from_job_description()
as this method delegates the actual conversion of a training job description to the arguments that the class constructor expects. After attaching, if the training job has a Complete status, it can bedeploy()
ed to create a SageMaker Endpoint and return aPredictor
.If the training job is in progress, attach will block and display log messages from the training job, until the training job completes.
Examples
>>> my_estimator.fit(wait=False) >>> training_job_name = my_estimator.latest_training_job.name Later on: >>> attached_estimator = Estimator.attach(training_job_name) >>> attached_estimator.deploy()
Parameters: - training_job_name (str) – The name of the training job to attach to.
- sagemaker_session (sagemaker.session.Session) – Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain.
- model_channel_name (str) – Name of the channel where pre-trained model data will be downloaded (default: ‘model’). If no channel with the same name exists in the training job, this option will be ignored.
Returns: Instance of the calling
Estimator
Class with the attached training job.
-
compile_model
(target_instance_family, input_shape, output_path, framework=None, framework_version=None, compile_max_run=300, tags=None, **kwargs)¶ Compile a Neo model using the input model.
Parameters: - target_instance_family (str) – Identifies the device that you want to run your model after compilation, for example: ml_c5. For allowed strings see https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
- input_shape (dict) – Specifies the name and shape of the expected inputs for your trained model in json dictionary form, for example: {‘data’:[1,3,1024,1024]}, or {‘var1’: [1,1,28,28], ‘var2’:[1,1,28,28]}
- output_path (str) – Specifies where to store the compiled model
- framework (str) – The framework that is used to train the original model. Allowed values: ‘mxnet’, ‘tensorflow’, ‘keras’, ‘pytorch’, ‘onnx’, ‘xgboost’
- framework_version (str) – The version of the framework
- compile_max_run (int) – Timeout in seconds for compilation (default: 3 * 60). After this amount of time Amazon SageMaker Neo terminates the compilation job regardless of its current status.
- tags (list[dict]) – List of tags for labeling a compilation job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
- **kwargs – Passed to invocation of
create_model()
. Implementations may customizecreate_model()
to accept**kwargs
to customize model creation during deploy. For more, see the implementation docs.
Returns: A SageMaker
Model
object. SeeModel()
for full details.Return type:
-
delete_endpoint
()¶ Delete an Amazon SageMaker
Endpoint
.Raises: botocore.exceptions.ClientError
– If the endpoint does not exist.
-
deploy
(initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None, use_compiled_model=False, update_endpoint=False, wait=True, model_name=None, kms_key=None, data_capture_config=None, tags=None, **kwargs)¶ Deploy the trained model to an Amazon SageMaker endpoint and return a
sagemaker.RealTimePredictor
object.More information: http://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-training.html
Parameters: - initial_instance_count (int) – Minimum number of EC2 instances to deploy to an endpoint for prediction.
- instance_type (str) – Type of EC2 instance to deploy to an endpoint for prediction, for example, ‘ml.c4.xlarge’.
- accelerator_type (str) – Type of Elastic Inference accelerator to attach to an endpoint for model loading and inference, for example, ‘ml.eia1.medium’. If not specified, no Elastic Inference accelerator will be attached to the endpoint. For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
- endpoint_name (str) – Name to use for creating an Amazon SageMaker endpoint. If not specified, the name of the training job is used.
- use_compiled_model (bool) – Flag to select whether to use compiled (optimized) model. Default: False.
- update_endpoint (bool) – Flag to update the model in an existing Amazon SageMaker endpoint. If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources corresponding to the previous EndpointConfig. Default: False
- wait (bool) – Whether the call should wait until the deployment of model completes (default: True).
- model_name (str) – Name to use for creating an Amazon SageMaker model. If not specified, the name of the training job is used.
- kms_key (str) – The ARN of the KMS key that is used to encrypt the data on the storage volume attached to the instance hosting the endpoint.
- data_capture_config (sagemaker.model_monitor.DataCaptureConfig) – Specifies configuration related to Endpoint data capture for use with Amazon SageMaker Model Monitoring. Default: None.
- tags (List[dict[str, str]]) – Optional. The list of tags to attach to this specific endpoint. Example: >>> tags = [{‘Key’: ‘tagname’, ‘Value’: ‘tagvalue’}] For more information about tags, see https://boto3.amazonaws.com/v1/documentation /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
- **kwargs – Passed to invocation of
create_model()
. Implementations may customizecreate_model()
to accept**kwargs
to customize model creation during deploy. For more, see the implementation docs.
Returns: - A predictor that provides a
predict()
method, which can be used to send requests to the Amazon SageMaker endpoint and obtain inferences.
Return type:
-
fit
(inputs=None, wait=True, logs='All', job_name=None, experiment_config=None)¶ Train a model using the input training dataset.
The API calls the Amazon SageMaker CreateTrainingJob API to start model training. The API uses configuration you provided to create the estimator and the specified input training data to send the CreatingTrainingJob request to Amazon SageMaker.
This is a synchronous operation. After the model training successfully completes, you can call the
deploy()
method to host the model using the Amazon SageMaker hosting services.Parameters: - inputs (str or dict or sagemaker.session.s3_input) –
Information about the training data. This can be one of three types:
- (str) the S3 location where training data is saved.
- (dict[str, str] or dict[str, sagemaker.session.s3_input]) If using multiple
- channels for training data, you can specify a dict mapping channel names to
strings or
s3_input()
objects.
- (sagemaker.session.s3_input) - channel configuration for S3 data sources that can
- provide additional information as well as the path to the training dataset.
See
sagemaker.session.s3_input()
for full details.
- (sagemaker.session.FileSystemInput) - channel configuration for
- a file system data source that can provide additional information as well as the path to the training dataset.
- wait (bool) – Whether the call should wait until the job completes (default: True).
- logs ([str]) – A list of strings specifying which logs to print. Acceptable strings are “All”, “None”, “Training”, or “Rules”. To maintain backwards compatibility, boolean values are also accepted and converted to strings. Only meaningful when wait is True.
- job_name (str) – Training job name. If not specified, the estimator generates a default job name, based on the training image name and current timestamp.
- experiment_config (dict[str, str]) – Experiment management configuration. Dictionary contains three optional keys, ‘ExperimentName’, ‘TrialName’, and ‘TrialComponentDisplayName’.
- inputs (str or dict or sagemaker.session.s3_input) –
-
get_vpc_config
(vpc_config_override='VPC_CONFIG_DEFAULT')¶ Returns VpcConfig dict either from this Estimator’s subnets and security groups, or else validate and return an optional override value.
Parameters: vpc_config_override –
-
latest_job_debugger_artifacts_path
()¶ Gets the path to the DebuggerHookConfig output artifacts.
Returns: An S3 path to the output artifacts. Return type: str
-
latest_job_tensorboard_artifacts_path
()¶ Gets the path to the TensorBoardOutputConfig output artifacts.
Returns: An S3 path to the output artifacts. Return type: str
-
model_data
¶ str – The model location in S3. Only set if Estimator has been
fit()
.
-
prepare_workflow_for_training
(job_name=None)¶ Calls _prepare_for_training. Used when setting up a workflow.
Parameters: job_name (str) – Name of the training job to be created. If not specified, one is generated, using the base name given to the constructor if applicable.
-
training_job_analytics
¶ Return a
TrainingJobAnalytics
object for the current training job.
-
transformer
(instance_count, instance_type, strategy=None, assemble_with=None, output_path=None, output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None, max_payload=None, tags=None, role=None, volume_kms_key=None, vpc_config_override='VPC_CONFIG_DEFAULT')¶ Return a
Transformer
that uses a SageMaker Model based on the training job. It reuses the SageMaker Session and base job name used by the Estimator.Parameters: - instance_count (int) – Number of EC2 instances to use.
- instance_type (str) – Type of EC2 instance to use, for example, ‘ml.c4.xlarge’.
- strategy (str) – The strategy used to decide how to batch records in a single request (default: None). Valid values: ‘MULTI_RECORD’ and ‘SINGLE_RECORD’.
- assemble_with (str) – How the output is assembled (default: None). Valid values: ‘Line’ or ‘None’.
- output_path (str) – S3 location for saving the transform result. If not specified, results are stored to a default bucket.
- output_kms_key (str) – Optional. KMS key ID for encrypting the transform output (default: None).
- accept (str) – The accept header passed by the client to the inference endpoint. If it is supported by the endpoint, it will be the format of the batch transform output.
- env (dict) – Environment variables to be set for use during the transform job (default: None).
- max_concurrent_transforms (int) – The maximum number of HTTP requests to be made to each individual transform container at one time.
- max_payload (int) – Maximum size of the payload in a single HTTP request to the container in MB.
- tags (list[dict]) – List of tags for labeling a transform job. If none specified, then the tags used for the training job are used for the transform job.
- role (str) – The
ExecutionRoleArn
IAM Role ARN for theModel
, which is also used during transform jobs. If not specified, the role from the Estimator will be used. - volume_kms_key (str) – Optional. KMS key ID for encrypting the volume attached to the ML compute instance (default: None).
- vpc_config_override (dict[str, list[str]]) – Optional override for the VpcConfig set on the model. Default: use subnets and security groups from this Estimator. * ‘Subnets’ (list[str]): List of subnet ids. * ‘SecurityGroupIds’ (list[str]): List of security group ids.
-
class
sagemaker.estimator.
Framework
(entry_point, source_dir=None, hyperparameters=None, enable_cloudwatch_metrics=False, container_log_level=20, code_location=None, image_name=None, dependencies=None, enable_network_isolation=False, git_config=None, checkpoint_s3_uri=None, checkpoint_local_path=None, enable_sagemaker_metrics=None, **kwargs)¶ Bases:
sagemaker.estimator.EstimatorBase
Base class that cannot be instantiated directly.
Subclasses define functionality pertaining to specific ML frameworks, such as training/deployment images and predictor instances.
Base class initializer. Subclasses which override
__init__
should invokesuper()
Parameters: - entry_point (str) –
Path (absolute or relative) to the local Python source file which should be executed as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5. If ‘git_config’ is provided, ‘entry_point’ should be a relative location to the Python source file in the Git repo. Example
With the following GitHub repo directory structure:>>> |----- README.md >>> |----- src >>> |----- train.py >>> |----- test.py
You can assign entry_point=’src/train.py’.
- source_dir (str) –
Path (absolute, relative, or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). Structure within this directory are preserved when training on Amazon SageMaker. If ‘git_config’ is provided, ‘source_dir’ should be a relative location to a directory in the Git repo. .. admonition:: Example
With the following GitHub repo directory structure:>>> |----- README.md >>> |----- src >>> |----- train.py >>> |----- test.py
and you need ‘train.py’ as entry point and ‘test.py’ as training source code as well, you can assign entry_point=’train.py’, source_dir=’src’.
- hyperparameters (dict) – Hyperparameters that will be used for
training (default: None). The hyperparameters are made
accessible as a dict[str, str] to the training code on
SageMaker. For convenience, this accepts other types for keys
and values, but
str()
will be called to convert them before training. - enable_cloudwatch_metrics (bool) – [DEPRECATED] Now there are cloudwatch metrics emitted by all SageMaker training jobs. This will be ignored for now and removed in a further release.
- container_log_level (int) – Log level to use within the container (default: logging.INFO). Valid values are defined in the Python logging module.
- code_location (str) – The S3 prefix URI where custom code will be
uploaded (default: None) - don’t include a trailing slash since
a string prepended with a “/” is appended to
code_location
. The code file uploaded to S3 is ‘code_location/job-name/source/sourcedir.tar.gz’. If not specified, the defaultcode location
is s3://default_bucket/job-name/. - image_name (str) – An alternate image name to use instead of the official Sagemaker image for the framework. This is useful to run one of the Sagemaker supported frameworks with an image containing custom dependencies.
- dependencies (list[str]) –
A list of paths to directories (absolute or relative) with any additional libraries that will be exported to the container (default: []). The library folders will be copied to SageMaker in the same folder where the entrypoint is copied. If ‘git_config’ is provided, ‘dependencies’ should be a list of relative locations to directories with any additional libraries needed in the Git repo. .. admonition:: Example
The following call >>> Estimator(entry_point=’train.py’, dependencies=[‘my/libs/common’, ‘virtual-env’]) results in the following inside the container:>>> $ ls
>>> opt/ml/code >>> |------ train.py >>> |------ common >>> |------ virtual-env
- enable_network_isolation (bool) – Specifies whether container will run in network isolation mode. Network isolation mode restricts the container access to outside networks (such as the internet). The container does not make any inbound or outbound network calls. If True, a channel named “code” will be created for any user entry script for training. The user entry script, files in source_dir (if specified), and dependencies will be uploaded in a tar to S3. Also known as internet-free mode (default: False ).
- git_config (dict[str, str]) –
Git configurations used for cloning files, including
repo
,branch
,commit
,2FA_enabled
,username
,password
andtoken
. Therepo
field is required. All other fields are optional.repo
specifies the Git repository where your training script is stored. If you don’t providebranch
, the default value ‘master’ is used. If you don’t providecommit
, the latest commit in the specified branch is used. .. admonition:: ExampleThe following config:>>> git_config = {'repo': 'https://github.com/aws/sagemaker-python-sdk.git', >>> 'branch': 'test-branch-git-config', >>> 'commit': '329bfcf884482002c05ff7f44f62599ebc9f445a'}
results in cloning the repo specified in ‘repo’, then checkout the ‘master’ branch, and checkout the specified commit.
2FA_enabled
,username
,password
andtoken
are used for authentication. For GitHub (or other Git) accounts, set2FA_enabled
to ‘True’ if two-factor authentication is enabled for the account, otherwise set it to ‘False’. If you do not provide a value for2FA_enabled
, a default value of ‘False’ is used. CodeCommit does not support two-factor authentication, so do not provide “2FA_enabled” with CodeCommit repositories.For GitHub and other Git repos, when SSH URLs are provided, it doesn’t matter whether 2FA is enabled or disabled; you should either have no passphrase for the SSH key pairs, or have the ssh-agent configured so that you will not be prompted for SSH passphrase when you do ‘git clone’ command with SSH URLs. When HTTPS URLs are provided: if 2FA is disabled, then either token or username+password will be used for authentication if provided (token prioritized); if 2FA is enabled, only token will be used for authentication if provided. If required authentication info is not provided, python SDK will try to use local credentials storage to authenticate. If that fails either, an error message will be thrown.
For CodeCommit repos, 2FA is not supported, so ‘2FA_enabled’ should not be provided. There is no token in CodeCommit, so ‘token’ should not be provided too. When ‘repo’ is an SSH URL, the requirements are the same as GitHub-like repos. When ‘repo’ is an HTTPS URL, username+password will be used for authentication if they are provided; otherwise, python SDK will try to use either CodeCommit credential helper or local credential storage for authentication.
- checkpoint_s3_uri (str) – The S3 URI in which to persist checkpoints
that the algorithm persists (if any) during training. (default:
None
). - checkpoint_local_path (str) – The local path that the algorithm
writes its checkpoints to. SageMaker will persist all files
under this path to checkpoint_s3_uri continually during
training. On job startup the reverse happens - data from the
s3 location is downloaded to this path before the algorithm is
started. If the path is unset then SageMaker assumes the
checkpoints will be provided under /opt/ml/checkpoints/.
(default:
None
). - enable_sagemaker_metrics (bool) – enable SageMaker Metrics Time
Series. For more information see:
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
(default:
None
). - **kwargs – Additional kwargs passed to the
EstimatorBase
constructor.
Tip
You can find additional parameters for initializing this class at
EstimatorBase
.-
LAUNCH_PS_ENV_NAME
= 'sagemaker_parameter_server_enabled'¶
-
LAUNCH_MPI_ENV_NAME
= 'sagemaker_mpi_enabled'¶
-
MPI_NUM_PROCESSES_PER_HOST
= 'sagemaker_mpi_num_of_processes_per_host'¶
-
MPI_CUSTOM_MPI_OPTIONS
= 'sagemaker_mpi_custom_mpi_options'¶
-
CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH
= '/opt/ml/input/data/code/sourcedir.tar.gz'¶
-
enable_network_isolation
()¶ Return True if this Estimator can use network isolation to run.
Returns: Whether this Estimator can use network isolation or not. Return type: bool
-
hyperparameters
()¶ Return the hyperparameters as a dictionary to use for training.
The
fit()
method, which trains the model, calls this method to find the hyperparameters.Returns: The hyperparameters. Return type: dict[str, str]
-
train_image
()¶ Return the Docker image to use for training.
The
fit()
method, which does the model training, calls this method to find the image to use for model training.Returns: The URI of the Docker image. Return type: str
-
classmethod
attach
(training_job_name, sagemaker_session=None, model_channel_name='model')¶ Attach to an existing training job.
Create an Estimator bound to an existing training job, each subclass is responsible to implement
_prepare_init_params_from_job_description()
as this method delegates the actual conversion of a training job description to the arguments that the class constructor expects. After attaching, if the training job has a Complete status, it can bedeploy()
ed to create a SageMaker Endpoint and return aPredictor
.If the training job is in progress, attach will block and display log messages from the training job, until the training job completes.
Examples
>>> my_estimator.fit(wait=False) >>> training_job_name = my_estimator.latest_training_job.name Later on: >>> attached_estimator = Estimator.attach(training_job_name) >>> attached_estimator.deploy()
Parameters: - training_job_name (str) – The name of the training job to attach to.
- sagemaker_session (sagemaker.session.Session) – Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain.
- model_channel_name (str) – Name of the channel where pre-trained model data will be downloaded (default: ‘model’). If no channel with the same name exists in the training job, this option will be ignored.
Returns: Instance of the calling
Estimator
Class with the attached training job.
-
transformer
(instance_count, instance_type, strategy=None, assemble_with=None, output_path=None, output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None, max_payload=None, tags=None, role=None, model_server_workers=None, volume_kms_key=None, entry_point=None, vpc_config_override='VPC_CONFIG_DEFAULT')¶ Return a
Transformer
that uses a SageMaker Model based on the training job. It reuses the SageMaker Session and base job name used by the Estimator.Parameters: - instance_count (int) – Number of EC2 instances to use.
- instance_type (str) – Type of EC2 instance to use, for example, ‘ml.c4.xlarge’.
- strategy (str) – The strategy used to decide how to batch records in a single request (default: None). Valid values: ‘MULTI_RECORD’ and ‘SINGLE_RECORD’.
- assemble_with (str) – How the output is assembled (default: None). Valid values: ‘Line’ or ‘None’.
- output_path (str) – S3 location for saving the transform result. If not specified, results are stored to a default bucket.
- output_kms_key (str) – Optional. KMS key ID for encrypting the transform output (default: None).
- accept (str) – The accept header passed by the client to the inference endpoint. If it is supported by the endpoint, it will be the format of the batch transform output.
- env (dict) – Environment variables to be set for use during the transform job (default: None).
- max_concurrent_transforms (int) – The maximum number of HTTP requests to be made to each individual transform container at one time.
- max_payload (int) – Maximum size of the payload in a single HTTP request to the container in MB.
- tags (list[dict]) – List of tags for labeling a transform job. If none specified, then the tags used for the training job are used for the transform job.
- role (str) – The
ExecutionRoleArn
IAM Role ARN for theModel
, which is also used during transform jobs. If not specified, the role from the Estimator will be used. - model_server_workers (int) – Optional. The number of worker processes used by the inference server. If None, server will use one worker per vCPU.
- volume_kms_key (str) – Optional. KMS key ID for encrypting the volume attached to the ML compute instance (default: None).
- entry_point (str) – Path (absolute or relative) to the local Python source file which should be executed as the entry point to training. If not specified, the training entry point is used.
- vpc_config_override (dict[str, list[str]]) – Optional override for the VpcConfig set on the model. Default: use subnets and security groups from this Estimator. * ‘Subnets’ (list[str]): List of subnet ids. * ‘SecurityGroupIds’ (list[str]): List of security group ids.
Returns: - a
Transformer
object that can be used to start a SageMaker Batch Transform job.
Return type:
-
compile_model
(target_instance_family, input_shape, output_path, framework=None, framework_version=None, compile_max_run=300, tags=None, **kwargs)¶ Compile a Neo model using the input model.
Parameters: - target_instance_family (str) – Identifies the device that you want to run your model after compilation, for example: ml_c5. For allowed strings see https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
- input_shape (dict) – Specifies the name and shape of the expected inputs for your trained model in json dictionary form, for example: {‘data’:[1,3,1024,1024]}, or {‘var1’: [1,1,28,28], ‘var2’:[1,1,28,28]}
- output_path (str) – Specifies where to store the compiled model
- framework (str) – The framework that is used to train the original model. Allowed values: ‘mxnet’, ‘tensorflow’, ‘keras’, ‘pytorch’, ‘onnx’, ‘xgboost’
- framework_version (str) – The version of the framework
- compile_max_run (int) – Timeout in seconds for compilation (default: 3 * 60). After this amount of time Amazon SageMaker Neo terminates the compilation job regardless of its current status.
- tags (list[dict]) – List of tags for labeling a compilation job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
- **kwargs – Passed to invocation of
create_model()
. Implementations may customizecreate_model()
to accept**kwargs
to customize model creation during deploy. For more, see the implementation docs.
Returns: A SageMaker
Model
object. SeeModel()
for full details.Return type:
-
create_model
(**kwargs)¶ Create a SageMaker
Model
object that can be deployed to anEndpoint
.Parameters: **kwargs – Keyword arguments used by the implemented method for creating the Model
.Returns: A SageMaker Model
object. SeeModel()
for full details.Return type: sagemaker.model.Model
-
delete_endpoint
()¶ Delete an Amazon SageMaker
Endpoint
.Raises: botocore.exceptions.ClientError
– If the endpoint does not exist.
-
deploy
(initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None, use_compiled_model=False, update_endpoint=False, wait=True, model_name=None, kms_key=None, data_capture_config=None, tags=None, **kwargs)¶ Deploy the trained model to an Amazon SageMaker endpoint and return a
sagemaker.RealTimePredictor
object.More information: http://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-training.html
Parameters: - initial_instance_count (int) – Minimum number of EC2 instances to deploy to an endpoint for prediction.
- instance_type (str) – Type of EC2 instance to deploy to an endpoint for prediction, for example, ‘ml.c4.xlarge’.
- accelerator_type (str) – Type of Elastic Inference accelerator to attach to an endpoint for model loading and inference, for example, ‘ml.eia1.medium’. If not specified, no Elastic Inference accelerator will be attached to the endpoint. For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
- endpoint_name (str) – Name to use for creating an Amazon SageMaker endpoint. If not specified, the name of the training job is used.
- use_compiled_model (bool) – Flag to select whether to use compiled (optimized) model. Default: False.
- update_endpoint (bool) – Flag to update the model in an existing Amazon SageMaker endpoint. If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources corresponding to the previous EndpointConfig. Default: False
- wait (bool) – Whether the call should wait until the deployment of model completes (default: True).
- model_name (str) – Name to use for creating an Amazon SageMaker model. If not specified, the name of the training job is used.
- kms_key (str) – The ARN of the KMS key that is used to encrypt the data on the storage volume attached to the instance hosting the endpoint.
- data_capture_config (sagemaker.model_monitor.DataCaptureConfig) – Specifies configuration related to Endpoint data capture for use with Amazon SageMaker Model Monitoring. Default: None.
- tags (List[dict[str, str]]) – Optional. The list of tags to attach to this specific endpoint. Example: >>> tags = [{‘Key’: ‘tagname’, ‘Value’: ‘tagvalue’}] For more information about tags, see https://boto3.amazonaws.com/v1/documentation /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
- **kwargs – Passed to invocation of
create_model()
. Implementations may customizecreate_model()
to accept**kwargs
to customize model creation during deploy. For more, see the implementation docs.
Returns: - A predictor that provides a
predict()
method, which can be used to send requests to the Amazon SageMaker endpoint and obtain inferences.
Return type:
-
fit
(inputs=None, wait=True, logs='All', job_name=None, experiment_config=None)¶ Train a model using the input training dataset.
The API calls the Amazon SageMaker CreateTrainingJob API to start model training. The API uses configuration you provided to create the estimator and the specified input training data to send the CreatingTrainingJob request to Amazon SageMaker.
This is a synchronous operation. After the model training successfully completes, you can call the
deploy()
method to host the model using the Amazon SageMaker hosting services.Parameters: - inputs (str or dict or sagemaker.session.s3_input) –
Information about the training data. This can be one of three types:
- (str) the S3 location where training data is saved.
- (dict[str, str] or dict[str, sagemaker.session.s3_input]) If using multiple
- channels for training data, you can specify a dict mapping channel names to
strings or
s3_input()
objects.
- (sagemaker.session.s3_input) - channel configuration for S3 data sources that can
- provide additional information as well as the path to the training dataset.
See
sagemaker.session.s3_input()
for full details.
- (sagemaker.session.FileSystemInput) - channel configuration for
- a file system data source that can provide additional information as well as the path to the training dataset.
- wait (bool) – Whether the call should wait until the job completes (default: True).
- logs ([str]) – A list of strings specifying which logs to print. Acceptable strings are “All”, “None”, “Training”, or “Rules”. To maintain backwards compatibility, boolean values are also accepted and converted to strings. Only meaningful when wait is True.
- job_name (str) – Training job name. If not specified, the estimator generates a default job name, based on the training image name and current timestamp.
- experiment_config (dict[str, str]) – Experiment management configuration. Dictionary contains three optional keys, ‘ExperimentName’, ‘TrialName’, and ‘TrialComponentDisplayName’.
- inputs (str or dict or sagemaker.session.s3_input) –
-
get_vpc_config
(vpc_config_override='VPC_CONFIG_DEFAULT')¶ Returns VpcConfig dict either from this Estimator’s subnets and security groups, or else validate and return an optional override value.
Parameters: vpc_config_override –
-
latest_job_debugger_artifacts_path
()¶ Gets the path to the DebuggerHookConfig output artifacts.
Returns: An S3 path to the output artifacts. Return type: str
-
latest_job_tensorboard_artifacts_path
()¶ Gets the path to the TensorBoardOutputConfig output artifacts.
Returns: An S3 path to the output artifacts. Return type: str
-
model_data
¶ str – The model location in S3. Only set if Estimator has been
fit()
.
-
prepare_workflow_for_training
(job_name=None)¶ Calls _prepare_for_training. Used when setting up a workflow.
Parameters: job_name (str) – Name of the training job to be created. If not specified, one is generated, using the base name given to the constructor if applicable.
-
training_job_analytics
¶ Return a
TrainingJobAnalytics
object for the current training job.
- entry_point (str) –