sagemaker.train.cpt_trainer

sagemaker.train.cpt_trainer#

CPTTrainer — Continued Pre-Training on foundation models using SageMaker HyperPod.

Classes

CPTTrainer(model[, model_package_group, ...])

Performs Continued Pre-Training (CPT) on foundation models using SageMaker HyperPod.

class sagemaker.train.cpt_trainer.CPTTrainer(model: str | ModelPackage, model_package_group: str | ModelPackageGroup | None = None, compute: HyperPodCompute | None = None, mlflow_resource_arn: str | None = None, mlflow_experiment_name: str | None = None, mlflow_run_name: str | None = None, training_dataset: str | DataSet | None = None, validation_dataset: str | DataSet | None = None, s3_output_path: str | None = None, kms_key_id: str | None = None, networking: VpcConfig | None = None, accept_eula: bool | None = False, stopping_condition: StoppingCondition | None = None, recipe: str | None = None, overrides: dict | None = None, training_image: str | None = None, data_mixing_config: DataMixingConfig | None = None, base_model_name: str | None = None, disable_output_compression: bool | None = False, **kwargs)[source]#

Bases: BaseTrainer

Performs Continued Pre-Training (CPT) on foundation models using SageMaker HyperPod.

CPT extends a foundation model’s knowledge by further pre-training on domain-specific unlabeled text data. This is useful for adapting models to specialized domains (legal, medical, finance, etc.) before applying task-specific fine-tuning.

CPT is only supported on HyperPod compute.

Example:

from sagemaker.train import CPTTrainer
from sagemaker.core.training.configs import HyperPodCompute

trainer = CPTTrainer(
    model="amazon.nova-lite-v2",
    model_package_group="my-cpt-models",
    training_dataset="s3://bucket/domain_corpus.jsonl",
    s3_output_path="s3://bucket/output/",
    compute=HyperPodCompute(
        cluster_name="my-cluster",
        instance_type="ml.p5.48xlarge",
        node_count=4,
    ),
    recipe="training/nova/nova_2_0/nova_lite/CPT/nova_lite_2_0_p5x8_gpu_pretrain",
    overrides={"recipes.training_config.trainer.max_steps": 100},
)

training_job = trainer.train(wait=False)
Parameters:
  • model (Union[str, ModelPackage]) – The foundation model to continue pre-training.

  • model_package_group (Optional[Union[str, ModelPackageGroup]]) – The model package group for storing the trained model.

  • compute (Optional[HyperPodCompute]) – HyperPod compute configuration. Required — CPT only runs on HyperPod.

  • mlflow_resource_arn (Optional[str]) – The MLflow tracking server ARN for experiment tracking.

  • mlflow_experiment_name (Optional[str]) – The MLflow experiment name.

  • mlflow_run_name (Optional[str]) – The MLflow run name.

  • training_dataset (Optional[Union[str, DataSet]]) – S3 URI or DataSet object pointing to unlabeled text data.

  • validation_dataset (Optional[Union[str, DataSet]]) – Validation dataset for computing validation loss during training.

  • s3_output_path (Optional[str]) – S3 path for training job outputs.

  • kms_key_id (Optional[str]) – KMS key ID for encrypting outputs.

  • networking (Optional[VpcConfig]) – VPC configuration for the training job.

  • stopping_condition (Optional[StoppingCondition]) – Stopping condition to override training runtime limit.

  • recipe (Optional[str]) – Path to a user recipe YAML file or HyperPod recipe name. If not provided, the recipe is auto-resolved from SageMaker Hub based on model and training type.

  • overrides (Optional[dict]) – Programmatic overrides dict.

  • training_image (Optional[str]) – Custom training container image URI. If not provided, auto-resolved from Hub.

train(training_dataset: str | DataSet | None = None, validation_dataset: str | DataSet | None = None, wait: bool = True, wait_timeout: int | None = None, poll: int = 5)[source]#

Execute the CPT training job on HyperPod.

Parameters:
  • training_dataset (Optional[Union[str, DataSet]]) – Training dataset. Overrides the dataset specified in __init__.

  • validation_dataset (Optional[Union[str, DataSet]]) – Validation dataset. Overrides the dataset specified in __init__.

  • wait (bool) – Whether to wait for the job to complete. Defaults to True.

  • wait_timeout (Optional[int]) – Maximum time in seconds to wait.

  • poll (int) – Polling interval in seconds. Defaults to 5.

Returns:

The HyperPod job name.

Return type:

str

Raises:

ValueError – If compute is not configured or recipe is missing.