Source code for sagemaker.train.rlaif_trainer

from typing import Optional, Union
import logging
from sagemaker.train.base_trainer import BaseTrainer
from sagemaker.train.common import TrainingType, CustomizationTechnique, JOB_TYPE
from sagemaker.core.resources import TrainingJob, ModelPackageGroup, MlflowTrackingServer, ModelPackage
from sagemaker.core.shapes import VpcConfig
from sagemaker.train.defaults import TrainDefaults
from sagemaker.train.utils import _get_unique_name, _get_studio_tags
from sagemaker.train.common_utils.recipe_utils import _get_hub_content_metadata
from sagemaker.ai_registry.dataset import DataSet
from sagemaker.ai_registry.evaluator import Evaluator
from sagemaker.train.configs import StoppingCondition
from sagemaker.train.common_utils.finetune_utils import (
    _get_beta_session,
    _get_fine_tuning_options_and_model_arn,
    _validate_and_resolve_model_package_group,
    _extract_evaluator_arn,
    _resolve_model_and_name,
    _create_input_data_config,
    _convert_input_data_to_channels,
    _create_output_config,
    _create_serverless_config,
    _create_mlflow_config,
    _create_model_package_config,
    _validate_eula_for_gated_model,
    _validate_hyperparameter_values
)
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
from sagemaker.train.constants import get_sagemaker_hub_name, _ALLOWED_REWARD_MODEL_IDS

logger = logging.getLogger(__name__)


