sagemaker.train.sft_trainer

sagemaker.train.sft_trainer#

Classes

SFTTrainer(model[, training_type, ...])

Class that performs Supervised Fine-Tuning (SFT) on foundation models using AWS SageMaker.

class sagemaker.train.sft_trainer.SFTTrainer(model: str | ModelPackage, training_type: TrainingType | str = TrainingType.LORA, model_package_group: str | ModelPackageGroup | None = None, compute: Compute | HyperPodCompute | None = None, mlflow_resource_arn: str | None = None, mlflow_experiment_name: str | None = None, mlflow_run_name: str | None = None, training_dataset: str | DataSet | None = None, validation_dataset: str | DataSet | None = None, s3_output_path: str | None = None, kms_key_id: str | None = None, networking: VpcConfig | None = None, accept_eula: bool | None = False, stopping_condition: StoppingCondition | None = None, recipe: str | None = None, overrides: dict | None = None, is_multimodal: bool | None = None, data_mixing_config: DataMixingConfig | None = None, base_model_name: str | None = None, disable_output_compression: bool | None = False, **kwargs)[source]#

Bases: BaseTrainer

Class that performs Supervised Fine-Tuning (SFT) on foundation models using AWS SageMaker.

Example:

from sagemaker.train import SFTTrainer
from sagemaker.train.common import TrainingType

trainer = SFTTrainer(
    model="meta-llama/Llama-2-7b-hf",
    training_type=TrainingType.LORA,
    model_package_group="my-model-group",
    training_dataset="s3://bucket/train.jsonl",
    validation_dataset="s3://bucket/val.jsonl"
)

trainer.train()

# Complete workflow:
trainer = SFTTrainer(
    model="meta-llama/Llama-2-7b-hf",
    model_package_group="my-fine-tuned-models"
)

# Create training job (non-blocking)
training_job = trainer.train(
    training_dataset="s3://bucket/train.jsonl",
    wait=False
)

# Wait for completion
training_job.wait()

# Refresh job status
training_job.refresh()

# Get the fine-tuned model artifacts ARN
model_package_arn = training_job.output_model_package_arn
Parameters:
  • model (Union[str, ModelPackage]) – The foundation model to fine-tune. Can be a model name string, model package ARN, or ModelPackage object.

  • training_type (Union[TrainingType, str]) – The fine-tuning approach. Valid values are TrainingType.LORA (default), TrainingType.FULL.

  • model_package_group (Optional[Union[str, ModelPackageGroup]]) – The model package group for storing the fine-tuned model. Can be a group name, ARN, or ModelPackageGroup object. Required when model is not a ModelPackage.

  • mlflow_resource_arn (Optional[str]) – The MLflow tracking server ARN for experiment tracking. If not specified, uses default MLflow experience.

  • mlflow_experiment_name (Optional[str]) – The MLflow experiment name for organizing runs.

  • mlflow_run_name (Optional[str]) – The MLflow run name for this training job.

  • training_dataset (Optional[Union[str, DataSet]]) – The training dataset. Can be dataset ARN, or DataSet object.

  • validation_dataset (Optional[Union[str, DataSet]]) – The validation dataset. Can be dataset ARN, or DataSet object.

  • s3_output_path (Optional[str]) – The S3 path for training job outputs. If not specified, defaults to s3://sagemaker-<region>-<account>/output.

  • kms_key_id (Optional[str]) – The KMS key ID for encrypting training job outputs.

  • networking (Optional[VpcConfig]) – The VPC configuration for the training job.

  • stopping_condition (Optional[StoppingCondition]) – The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training).

  • recipe (Optional[str]) – Path to a user recipe YAML file (local path or S3 URI). When provided, enables 3-level recipe resolution: Hub defaults < recipe file < overrides dict. The recipe file can contain any training parameters in nested YAML format.

  • overrides (Optional[dict]) – Programmatic overrides dict with nested structure matching the recipe layout (e.g., {"training_config": {"learning_rate": 2e-5}}). Takes highest precedence. When provided, resolved recipe values override matching hyperparameters at train() time. Use get_resolved_recipe() to inspect the final merged config.

  • is_multimodal (Optional[bool]) – Whether the training dataset contains multimodal data. If None (default), auto-detected from the training dataset at train time.

  • base_model_name (Optional[str]) – Base model name for recipe lookup when model is an S3 checkpoint path. Required when model starts with s3:// so the SDK knows which recipe, container image, and validation spec to use. Example: "amazon.nova-2-lite-v1".

  • disable_output_compression (Optional[bool]) – Whether to disable compression of model output artifacts. When True, model artifacts are stored uncompressed in S3 (compression_type=”NONE”). Recommended for large model outputs. Defaults to False (gzip compression).

train(training_dataset: str | DataSet | None = None, validation_dataset: str | DataSet | None = None, wait: bool = True, wait_timeout: int | None = None, poll: int = 5)[source]#

Execute the SFT training job.

Parameters:
  • training_dataset (Optional[Union[str, DataSet]]) – The training dataset for this job. Overrides the dataset specified in __init__. Can be an S3 URI, dataset ARN, or DataSet object.

  • validation_dataset (Optional[Union[str, DataSet]]) – The validation dataset for this job. Overrides the dataset specified in __init__. Can be an S3 URI, dataset ARN, or DataSet object.

  • wait (bool) – Whether to wait for the training job to complete. Defaults to True.

  • wait_timeout (Optional[int]) – Maximum time in seconds to wait for the training job to complete. Only used when wait=True. If None, uses the default timeout from the wait utility.

  • poll (int) – Polling interval in seconds for checking training job status. Defaults to 5.

Returns:

The SageMaker training job object.

Return type:

TrainingJob