Source code for sagemaker.train.multi_turn_rl_trainer

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

"""MultiTurnRLTrainer — trainer for Agentic Reinforcement Fine-Tuning (Multi-Turn RL) jobs."""
from __future__ import annotations

import json
import logging
import re
from typing import Optional, Union

import boto3

from sagemaker.ai_registry.dataset import DataSet
from sagemaker.core.resources import Job, ModelPackageGroup, ModelPackage, MlflowApp
from sagemaker.core.shapes import VpcConfig
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
from sagemaker.train.custom_agent_lambda import CustomAgentLambda
from sagemaker.train.agent_rft_job import AgentRFTJob
from sagemaker.train.base_trainer import BaseTrainer
from sagemaker.train.common import CustomizationTechnique
from sagemaker.train.common_utils.finetune_utils import (
    _get_default_s3_output_path,
    _get_fine_tuning_options_and_model_arn,
    _resolve_mlflow_resource_arn,
    _resolve_model_and_name,
    _resolve_model_package_arn,
    _validate_eula_for_gated_model,
    _validate_hyperparameter_values,
    _validate_s3_path_exists,
)
from sagemaker.train.common_utils.constants import MIN_MLFLOW_VERSION
from sagemaker.train.common_utils.recipe_utils import _list_hub_models_by_recipe, _is_nova_model
from sagemaker.train.constants import get_sagemaker_hub_name
from sagemaker.train.defaults import TrainDefaults
from sagemaker.train.utils import _get_unique_name, _get_studio_tags

logger = logging.getLogger(__name__)



# ARN patterns
BEDROCK_AGENT_CORE_ARN_PATTERN = re.compile(
    r"^arn:aws[a-z-]*:bedrock-agentcore:[a-z0-9-]+:[0-9]{12}:runtime/[a-zA-Z0-9_-]+$"
)
LAMBDA_ARN_PATTERN = re.compile(
    r"^arn:aws[a-z-]*:lambda:[a-z0-9-]+:[0-9]{12}:function:[a-zA-Z0-9-_.]+"
    r"(:\$LATEST|:[a-zA-Z0-9-_]+)?$"
)
S3_URI_PATTERN = re.compile(r"^s3://[^/]+(/.*)?$")
MLFLOW_APP_ARN_PATTERN = re.compile(
    r"^arn:[a-z0-9-.]+:sagemaker:[^:]+:[^:]+:mlflow-app/.+$"
)

# Pattern for bare Bedrock AgentCore runtime IDs (not full ARNs).
AGENT_RUNTIME_ID_PATTERN = re.compile(r"^[a-zA-Z][a-zA-Z0-9_]{0,99}-[a-zA-Z0-9]{10}$")

MAX_HYPERPARAMETERS = 50
# Intentionlly hardcode this version for each PySDK version. 
# If we need upgrade the schema version, it should upgrade PySDK version as well.
JOB_CONFIG_SCHEMA_VERSION = "1.0.0" 
JOB_CATEGORY = "AgentRFT"
MTRL_TECHNIQUE = "MTRL"


def _resolve_agent_runtime_arn(agent_runtime_id: str, session=None) -> str:
    """Resolve a bare agent runtime ID to its full ARN via GetAgentRuntime.

    Args:
        agent_runtime_id: The agent runtime ID (e.g. ``"myRuntime-aBcDeFgHiJ"``).
        session: Optional boto3 session.

    Returns:
        The full Bedrock AgentCore runtime ARN.

    Raises:
        ValueError: If the runtime cannot be found or the API call fails.
    """
    try:
        client = (session or boto3.Session()).client("bedrock-agentcore-control")
        response = client.get_agent_runtime(agentRuntimeId=agent_runtime_id)
        arn = response.get("agentRuntimeArn")
        if not arn:
            raise ValueError(
                f"GetAgentRuntime returned no ARN for runtime ID '{agent_runtime_id}'."
            )
        logger.info("Resolved agent runtime ID '%s' to ARN '%s'", agent_runtime_id, arn)
        return arn
    except Exception as e:
        if "agentRuntimeArn" not in str(e):
            raise ValueError(
                f"Failed to resolve agent runtime ID '{agent_runtime_id}': {e}"
            ) from e
        raise


def _list_all_mtrl_models(session=None) -> list[str]:
    """List all models in SageMakerPublicHub that support the MTRL technique.

    Delegates to :func:`_list_hub_models_by_recipe` with
    ``recipe_type="FineTuning"`` and ``technique="MTRL"``.

    Args:
        session: Optional boto3 session.

    Returns:
        Sorted list of hub content model names supporting MTRL.
    """
    return _list_hub_models_by_recipe(
        recipe_type="FineTuning", technique=MTRL_TECHNIQUE, session=session
    )


