sagemaker.train.cpt_trainer#
CPTTrainer — Continued Pre-Training on foundation models using SageMaker HyperPod.
Classes
|
Performs Continued Pre-Training (CPT) on foundation models using SageMaker HyperPod. |
- class sagemaker.train.cpt_trainer.CPTTrainer(model: str | ModelPackage, model_package_group: str | ModelPackageGroup | None = None, compute: HyperPodCompute | None = None, mlflow_resource_arn: str | None = None, mlflow_experiment_name: str | None = None, mlflow_run_name: str | None = None, training_dataset: str | DataSet | None = None, validation_dataset: str | DataSet | None = None, s3_output_path: str | None = None, kms_key_id: str | None = None, networking: VpcConfig | None = None, accept_eula: bool | None = False, stopping_condition: StoppingCondition | None = None, recipe: str | None = None, overrides: dict | None = None, training_image: str | None = None, data_mixing_config: DataMixingConfig | None = None, base_model_name: str | None = None, disable_output_compression: bool | None = False, **kwargs)[source]#
Bases:
BaseTrainerPerforms Continued Pre-Training (CPT) on foundation models using SageMaker HyperPod.
CPT extends a foundation model’s knowledge by further pre-training on domain-specific unlabeled text data. This is useful for adapting models to specialized domains (legal, medical, finance, etc.) before applying task-specific fine-tuning.
CPT is only supported on HyperPod compute.
Example:
from sagemaker.train import CPTTrainer from sagemaker.core.training.configs import HyperPodCompute trainer = CPTTrainer( model="amazon.nova-lite-v2", model_package_group="my-cpt-models", training_dataset="s3://bucket/domain_corpus.jsonl", s3_output_path="s3://bucket/output/", compute=HyperPodCompute( cluster_name="my-cluster", instance_type="ml.p5.48xlarge", node_count=4, ), recipe="training/nova/nova_2_0/nova_lite/CPT/nova_lite_2_0_p5x8_gpu_pretrain", overrides={"recipes.training_config.trainer.max_steps": 100}, ) training_job = trainer.train(wait=False)
- Parameters:
model (Union[str, ModelPackage]) – The foundation model to continue pre-training.
model_package_group (Optional[Union[str, ModelPackageGroup]]) – The model package group for storing the trained model.
compute (Optional[HyperPodCompute]) – HyperPod compute configuration. Required — CPT only runs on HyperPod.
mlflow_resource_arn (Optional[str]) – The MLflow tracking server ARN for experiment tracking.
mlflow_experiment_name (Optional[str]) – The MLflow experiment name.
mlflow_run_name (Optional[str]) – The MLflow run name.
training_dataset (Optional[Union[str, DataSet]]) – S3 URI or DataSet object pointing to unlabeled text data.
validation_dataset (Optional[Union[str, DataSet]]) – Validation dataset for computing validation loss during training.
s3_output_path (Optional[str]) – S3 path for training job outputs.
kms_key_id (Optional[str]) – KMS key ID for encrypting outputs.
networking (Optional[VpcConfig]) – VPC configuration for the training job.
stopping_condition (Optional[StoppingCondition]) – Stopping condition to override training runtime limit.
recipe (Optional[str]) – Path to a user recipe YAML file or HyperPod recipe name. If not provided, the recipe is auto-resolved from SageMaker Hub based on model and training type.
overrides (Optional[dict]) – Programmatic overrides dict.
training_image (Optional[str]) – Custom training container image URI. If not provided, auto-resolved from Hub.
- train(training_dataset: str | DataSet | None = None, validation_dataset: str | DataSet | None = None, wait: bool = True, wait_timeout: int | None = None, poll: int = 5)[source]#
Execute the CPT training job on HyperPod.
- Parameters:
training_dataset (Optional[Union[str, DataSet]]) – Training dataset. Overrides the dataset specified in __init__.
validation_dataset (Optional[Union[str, DataSet]]) – Validation dataset. Overrides the dataset specified in __init__.
wait (bool) – Whether to wait for the job to complete. Defaults to True.
wait_timeout (Optional[int]) – Maximum time in seconds to wait.
poll (int) – Polling interval in seconds. Defaults to 5.
- Returns:
The HyperPod job name.
- Return type:
str
- Raises:
ValueError – If compute is not configured or recipe is missing.