Source code for sagemaker.train.cpt_trainer

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""CPTTrainer — Continued Pre-Training on foundation models using SageMaker HyperPod."""

from typing import Optional, Union
import logging

from sagemaker.train.base_trainer import BaseTrainer
from sagemaker.train.common import TrainingType, CustomizationTechnique
from sagemaker.core.resources import ModelPackageGroup, ModelPackage
from sagemaker.core.shapes import VpcConfig
from sagemaker.ai_registry.dataset import DataSet
from sagemaker.train.configs import StoppingCondition
from sagemaker.train.data_mixing_config import DataMixingConfig
from sagemaker.core.training.configs import HyperPodCompute
from sagemaker.train.common_utils.finetune_utils import (
    _validate_and_resolve_model_package_group,
    _resolve_model_and_name,
    _resolve_model_with_checkpoint,
    _validate_eula_for_gated_model,
)
from sagemaker.train.common_utils.data_mixing_utils import (
    validate_data_mixing_model,
    validate_data_mixing_categories,
    resolve_hyperpod_datamix_context,
    build_hyperpod_datamix_recipe_from_context,
)
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


[docs] class CPTTrainer(BaseTrainer): """Performs Continued Pre-Training (CPT) on foundation models using SageMaker HyperPod. CPT extends a foundation model's knowledge by further pre-training on domain-specific unlabeled text data. This is useful for adapting models to specialized domains (legal, medical, finance, etc.) before applying task-specific fine-tuning. CPT is only supported on HyperPod compute. Example: .. code:: python from sagemaker.train import CPTTrainer from sagemaker.core.training.configs import HyperPodCompute trainer = CPTTrainer( model="amazon.nova-lite-v2", model_package_group="my-cpt-models", training_dataset="s3://bucket/domain_corpus.jsonl", s3_output_path="s3://bucket/output/", compute=HyperPodCompute( cluster_name="my-cluster", instance_type="ml.p5.48xlarge", node_count=4, ), recipe="training/nova/nova_2_0/nova_lite/CPT/nova_lite_2_0_p5x8_gpu_pretrain", overrides={"recipes.training_config.trainer.max_steps": 100}, ) training_job = trainer.train(wait=False) Parameters: model (Union[str, ModelPackage]): The foundation model to continue pre-training. model_package_group (Optional[Union[str, ModelPackageGroup]]): The model package group for storing the trained model. compute (Optional[HyperPodCompute]): HyperPod compute configuration. Required — CPT only runs on HyperPod. mlflow_resource_arn (Optional[str]): The MLflow tracking server ARN for experiment tracking. mlflow_experiment_name (Optional[str]): The MLflow experiment name. mlflow_run_name (Optional[str]): The MLflow run name. training_dataset (Optional[Union[str, DataSet]]): S3 URI or DataSet object pointing to unlabeled text data. validation_dataset (Optional[Union[str, DataSet]]): Validation dataset for computing validation loss during training. s3_output_path (Optional[str]): S3 path for training job outputs. kms_key_id (Optional[str]): KMS key ID for encrypting outputs. networking (Optional[VpcConfig]): VPC configuration for the training job. stopping_condition (Optional[StoppingCondition]): Stopping condition to override training runtime limit. recipe (Optional[str]): Path to a user recipe YAML file or HyperPod recipe name. If not provided, the recipe is auto-resolved from SageMaker Hub based on model and training type. overrides (Optional[dict]): Programmatic overrides dict. training_image (Optional[str]): Custom training container image URI. If not provided, auto-resolved from Hub. """ _customization_technique = CustomizationTechnique.CPT.value def __init__( self, model: Union[str, ModelPackage], model_package_group: Optional[Union[str, ModelPackageGroup]] = None, compute: Optional[HyperPodCompute] = None, mlflow_resource_arn: Optional[str] = None, mlflow_experiment_name: Optional[str] = None, mlflow_run_name: Optional[str] = None, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, s3_output_path: Optional[str] = None, kms_key_id: Optional[str] = None, networking: Optional[VpcConfig] = None, accept_eula: Optional[bool] = False, stopping_condition: Optional[StoppingCondition] = None, recipe: Optional[str] = None, overrides: Optional[dict] = None, training_image: Optional[str] = None, data_mixing_config: Optional[DataMixingConfig] = None, base_model_name: Optional[str] = None, disable_output_compression: Optional[bool] = False, **kwargs, ): super().__init__(training_image=training_image, base_model_name=base_model_name, disable_output_compression=disable_output_compression, **kwargs) self.model, self._model_name, self.model_source = _resolve_model_with_checkpoint( model, self.base_model_name, compute, self.sagemaker_session, resolve_fn=_resolve_model_and_name, ) self.training_type = TrainingType.FULL if compute is None: self.model_package_group = _validate_and_resolve_model_package_group( model, model_package_group ) else: self.model_package_group = model_package_group if compute is not None and not isinstance(compute, HyperPodCompute): raise TypeError( f"CPT only supports HyperPod compute. Got {type(compute).__name__}. " f"Pass a HyperPodCompute instance with cluster_name and recipe." ) self.compute = compute self.data_mixing_config = data_mixing_config self.mlflow_resource_arn = mlflow_resource_arn self.mlflow_experiment_name = mlflow_experiment_name self.mlflow_run_name = mlflow_run_name self.training_dataset = training_dataset self.validation_dataset = validation_dataset self.s3_output_path = s3_output_path self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition self._recipe_path = recipe self._overrides = overrides self._recipe_resolver = None self._resolved_recipe_cache = None self.disable_output_compression = disable_output_compression # CPT is HyperPod-only and the recipe is auto-resolved from Hub if not # provided by the user. No Hub lookup for hyperparameters is needed — # they are managed entirely by the HyperPod recipe and overrides. self.hyperparameters = None self._model_arn = None self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, False)
[docs] @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="CPTTrainer.train") def train( self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None, poll: int = 5, ): """Execute the CPT training job on HyperPod. Parameters: training_dataset (Optional[Union[str, DataSet]]): Training dataset. Overrides the dataset specified in __init__. validation_dataset (Optional[Union[str, DataSet]]): Validation dataset. Overrides the dataset specified in __init__. wait (bool): Whether to wait for the job to complete. Defaults to True. wait_timeout (Optional[int]): Maximum time in seconds to wait. poll (int): Polling interval in seconds. Defaults to 5. Returns: str: The HyperPod job name. Raises: ValueError: If compute is not configured or recipe is missing. """ if not isinstance(self.compute, HyperPodCompute): raise ValueError( "CPT requires HyperPod compute. Pass compute=HyperPodCompute(...) " "when creating the CPTTrainer." ) if self.data_mixing_config is not None: if not isinstance(self.compute, HyperPodCompute): raise ValueError( "Data mixing is only supported on HyperPod. " "Provide a HyperPodCompute instance as compute." ) validate_data_mixing_model(self._model_name) is_multimodal = getattr(self, "is_multimodal", False) or False from sagemaker.train.defaults import TrainDefaults sagemaker_session = TrainDefaults.get_sagemaker_session( sagemaker_session=self.sagemaker_session ) context = resolve_hyperpod_datamix_context( model_name=self._model_name, is_multimodal=is_multimodal, sagemaker_session=sagemaker_session, training_type="FULL", customization_technique="CPT", ) validated_config = validate_data_mixing_categories( self.data_mixing_config, context.categories ) recipe_path, hp_image_uri = build_hyperpod_datamix_recipe_from_context( context, validated_config ) self._recipe_path = recipe_path if hp_image_uri and not self.training_image: self.training_image = hp_image_uri return self._train_hyperpod( training_dataset=training_dataset, validation_dataset=validation_dataset, wait=wait, wait_timeout=wait_timeout, poll=poll, )