[docs] class MultiTurnRLTrainer(BaseTrainer): """Trainer for Agentic Reinforcement Fine-Tuning (Multi-Turn RL) jobs. Uses CreateJob API (not CreateTrainingJob) with a JobConfigDocument JSON string. Example: .. code:: python from sagemaker.train.multi_turn_rl_trainer import MultiTurnRLTrainer trainer = MultiTurnRLTrainer( model="huggingface-reasoning-qwen3-32b", agent_env="arn:aws:bedrock-agentcore::us-west-2:123456789012:runtime/AGENTID", training_dataset="s3://my-bucket/", output_model_package_group="arn:aws:sagemaker:us-west-2:123456789012:model-package-group/grp", mlflow_app_arn="arn:aws:sagemaker:us-west-2:123456789012:mlflow-app/srv", s3_output_path="s3://my-bucket/output/", accept_eula=True, ) job = trainer.train() Args: model: JumpStart model ID string or JumpStart hub content Model ARN. agent_env: Bedrock AgentCore ARN, agent runtime ID, Lambda ARN, or CustomAgentLambda. When a bare agent runtime ID is provided (e.g. ``"myRuntime-aBcDeFgHiJ"``), it is resolved to the full ARN via ``GetAgentRuntime``. training_dataset: S3 URI, DataSet object, or DataSet ARN string (optional). Must be provided at ``__init__`` or ``train()`` time. mlflow_app_arn: MLflow app ARN or MlflowApp object (optional). If not specified, uses the default MLflow experience. s3_output_path: S3 path for output artifacts (optional). If not specified, defaults to ``s3://sagemaker-<region>-<account>/output``. output_model_package_group: ModelPackageGroup object or ARN string (optional). intermediate_checkpoint_model_package_group: ModelPackageGroup object or ARN string for intermediate checkpoints (optional). If not provided, auto-creates ``{model_name}-mtrl-checkpoint-mpg``. Must differ from ``output_model_package_group``. validation_dataset: S3 URI, DataSet object, or DataSet ARN string (optional). bedrock_agentcore_qualifier: Bedrock AgentCore qualifier (default: ``"DEFAULT"``). mlflow_experiment_name: MLflow experiment name (optional). mlflow_run_name: MLflow run name (optional). networking: VpcConfig for the job (optional). kms_key_arn: KMS key ID for output encryption (optional). accept_eula: Boolean for EULA acceptance (optional). **kwargs: Passed to BaseTrainer (sagemaker_session, role, base_job_name, tags). """ def __init__( self, model: Union[str, ModelPackage], agent_env: Union[str, CustomAgentLambda], training_dataset: Optional[Union[str, DataSet]] = None, mlflow_app_arn: Optional[Union[str, MlflowApp]] = None, s3_output_path: Optional[str] = None, output_model_package_group: Optional[Union[str, ModelPackageGroup]] = None, intermediate_checkpoint_model_package_group: Optional[Union[str, ModelPackageGroup]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, bedrock_agentcore_qualifier: str = "DEFAULT", mlflow_experiment_name: Optional[str] = None, mlflow_run_name: Optional[str] = None, networking: Optional[VpcConfig] = None, kms_key_arn: Optional[str] = None, accept_eula: bool = False, **kwargs, ): super().__init__(**kwargs) self._validate_agent_config(agent_env) self._validate_networking(networking) # Resolve bare agent runtime ID to full ARN if ( isinstance(agent_env, str) and not agent_env.startswith("arn:") and AGENT_RUNTIME_ID_PATTERN.match(agent_env) ): agent_env = _resolve_agent_runtime_arn(agent_env) self.model, self._model_name = _resolve_model_and_name(model, self.sagemaker_session) self.agent_env = agent_env self.bedrock_agentcore_qualifier = bedrock_agentcore_qualifier self.training_dataset = training_dataset self.validation_dataset = validation_dataset self.output_model_package_group = output_model_package_group self.mlflow_app_arn = mlflow_app_arn if isinstance(mlflow_app_arn, str) and not MLFLOW_APP_ARN_PATTERN.match(mlflow_app_arn): raise ValueError( f"Invalid mlflow_app_arn: '{mlflow_app_arn}'. " "Must match pattern: arn:<partition>:sagemaker:<region>:<account>:mlflow-app/<name>" ) self.s3_output_path = s3_output_path self.mlflow_experiment_name = mlflow_experiment_name self.mlflow_run_name = mlflow_run_name self.networking = networking self.kms_key_arn = kms_key_arn session = self.sagemaker_session or TrainDefaults.get_sagemaker_session( sagemaker_session=self.sagemaker_session ) # Resolve defaults for optional parameters if s3_output_path is None: self.s3_output_path = _get_default_s3_output_path(session) logger.info("Using default S3 output path: %s", self.s3_output_path) _validate_s3_path_exists(self.s3_output_path, session) self.output_model_package_group = self._resolve_model_package_group( model, output_model_package_group, session ) self.intermediate_checkpoint_model_package_group = ( self._resolve_intermediate_checkpoint_mpg( intermediate_checkpoint_model_package_group, session ) ) self.hyperparameters, self._model_arn, is_gated_model = ( _get_fine_tuning_options_and_model_arn( self._model_name, MTRL_TECHNIQUE, "LORA", session ) ) self.accept_eula = _validate_eula_for_gated_model(model, accept_eula, is_gated_model) self._process_hyperparameters() self._latest_job: AgentRFTJob | None = None
[docs] @_telemetry_emitter( feature=Feature.MODEL_CUSTOMIZATION, func_name="MultiTurnRLTrainer.train" ) def train( self, training_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, ) -> AgentRFTJob: """Launch an Agentic RFT job. Args: training_dataset: Training dataset override. wait: If True (default), block until job reaches terminal status. Returns: AgentRFTJob instance for tracking the job. """ sagemaker_session = TrainDefaults.get_sagemaker_session( sagemaker_session=self.sagemaker_session ) role = TrainDefaults.get_role(role=self.role, sagemaker_session=sagemaker_session) current_job_name = _get_unique_name( self.base_job_name or f"{self._model_name}-mtrl" ) logger.info(f"Job Name: {current_job_name}") self._final_hyperparameters = self.hyperparameters.to_dict() _validate_hyperparameter_values(self._final_hyperparameters) if training_dataset is not None: self.training_dataset = training_dataset job_config_doc = self._build_job_config_document() tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name()) try: job = Job.create( job_name=current_job_name, job_category=JOB_CATEGORY, role_arn=role, job_config_schema_version=JOB_CONFIG_SCHEMA_VERSION, job_config_document=job_config_doc, session=sagemaker_session.boto_session, region=sagemaker_session.boto_session.region_name, ) except Exception as e: logger.error("Error: %s", e) raise agent_rft_job = AgentRFTJob.from_job(job) logger.info(f"Created Job: {agent_rft_job.job_arn}") hp = self._final_hyperparameters agent_rft_job.description = f"Multi-turn RFT training using {self._model_name}" if wait: from sagemaker.core.utils.exceptions import TimeoutExceededError try: agent_rft_job.wait() except TimeoutExceededError as e: logger.error("Error: %s", e) self._latest_job = agent_rft_job return agent_rft_job
@property def output_model_package_arn(self) -> str | None: """The output model package ARN from the latest completed training job.""" if self._latest_job is not None: return self._latest_job.output_model_package_arn return None
[docs] @classmethod @_telemetry_emitter( feature=Feature.MODEL_CUSTOMIZATION, func_name="MultiTurnRLTrainer.attach" ) def attach(cls, job_name: str, session=None) -> AgentRFTJob: """Attach to an existing Agentic RFT job by name. Args: job_name: The name of the job. session: Optional boto3 session. Returns: AgentRFTJob wrapping the existing job. """ return AgentRFTJob.get(job_name=job_name, session=session)
# ---- Private: JobConfigDocument construction ---- def _build_job_config_document(self) -> str: """Build the JobConfigDocument JSON string conforming to v1_0_0 schema.""" config = { "AgentConfig": self._build_agent_config(), "InputDataConfig": self._build_input_data_config(), "OutputDataConfig": self._build_output_data_config(), "ModelPackageConfig": self._build_model_package_config(), "TrainingConfig": self._build_training_config(), } if self.networking: config["VpcConfig"] = { "SecurityGroupIds": self.networking.security_group_ids, "Subnets": self.networking.subnets, } doc = json.dumps(config, indent=2) logger.info(f"JobConfigDocument:\n{doc}") return doc def _build_agent_config(self) -> dict: agent_env = self.agent_env if isinstance(agent_env, CustomAgentLambda): return { "CustomAgentLambdaConfig": {"LambdaArn": agent_env.lambda_arn}, } if BEDROCK_AGENT_CORE_ARN_PATTERN.match(agent_env): config = {"AgentRuntimeArn": agent_env} if self.bedrock_agentcore_qualifier: config["Qualifier"] = self.bedrock_agentcore_qualifier return {"BedrockAgentCoreConfig": config} if LAMBDA_ARN_PATTERN.match(agent_env): return { "CustomAgentLambdaConfig": {"LambdaArn": agent_env}, } raise ValueError(f"Unrecognized agent config: {agent_env}") def _build_input_data_config(self) -> list: channels = [self._resolve_channel("train", self.training_dataset)] if self.validation_dataset is not None: channels.append(self._resolve_channel("validation", self.validation_dataset)) return channels @staticmethod def _resolve_channel(channel_name: str, data) -> dict: if isinstance(data, DataSet): return { "ChannelName": channel_name, "DataSource": {"DatasetSource": {"DatasetArn": data.arn}}, } if isinstance(data, str) and S3_URI_PATTERN.match(data): return { "ChannelName": channel_name, "DataSource": { "S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": data} }, } # Assume DataSet ARN string return { "ChannelName": channel_name, "DataSource": {"DatasetSource": {"DatasetArn": data}}, } def _build_output_data_config(self) -> dict: config = {"S3OutputPath": self.s3_output_path} if self.kms_key_arn: config["KmsKeyArn"] = self.kms_key_arn return config def _build_model_package_config(self) -> dict: arn = ( self.output_model_package_group.model_package_group_arn if isinstance(self.output_model_package_group, ModelPackageGroup) else self.output_model_package_group ) config = {"OutputModelPackageGroupArn": arn} if isinstance(self.model, ModelPackage): source_arn = _resolve_model_package_arn(self.model) if source_arn: config["InputModelPackageArn"] = source_arn config["IntermediateCheckpointModelPackageGroupArn"] = ( self.intermediate_checkpoint_model_package_group ) return config def _build_training_config(self) -> dict: hyperparameters = getattr(self, "_final_hyperparameters", {}) config = { "BaseModelArn": self._model_arn, } mlflow_config = self._build_mlflow_config() if mlflow_config: config["MlflowConfig"] = mlflow_config if self.accept_eula is not None: config["AcceptEula"] = self.accept_eula if hyperparameters: # Only send hyperparameters the user explicitly changed defaults = getattr(self, "_hp_defaults", {}) user_set = {k: v for k, v in hyperparameters.items() if v != defaults.get(k)} if user_set: config["HyperParameters"] = user_set return config def _build_mlflow_config(self) -> Optional[dict]: arn = ( self.mlflow_app_arn.arn if isinstance(self.mlflow_app_arn, MlflowApp) else self.mlflow_app_arn ) if not arn: session = self.sagemaker_session or TrainDefaults.get_sagemaker_session( sagemaker_session=self.sagemaker_session ) arn = _resolve_mlflow_resource_arn(session, None, min_mlflow_version=MIN_MLFLOW_VERSION) if not arn: return None logger.info("MLflow resource ARN: %s", arn) config = {"MlflowResourceArn": arn} if self.mlflow_experiment_name: config["MlflowExperimentName"] = self.mlflow_experiment_name if self.mlflow_run_name: config["MlflowRunName"] = self.mlflow_run_name return config def _process_hyperparameters(self): """Snapshot defaults for MTRL so we only send user-changed values.""" if not self.hyperparameters or not hasattr(self.hyperparameters, "_specs"): return self._hp_defaults = self.hyperparameters.to_dict().copy() # ---- Validation ----
[docs] @staticmethod @_telemetry_emitter( feature=Feature.MODEL_CUSTOMIZATION, func_name="MultiTurnRLTrainer.list_supported_models", ) def list_supported_models(session=None) -> list[str]: """Return the list of supported model names. Queries SageMakerPublicHub to discover all models with MTRL recipes in their ``RecipeCollection``. Args: session: Optional boto3 session. Returns: List of hub content model names supporting MTRL. """ return _list_all_mtrl_models(session=session)
[docs] @staticmethod @_telemetry_emitter( feature=Feature.MODEL_CUSTOMIZATION, func_name="MultiTurnRLTrainer.list_bedrock_agentcore_runtimes", ) def list_bedrock_agentcore_runtimes(session=None) -> list[dict]: """List Bedrock AgentCore runtimes. Args: session: Optional boto3 session. Returns: List of dicts, each with keys ``name``, ``runtime_id``, ``arn``, and ``status``. """ client = (session or boto3.Session()).client("bedrock-agentcore-control") runtimes: list[dict] = [] next_token = None while True: kwargs: dict = {} if next_token: kwargs["nextToken"] = next_token response = client.list_agent_runtimes(**kwargs) for rt in response.get("agentRuntimes", []): entry = { "name": rt.get("agentRuntimeName", ""), "runtime_id": rt.get("agentRuntimeId", ""), "arn": rt["agentRuntimeArn"], "status": rt.get("status", ""), } runtimes.append(entry) next_token = response.get("nextToken") if not next_token: break return runtimes
@staticmethod def _validate_agent_config(agent_env): if isinstance(agent_env, CustomAgentLambda): return if not isinstance(agent_env, str): raise ValueError( f"agent_env must be a string ARN, agent runtime ID, or CustomAgentLambda, " f"got {type(agent_env).__name__}." ) if not ( BEDROCK_AGENT_CORE_ARN_PATTERN.match(agent_env) or LAMBDA_ARN_PATTERN.match(agent_env) or AGENT_RUNTIME_ID_PATTERN.match(agent_env) ): raise ValueError( f"Invalid agent_env: '{agent_env}'. " "Must be a Bedrock AgentCore ARN, Lambda ARN, agent runtime ID, " "or CustomAgentLambda." ) @staticmethod def _validate_networking(vpc): if vpc is None: return sg = getattr(vpc, "security_group_ids", None) subnets = getattr(vpc, "subnets", None) if not sg or not subnets: raise ValueError( "VPC config requires both non-empty 'security_group_ids' and 'subnets'." ) def _get_or_create_mpg(self, value, default_name: str, session, managed_configuration=None) -> str: """Resolve an existing ModelPackageGroup or auto-create one. If ``value`` is provided (object or string), validates it exists and returns its ARN. If ``value`` is None, creates a ModelPackageGroup with ``default_name`` (get-or-create). Returns: The ModelPackageGroup ARN. """ if value: if isinstance(value, ModelPackageGroup): return value.model_package_group_arn mpg = ModelPackageGroup.get( model_package_group_name=value, session=session.boto_session, region=session.boto_session.region_name, ) return mpg.model_package_group_arn # Auto-create (get-or-create with deterministic name) logger.info("Auto-resolving ModelPackageGroup: %s", default_name) try: mpg = ModelPackageGroup.get( model_package_group_name=default_name, session=session.boto_session, region=session.boto_session.region_name, ) except Exception: try: create_kwargs = { "model_package_group_name": default_name, "session": session.boto_session, "region": session.boto_session.region_name, } if managed_configuration: create_kwargs["managed_configuration"] = managed_configuration mpg = ModelPackageGroup.create(**create_kwargs) logger.info("Created ModelPackageGroup: %s", mpg.model_package_group_arn) except Exception as e: raise ValueError( f"Failed to create ModelPackageGroup '{default_name}': {e}" ) from e return mpg.model_package_group_arn def _resolve_model_package_group(self, model, output_model_package_group, session): """Resolve, validate, or auto-create the output ModelPackageGroup. Resolution order: 1. If ``output_model_package_group`` is provided, validates it exists. 2. If ``model`` is a ModelPackage, derives the group from it. 3. Otherwise, auto-creates ``{model_name}-mtrl-mpg`` (get-or-create). Returns: The ModelPackageGroup ARN. """ if output_model_package_group: return self._get_or_create_mpg(output_model_package_group, None, session) # Derive from ModelPackage if isinstance(model, ModelPackage): group_name = model.model_package_group_name if group_name: return self._get_or_create_mpg(group_name, None, session) managed_config = None if _is_nova_model(self._model_name): from sagemaker.core.shapes import ManagedConfiguration managed_config = ManagedConfiguration(managed_storage_type="Restricted") return self._get_or_create_mpg( None, f"{self._model_name}-mtrl-mpg", session, managed_configuration=managed_config ) def _resolve_intermediate_checkpoint_mpg(self, intermediate_checkpoint_mpg, session) -> str: """Resolve or auto-create the intermediate checkpoint ModelPackageGroup. If provided, validates it exists. Otherwise auto-creates ``{model_name}-mtrl-checkpoint-mpg`` (get-or-create). Raises ValueError if the resolved ARN is the same as ``output_model_package_group``. Returns: The ModelPackageGroup ARN. """ managed_config = None if not intermediate_checkpoint_mpg and _is_nova_model(self._model_name): from sagemaker.core.shapes import ManagedConfiguration managed_config = ManagedConfiguration(managed_storage_type="Restricted") arn = self._get_or_create_mpg( intermediate_checkpoint_mpg, f"{self._model_name}-mtrl-checkpoint-mpg", session, managed_configuration=managed_config, ) if arn == self.output_model_package_group: raise ValueError( "intermediate_checkpoint_model_package_group must differ from " "output_model_package_group." ) return arn