import json
import logging
from typing import Any, Dict, List, Optional, Union
from sagemaker.train.base_trainer import BaseTrainer
from sagemaker.train.common import TrainingType, CustomizationTechnique, JOB_TYPE
from sagemaker.core.resources import TrainingJob, ModelPackageGroup, MlflowTrackingServer, ModelPackage
from sagemaker.core.shapes import VpcConfig
from sagemaker.train.defaults import TrainDefaults
from sagemaker.train.utils import _get_unique_name, _get_jumpstart_tags
from sagemaker.ai_registry.dataset import DataSet
from sagemaker.ai_registry.evaluator import Evaluator
from sagemaker.core.training.configs import TrainingJobCompute, HyperPodCompute
from sagemaker.train.configs import StoppingCondition
from sagemaker.core.training.configs import TrainingJobCompute, HyperPodCompute
from sagemaker.train.common_utils.finetune_utils import (
_get_fine_tuning_options_and_model_arn,
_validate_and_resolve_model_package_group,
_extract_evaluator_arn,
_is_lambda_arn,
_is_nova_model,
_resolve_model_and_name,
_resolve_model_with_checkpoint,
_create_input_data_config,
_convert_input_data_to_channels,
_create_output_config,
_create_serverless_config,
_create_mlflow_config,
_create_model_package_config,
_validate_eula_for_gated_model,
_validate_hyperparameter_values
)
from sagemaker.train.common_utils.data_utils import is_multimodal_data, load_file_content
from sagemaker.train.common_utils.rlvr_reward_verifier import verify_reward_function
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
from sagemaker.train.constants import get_sagemaker_hub_name
logger = logging.getLogger(__name__)
[docs]
class RLVRTrainer(BaseTrainer):
"""Class that performs Reinforcement Learning from Verifiable Rewards (RLVR) fine-tuning on foundation models using AWS SageMaker.
Example:
.. code:: python
from sagemaker.train import RLVRTrainer
from sagemaker.train.common import TrainingType
trainer = RLVRTrainer(
model="meta-llama/Llama-2-7b-hf",
training_type=TrainingType.LORA,
model_package_group="my-model-group",
custom_reward_function="arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/JsonDoc/my-evaluator/1.0",
training_dataset="s3://bucket/rlvr_data.jsonl"
)
trainer.train()
# Using a Lambda ARN directly (Evaluator is auto-created):
trainer = RLVRTrainer(
model="meta-llama/Llama-2-7b-hf",
training_type=TrainingType.LORA,
model_package_group="my-model-group",
custom_reward_function="arn:aws:lambda:us-east-1:123456789012:function:my-reward-fn",
training_dataset="s3://bucket/rlvr_data.jsonl"
)
trainer.train()
# Complete workflow: create -> wait -> get model package ARN
trainer = RLVRTrainer(
model="meta-llama/Llama-2-7b-hf",
model_package_group="my-rlvr-models",
custom_reward_function="arn:aws:sagemaker:us-east-1:123456789012:hub-content/SageMakerPublicHub/JsonDoc/my-evaluator/1.0"
)
# Create training job (non-blocking)
training_job = trainer.train(
training_dataset="s3://bucket/rlvr_data.jsonl",
wait=False
)
# Wait for completion
training_job.wait()
# Refresh job status
training_job.refresh()
# Get the fine-tuned model package ARN
model_package_arn = training_job.output_model_package_arn
Parameters:
model (Union[str, ModelPackage]):
The foundation model to fine-tune. Can be a model name string, model package ARN,
or ModelPackage object.
training_type (Union[TrainingType, str]):
The fine-tuning approach. Valid values are TrainingType.LORA (default),
TrainingType.FULL.
model_package_group (Optional[Union[str, ModelPackageGroup]]):
The model package group for storing the fine-tuned model. Can be a group name,
ARN, or ModelPackageGroup object. Required when model is not a ModelPackage.
custom_reward_function (Optional[Union[str, Evaluator]]):
The custom reward function evaluator. Can be an evaluator ARN string, a Lambda
function ARN string, or an Evaluator object. If a Lambda ARN is provided
(e.g., "arn:aws:lambda:us-east-1:123456789012:function:my-reward"), an Evaluator
will be automatically created in the AI Registry and used for training.
Required for RLVR training to provide reward signals.
mlflow_resource_arn (Optional[Union[str, MlflowTrackingServer]]):
The MLflow tracking server ARN for experiment tracking.
If not specified, uses default MLflow experience.
mlflow_experiment_name (Optional[str]):
The MLflow experiment name for organizing runs.
mlflow_run_name (Optional[str]):
The MLflow run name for this training job.
training_dataset (Optional[Union[str, DataSet]]):
The training dataset. Can be a dataset ARN, or DataSet object.
validation_dataset (Optional[Union[str, DataSet]]):
The validation dataset. Can be a dataset ARN, or DataSet object.
s3_output_path (Optional[str]):
The S3 path for training job outputs.
If not specified, defaults to s3://sagemaker-<region>-<account>/output.
kms_key_id (Optional[str]):
The KMS key ID for encrypting training job outputs.
networking (Optional[VpcConfig]):
The VPC configuration for the training job.
stopping_condition (Optional[StoppingCondition]):
The stopping condition to override training runtime limit.
If not specified, uses SageMaker service default (24 hours for serverless training).
is_multimodal (Optional[bool]):
Whether the training dataset contains multimodal data. If None (default),
auto-detected from the training dataset at train time.
"""
_customization_technique = CustomizationTechnique.RLVR.value
def __init__(
self,
model: Union[str, ModelPackage],
training_type: Union[TrainingType, str] = TrainingType.LORA,
model_package_group: Optional[Union[str, ModelPackageGroup]] = None,
custom_reward_function: Optional[Union[str, Evaluator]] = None,
compute: Optional[Union[TrainingJobCompute, HyperPodCompute]] = None,
mlflow_resource_arn: Optional[Union[str, MlflowTrackingServer]] = None,
mlflow_experiment_name: Optional[str] = None,
mlflow_run_name: Optional[str] = None,
training_dataset: Optional[Union[str, DataSet]] = None,
validation_dataset: Optional[Union[str, DataSet]] = None,
s3_output_path: Optional[str] = None,
kms_key_id: Optional[str] = None,
networking: Optional[VpcConfig] = None,
accept_eula: bool = False,
stopping_condition: Optional[StoppingCondition] = None,
recipe: Optional[str] = None,
overrides: Optional[dict] = None,
is_multimodal: Optional[bool] = None,
base_model_name: Optional[str] = None,
disable_output_compression: Optional[bool] = False,
**kwargs,
):
super().__init__(base_model_name=base_model_name, disable_output_compression=disable_output_compression, **kwargs)
self.model, self._model_name, self.model_source = _resolve_model_with_checkpoint(
model, self.base_model_name, compute, self.sagemaker_session,
resolve_fn=_resolve_model_and_name,
)
self.training_type = training_type
self.custom_reward_function = custom_reward_function
self.compute = compute
if compute is not None and not isinstance(compute, (TrainingJobCompute, HyperPodCompute)):
raise TypeError(
f"compute must be a TrainingJobCompute or HyperPodCompute instance, got {type(compute).__name__}"
)
if compute is None:
self.model_package_group = _validate_and_resolve_model_package_group(
model, model_package_group
)
else:
self.model_package_group = model_package_group
self.mlflow_resource_arn = mlflow_resource_arn
self.mlflow_experiment_name = mlflow_experiment_name
self.mlflow_run_name = mlflow_run_name
self.training_dataset = training_dataset
self.validation_dataset = validation_dataset
self.s3_output_path = s3_output_path
self.kms_key_id = kms_key_id
self.networking = networking
self.stopping_condition = stopping_condition
self._recipe_path = recipe
self._overrides = overrides
self._recipe_resolver = None
self._resolved_recipe_cache = None
self.is_multimodal = is_multimodal
# Initialize fine-tuning options with beta session fallback
self.hyperparameters, self._model_arn, is_gated_model = _get_fine_tuning_options_and_model_arn(self._model_name,
CustomizationTechnique.RLVR.value,
self.training_type,
self.sagemaker_session or TrainDefaults.get_sagemaker_session(
sagemaker_session=self.sagemaker_session
),
compute=self.compute)
# Remove constructor-handled hyperparameters
self._process_hyperparameters()
# Validate and set EULA acceptance
self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model)
def _process_hyperparameters(self):
"""Remove hyperparameter keys that are handled by constructor inputs."""
if self.hyperparameters:
# Remove keys that are handled by constructor inputs
if hasattr(self.hyperparameters, 'data_s3_path'):
delattr(self.hyperparameters, 'data_s3_path')
self.hyperparameters._specs.pop('data_s3_path', None)
if hasattr(self.hyperparameters, 'reward_lambda_arn'):
delattr(self.hyperparameters, 'reward_lambda_arn')
self.hyperparameters._specs.pop('reward_lambda_arn', None)
if hasattr(self.hyperparameters, 'preset_reward_function'):
delattr(self.hyperparameters, 'preset_reward_function')
self.hyperparameters._specs.pop('preset_reward_function', None)
if hasattr(self.hyperparameters, 'data_path'):
delattr(self.hyperparameters, 'data_path')
self.hyperparameters._specs.pop('data_path', None)
if hasattr(self.hyperparameters, 'validation_data_path'):
delattr(self.hyperparameters, 'validation_data_path')
self.hyperparameters._specs.pop('validation_data_path', None)
if hasattr(self.hyperparameters, 'output_path'):
delattr(self.hyperparameters, 'output_path')
self.hyperparameters._specs.pop('output_path', None)
def _verify_reward_function(
self,
sample_count: int = 3,
training_dataset: Optional[Union[str, DataSet]] = None,
) -> Dict[str, Any]:
"""Verifies the reward function by invoking it with sample data from the training dataset.
Reads a small number of samples from the training dataset and invokes the configured
reward function (Lambda ARN or local Python file) to validate it returns the expected
output format before submitting a full training job.
Args:
sample_count: Number of samples to read from the training dataset for verification.
Defaults to 3.
training_dataset: Training dataset to read samples from. Can be an S3 URI, dataset
ARN, or DataSet object. If not provided, uses the dataset configured on the
trainer instance.
Returns:
None. Logs the verification result dict on success.
Raises:
ValueError: If the reward function is not configured, no training dataset is
available, or verification fails with detailed error messages.
"""
# Resolve the reward function
reward_function = self.custom_reward_function
if reward_function is None:
raise ValueError(
"Cannot verify reward function: 'custom_reward_function' is not set. "
"Please provide custom_reward_function when initializing RLVRTrainer."
)
is_nova = _is_nova_model(self._model_name)
# If it's an Evaluator object, extract the Lambda ARN (reference)
if isinstance(reward_function, Evaluator):
reward_function = reward_function.reference
if not reward_function:
raise ValueError(
"Cannot verify reward function: Evaluator object does not have a "
"Lambda ARN reference. Verification requires a Lambda ARN or local file path."
)
elif isinstance(reward_function, str) and not _is_lambda_arn(reward_function):
# It's a string but not a Lambda ARN — treat as an evaluator ARN,
# fetch the Evaluator object and extract the Lambda ARN from it.
# Evaluator Lambdas always use OSS format (statusCode/body envelope)
try:
# Parse evaluator name from the ARN
# ARN format: arn:aws:sagemaker:region:account:hub-content/hub/type/name/version
evaluator_name = reward_function.split("/")[-2]
evaluator = Evaluator.get(evaluator_name, sagemaker_session=TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session))
reward_function = evaluator.reference
if not reward_function:
raise ValueError(
f"Evaluator '{evaluator_name}' does not have a Lambda ARN reference. "
"Verification requires a Lambda ARN."
)
logger.info(f"Resolved evaluator ARN to Lambda ARN: {reward_function}")
except ValueError:
raise
except Exception as e:
raise ValueError(
f"Failed to resolve evaluator ARN '{self.custom_reward_function}' "
f"to a Lambda ARN for verification: {str(e)}"
)
# Resolve DataSet object to S3 URI
if isinstance(training_dataset, DataSet):
data_s3_path = training_dataset.source
else:
data_s3_path = training_dataset
# Read sample data from the training dataset
samples: List[Dict[str, Any]] = []
try:
for line in load_file_content(data_s3_path, extension=".jsonl", encoding="utf-8-sig"):
if len(samples) >= sample_count:
break
line = line.strip()
if not line:
continue
try:
sample = json.loads(line)
samples.append(sample)
except json.JSONDecodeError as e:
raise ValueError(
f"Failed to parse JSON from line {len(samples) + 1} in "
f"{data_s3_path}: {str(e)}"
)
except ValueError:
raise
except Exception as e:
raise ValueError(
f"Failed to read samples from {data_s3_path}: {str(e)}\n"
"Please verify the S3 path is correct and you have read permissions."
)
if not samples:
raise ValueError(
f"No samples found in {data_s3_path}. "
"Please ensure the data file contains valid JSONL data."
)
logger.info(f"Verifying reward function with {len(samples)} sample(s)...")
result = verify_reward_function(
reward_function=reward_function,
sample_data=samples,
validate_format=True,
compute=self.compute,
is_nova=is_nova,
)
logger.info(
f"Reward function verification result: {json.dumps(result, indent=2, default=str)}"
)
if not result.get("success"):
raise ValueError(
f"Reward function verification failed: "
f"Details: {json.dumps(result, default=str)}"
)
logger.info(
f"Reward function verification successful: "
f"{result['successful_samples']}/{result['total_samples']} sample(s) passed"
)
[docs]
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLVRTrainer.train")
def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None, poll: int = 5):
"""Execute the RLVR training job.
Parameters:
training_dataset (Optional[Union[str, DataSet]]):
The training dataset for this job. Overrides the dataset specified in __init__.
Can be an S3 URI, dataset ARN, or DataSet object.
validation_dataset (Optional[Union[str, DataSet]]):
The validation dataset for this job. Overrides the dataset specified in __init__.
Can be an S3 URI, dataset ARN, or DataSet object.
wait (bool):
Whether to wait for the training job to complete. Defaults to True.
wait_timeout (Optional[int]):
Maximum time in seconds to wait for the training job to complete. Only used when wait=True.
If None, uses the default timeout from the wait utility.
poll (int):
Polling interval in seconds for checking training job status. Defaults to 5.
Returns:
TrainingJob: The SageMaker training job object.
"""
# Dispatch based on compute type
if isinstance(self.compute, HyperPodCompute):
return self._train_hyperpod(
training_dataset=training_dataset,
validation_dataset=validation_dataset,
wait=wait,
wait_timeout=wait_timeout,
poll=poll,
)
elif isinstance(self.compute, TrainingJobCompute):
return self._train_serverful_smtj(
training_dataset=training_dataset,
validation_dataset=validation_dataset,
wait=wait,
wait_timeout=wait_timeout,
poll=poll,
)
# Default: serverless compute (None)
sagemaker_session = TrainDefaults.get_sagemaker_session(
sagemaker_session=self.sagemaker_session
)
role = TrainDefaults.get_role(role=self.role, sagemaker_session=sagemaker_session)
current_training_job_name = _get_unique_name(
self.base_job_name or f"{self._model_name}-rlvr"
)
logger.info(f"Training Job Name: {current_training_job_name}")
#data
input_data_config = _create_input_data_config(training_dataset or self.training_dataset,
validation_dataset or self.validation_dataset
)
channels = _convert_input_data_to_channels(input_data_config)
output_config = _create_output_config(
s3_output_path=self.s3_output_path,
sagemaker_session=sagemaker_session,
kms_key_id=self.kms_key_id,
disable_output_compression=getattr(self, 'disable_output_compression', False),
)
# Extract and validate evaluator ARN
evaluator_arn = _extract_evaluator_arn(self.custom_reward_function) if self.custom_reward_function else None
serverless_config = _create_serverless_config(model_arn=self._model_arn,
customization_technique=CustomizationTechnique.RLVR.value,
training_type=self.training_type,
accept_eula=self.accept_eula,
evaluator_arn=evaluator_arn,
job_type=JOB_TYPE
)
mlflow_config = _create_mlflow_config(
sagemaker_session,
mlflow_resource_arn=self.mlflow_resource_arn,
mlflow_experiment_name=self.mlflow_experiment_name,
mlflow_run_name=self.mlflow_run_name,
)
final_hyperparameters = self.hyperparameters.to_dict()
# Apply recipe/overrides if provided (overrides > recipe > Hub defaults)
final_hyperparameters = self._apply_recipe_to_hyperparameters(final_hyperparameters)
# Resolve is_multimodal: auto-detect from training dataset if not explicitly set
if self.is_multimodal is None:
effective_training_dataset = training_dataset or self.training_dataset
if effective_training_dataset is not None:
self.is_multimodal = is_multimodal_data(effective_training_dataset)
# Validate hyperparameter values
_validate_hyperparameter_values(final_hyperparameters)
model_package_config = _create_model_package_config(
model_package_group_name=self.model_package_group,
model=self.model,
sagemaker_session=sagemaker_session
)
# Verify reward function before submitting training job
if self.custom_reward_function:
effective_dataset = training_dataset or self.training_dataset
if effective_dataset is not None:
logger.info("Verifying reward function before submitting training job...")
self._verify_reward_function(training_dataset=effective_dataset)
vpc_config = self.networking if self.networking else None
tags = _get_jumpstart_tags(self._model_name, get_sagemaker_hub_name())
# Build TrainingJob.create() arguments
create_args = {
"training_job_name": current_training_job_name,
"role_arn": role,
"input_data_config": channels,
"output_data_config": output_config,
"serverless_job_config": serverless_config,
"mlflow_config": mlflow_config,
"hyper_parameters": final_hyperparameters,
"model_package_config": model_package_config,
"vpc_config": vpc_config,
"session": sagemaker_session.boto_session,
"region": sagemaker_session.boto_session.region_name,
"tags": tags,
}
# Only pass stopping_condition if explicitly provided by user
if self.stopping_condition is not None:
create_args["stopping_condition"] = self.stopping_condition
try:
training_job = TrainingJob.create(**create_args)
except Exception as e:
logger.error("Error: %s", e)
raise e
if wait:
from sagemaker.train.common_utils.trainer_wait import wait as _wait
from sagemaker.core.utils.exceptions import TimeoutExceededError
try:
wait_kwargs = {}
if wait_timeout is not None:
wait_kwargs['timeout'] = wait_timeout
wait_kwargs['poll'] = poll
_wait(training_job, **wait_kwargs)
except TimeoutExceededError as e:
logger.error("Error: %s", e)
self._latest_training_job = training_job
return training_job
def _get_extra_smtj_hyperparameters(self) -> Dict[str, Any]:
"""Return RLVR-specific hyperparameters for SMTJ training.
Injects reward_lambda_arn from the custom_reward_function if set.
"""
extra_hp = {}
if self.custom_reward_function:
reward_fn = self.custom_reward_function
if isinstance(reward_fn, str) and (
reward_fn.startswith("arn:aws:lambda:") or "hub-content" in reward_fn
):
extra_hp["reward_lambda_arn"] = reward_fn
else:
evaluator_arn = _extract_evaluator_arn(reward_fn)
if evaluator_arn:
extra_hp["reward_lambda_arn"] = evaluator_arn
return extra_hp