[docs] class RLAIFTrainer(BaseTrainer): """Class that performs Reinforcement Learning from AI Feedback (RLAIF) fine-tuning on foundation models using AWS SageMaker. Example: .. code:: python from sagemaker.train import RLAIFTrainer from sagemaker.train.common import TrainingType trainer = RLAIFTrainer( model="meta-llama/Llama-2-7b-hf", training_type=TrainingType.LORA, model_package_group="my-model-group", reward_model_id="reward-model-id", reward_prompt="Rate the helpfulness of this response on a scale of 1-10", training_dataset="s3://bucket/rlaif_data.jsonl" ) trainer.train() # Complete workflow: create -> wait -> get model package ARN trainer = RLAIFTrainer( model="meta-llama/Llama-2-7b-hf", model_package_group="my-rlaif-models", reward_model_id="reward-model-id", reward_prompt="Rate the helpfulness of this response on a scale of 1-10" ) # Create training job (non-blocking) training_job = trainer.train( training_dataset="s3://bucket/rlaif_data.jsonl", wait=False ) # Wait for completion training_job.wait() # Refresh job status training_job.refresh() # Get the fine-tuned model package ARN model_package_arn = training_job.output_model_package_arn Parameters: model (Union[str, ModelPackage]): The foundation model to fine-tune. Can be a model name string, model package ARN, or ModelPackage object. training_type (Union[TrainingType, str]): The fine-tuning approach. Valid values are TrainingType.LORA (default), TrainingType.FULL. model_package_group (Optional[Union[str, ModelPackageGroup]]): The model package group for storing the fine-tuned model. Can be a group name, ARN, or ModelPackageGroup object. Required when model is not a ModelPackage. reward_model_id (str): Bedrock model identifier for generating LLM feedback. Required for RLAIF training to provide reward signals. reward_prompt (Union[str, Evaluator]): The reward prompt or evaluator for AI feedback generation. Can be a prompt string or Evaluator object. For Builtin metric prompts refer: https://docs.aws.amazon.com/bedrock/latest/userguide/model-evaluation-metrics.html mlflow_resource_arn (Optional[Union[str, MlflowTrackingServer]]): The MLflow tracking server ARN for experiment tracking. If not specified, uses default MLflow experience. mlflow_experiment_name (Optional[str]): The MLflow experiment name for organizing runs. mlflow_run_name (Optional[str]): The MLflow run name for this training job. training_dataset (Optional[Union[str, DataSet]]): The training dataset. Can be a dataset ARN, or DataSet object. validation_dataset (Optional[Union[str, DataSet]]): The validation dataset. Can be a dataset ARN, or DataSet object. s3_output_path (Optional[str]): The S3 path for training job outputs. If not specified, defaults to s3://sagemaker-<region>-<account>/output. kms_key_id (Optional[str]): The KMS key ID for encrypting training job outputs. networking (Optional[VpcConfig]): The VPC configuration for the training job. stopping_condition (Optional[StoppingCondition]): The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training). """ def __init__( self, model: Union[str, ModelPackage], training_type: Union[TrainingType, str] = TrainingType.LORA, model_package_group: Optional[Union[str, ModelPackageGroup]] = None, reward_model_id: str = None, reward_prompt: Union[str, Evaluator] = None, mlflow_resource_arn: Optional[Union[str, MlflowTrackingServer]] = None, mlflow_experiment_name: Optional[str] = None, mlflow_run_name: Optional[str] = None, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, s3_output_path: Optional[str] = None, # Additional OutputDataConfig parameters kms_key_id: Optional[str] = None, # vpc config networking: Optional[VpcConfig] = None, accept_eula: bool = False, stopping_condition: Optional[StoppingCondition] = None, **kwargs, ): super().__init__(**kwargs) # Resolve model and model name self.model, self._model_name = _resolve_model_and_name(model, self.sagemaker_session) self.training_type = training_type self.model_package_group = _validate_and_resolve_model_package_group(model, model_package_group) self.reward_model_id = self._validate_reward_model_id(reward_model_id) self.reward_prompt = reward_prompt self.mlflow_resource_arn = mlflow_resource_arn self.mlflow_experiment_name = mlflow_experiment_name self.mlflow_run_name = mlflow_run_name self.training_dataset = training_dataset self.validation_dataset = validation_dataset self.s3_output_path = s3_output_path self.kms_key_id = kms_key_id self.networking = networking self.stopping_condition = stopping_condition # Initialize fine-tuning options with beta session fallback self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name, CustomizationTechnique.RLAIF.value, self.training_type, self.sagemaker_session or TrainDefaults.get_sagemaker_session( sagemaker_session=self.sagemaker_session )) # Validate and set EULA acceptance self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) # Process reward_prompt parameter self._process_hyperparameters() def _validate_reward_model_id(self, reward_model_id): """Validate reward_model_id is one of the allowed values.""" if not reward_model_id: return None if reward_model_id not in _ALLOWED_REWARD_MODEL_IDS: raise ValueError( f"Invalid reward_model_id '{reward_model_id}'. " f"Available models are: {list(_ALLOWED_REWARD_MODEL_IDS.keys())}" ) # Check region compatibility session = self.sagemaker_session if hasattr(self, 'sagemaker_session') and self.sagemaker_session else TrainDefaults.get_sagemaker_session() current_region = session.boto_region_name allowed_regions = _ALLOWED_REWARD_MODEL_IDS[reward_model_id] if current_region not in allowed_regions: raise ValueError( f"Reward model '{reward_model_id}' is not available in region '{current_region}'. " f"Available regions for this model: {allowed_regions}" ) return reward_model_id
[docs] @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLAIFTrainer.train") def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None, poll: int = 5): """Execute the RLAIF training job. Parameters: training_dataset (Optional[Union[str, DataSet]]): The training dataset for this job. Overrides the dataset specified in __init__. Can be an S3 URI, dataset ARN, or DataSet object. validation_dataset (Optional[Union[str, DataSet]]): The validation dataset for this job. Overrides the dataset specified in __init__. Can be an S3 URI, dataset ARN, or DataSet object. wait (bool): Whether to wait for the training job to complete. Defaults to True. wait_timeout (Optional[int]): Maximum time in seconds to wait for the training job to complete. Only used when wait=True. If None, uses the default timeout from the wait utility. poll (int): Polling interval in seconds for checking training job status. Defaults to 5. Returns: TrainingJob: The SageMaker training job object. """ sagemaker_session = TrainDefaults.get_sagemaker_session( sagemaker_session=self.sagemaker_session ) role = TrainDefaults.get_role(role=self.role, sagemaker_session=sagemaker_session) current_training_job_name = _get_unique_name( self.base_job_name or f"{self._model_name}-rlaif" ) logger.info(f"Training Job Name: {current_training_job_name}") #data input_data_config = _create_input_data_config(training_dataset or self.training_dataset, validation_dataset or self.validation_dataset ) channels = _convert_input_data_to_channels(input_data_config) output_config = _create_output_config( s3_output_path=self.s3_output_path, sagemaker_session=sagemaker_session, kms_key_id=self.kms_key_id ) evaluator_arn = getattr(self, '_evaluator_arn', None) serverless_config = _create_serverless_config(model_arn=self._model_arn, customization_technique=CustomizationTechnique.RLAIF.value, training_type=self.training_type, accept_eula=self.accept_eula, evaluator_arn=evaluator_arn, job_type=JOB_TYPE ) mlflow_config = _create_mlflow_config( sagemaker_session, mlflow_resource_arn=self.mlflow_resource_arn, mlflow_experiment_name=self.mlflow_experiment_name, mlflow_run_name=self.mlflow_run_name, ) final_hyperparameters = self.hyperparameters.to_dict() _validate_hyperparameter_values(final_hyperparameters) model_package_config = _create_model_package_config( model_package_group_name=self.model_package_group, model=self.model, sagemaker_session=sagemaker_session ) vpc_config = self.networking if self.networking else None tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name()) # Build TrainingJob.create() arguments create_args = { "training_job_name": current_training_job_name, "role_arn": role, "input_data_config": channels, "output_data_config": output_config, "serverless_job_config": serverless_config, "mlflow_config": mlflow_config, "hyper_parameters": final_hyperparameters, "model_package_config": model_package_config, "vpc_config": vpc_config, "session": sagemaker_session.boto_session, "region": sagemaker_session.boto_session.region_name, "tags": tags, } # Only pass stopping_condition if explicitly provided by user if self.stopping_condition is not None: create_args["stopping_condition"] = self.stopping_condition try: training_job = TrainingJob.create(**create_args) except Exception as e: logger.error("Error: %s", e) raise e if wait: from sagemaker.train.common_utils.trainer_wait import wait as _wait from sagemaker.core.utils.exceptions import TimeoutExceededError try : wait_kwargs = {} if wait_timeout is not None: wait_kwargs['timeout'] = wait_timeout wait_kwargs['poll'] = poll _wait(training_job, **wait_kwargs) except TimeoutExceededError as e: logger.error("Error: %s", e) self._latest_training_job = training_job return training_job
def _process_hyperparameters(self): """Update hyperparameters based on constructor inputs and process reward_prompt.""" if not self.hyperparameters or not hasattr(self.hyperparameters, '_specs') or not self.hyperparameters._specs: return # Remove keys that are handled by constructor inputs if hasattr(self.hyperparameters, 'output_path'): delattr(self.hyperparameters, 'output_path') self.hyperparameters._specs.pop('output_path', None) if hasattr(self.hyperparameters, 'data_path'): delattr(self.hyperparameters, 'data_path') self.hyperparameters._specs.pop('data_path', None) if hasattr(self.hyperparameters, 'validation_data_path'): delattr(self.hyperparameters, 'validation_data_path') self.hyperparameters._specs.pop('validation_data_path', None) # Update judge_model_id if reward_model_id is provided if hasattr(self, 'reward_model_id') and self.reward_model_id: judge_model_value = f"bedrock/{self.reward_model_id}" self.hyperparameters.judge_model_id = judge_model_value # Process reward_prompt parameter if hasattr(self, 'reward_prompt') and self.reward_prompt: if isinstance(self.reward_prompt, str): if self.reward_prompt.startswith("Builtin"): # Handle builtin reward prompts self._update_judge_prompt_template_direct(self.reward_prompt) else: # Handle evaluator ARN or hub content name self._process_non_builtin_reward_prompt() else: # Handle evaluator object if hasattr(self.hyperparameters, 'judge_prompt_template'): delattr(self.hyperparameters, 'judge_prompt_template') self.hyperparameters._specs.pop('judge_prompt_template', None) evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt") self._evaluator_arn = evaluator_arn def _process_non_builtin_reward_prompt(self): """Process non-builtin reward prompt (ARN or hub content name).""" # Remove judge_prompt_template for non-builtin prompts if hasattr(self.hyperparameters, 'judge_prompt_template'): delattr(self.hyperparameters, 'judge_prompt_template') self.hyperparameters._specs.pop('judge_prompt_template', None) if self.reward_prompt.startswith("arn:aws:sagemaker:"): # Validate and assign ARN evaluator_arn = _extract_evaluator_arn(self.reward_prompt, "reward_prompt") self._evaluator_arn = evaluator_arn else: try: session = TrainDefaults.get_sagemaker_session( sagemaker_session=self.sagemaker_session ) hub_content = _get_hub_content_metadata( hub_name=get_sagemaker_hub_name(), hub_content_type="JsonDoc", hub_content_name=self.reward_prompt, session=session.boto_session, region=session.boto_session.region_name ) # Store ARN for evaluator_arn self._evaluator_arn = hub_content.hub_content_arn except Exception as e: raise ValueError(f"Custom prompt '{self.reward_prompt}' not found in HubContent: {e}") def _update_judge_prompt_template_direct(self, reward_prompt): """Update judge_prompt_template based on Builtin reward function.""" # Get available templates from hyperparameters specs judge_prompt_spec = self.hyperparameters._specs.get('judge_prompt_template', {}) available_templates = judge_prompt_spec.get('enum', []) if not available_templates: # If no enum found, use the current value as the only available option current_value = getattr(self.hyperparameters, 'judge_prompt_template', None) if current_value: available_templates = [current_value] else: return # Extract template name after "Builtin." and convert to lowercase template_name = reward_prompt.split(".", 1)[1].lower() # Find matching template by extracting filename without extension matching_template = None for template in available_templates: template_filename = template.split("/")[-1].replace(".jinja", "").lower() if template_filename == template_name: matching_template = template break if matching_template: self.hyperparameters.judge_prompt_template = matching_template else: available_options = [f"Builtin.{t.split('/')[-1].replace('.jinja', '')}" for t in available_templates] raise ValueError( f"Selected reward function option '{reward_prompt}' is not available. " f"Choose one from the available options: {available_options}. " f"Example: reward_prompt='Builtin.summarize'" )