sagemaker.train.multi_turn_rl_trainer#
MultiTurnRLTrainer — trainer for Agentic Reinforcement Fine-Tuning (Multi-Turn RL) jobs.
Classes
|
Trainer for Agentic Reinforcement Fine-Tuning (Multi-Turn RL) jobs. |
- class sagemaker.train.multi_turn_rl_trainer.MultiTurnRLTrainer(model: str | ModelPackage, agent_env: str | CustomAgentLambda, training_dataset: str | DataSet | None = None, mlflow_app_arn: str | MlflowApp | None = None, s3_output_path: str | None = None, output_model_package_group: str | ModelPackageGroup | None = None, intermediate_checkpoint_model_package_group: str | ModelPackageGroup | None = None, validation_dataset: str | DataSet | None = None, bedrock_agentcore_qualifier: str = 'DEFAULT', mlflow_experiment_name: str | None = None, mlflow_run_name: str | None = None, networking: VpcConfig | None = None, kms_key_arn: str | None = None, accept_eula: bool = False, **kwargs)[source]#
Bases:
BaseTrainerTrainer for Agentic Reinforcement Fine-Tuning (Multi-Turn RL) jobs.
Uses CreateJob API (not CreateTrainingJob) with a JobConfigDocument JSON string.
Example:
from sagemaker.train.multi_turn_rl_trainer import MultiTurnRLTrainer trainer = MultiTurnRLTrainer( model="huggingface-reasoning-qwen3-32b", agent_env="arn:aws:bedrock-agentcore::us-west-2:123456789012:runtime/AGENTID", training_dataset="s3://my-bucket/", output_model_package_group="arn:aws:sagemaker:us-west-2:123456789012:model-package-group/grp", mlflow_app_arn="arn:aws:sagemaker:us-west-2:123456789012:mlflow-app/srv", s3_output_path="s3://my-bucket/output/", accept_eula=True, ) job = trainer.train()
- Parameters:
model – JumpStart model ID string or JumpStart hub content Model ARN.
agent_env – Bedrock AgentCore ARN, agent runtime ID, Lambda ARN, or CustomAgentLambda. When a bare agent runtime ID is provided (e.g.
"myRuntime-aBcDeFgHiJ"), it is resolved to the full ARN viaGetAgentRuntime.training_dataset – S3 URI, DataSet object, or DataSet ARN string (optional). Must be provided at
__init__ortrain()time.mlflow_app_arn – MLflow app ARN or MlflowApp object (optional). If not specified, uses the default MLflow experience.
s3_output_path – S3 path for output artifacts (optional). If not specified, defaults to
s3://sagemaker-<region>-<account>/output.output_model_package_group – ModelPackageGroup object or ARN string (optional).
intermediate_checkpoint_model_package_group – ModelPackageGroup object or ARN string for intermediate checkpoints (optional). If not provided, auto-creates
{model_name}-mtrl-checkpoint-mpg. Must differ fromoutput_model_package_group.validation_dataset – S3 URI, DataSet object, or DataSet ARN string (optional).
bedrock_agentcore_qualifier – Bedrock AgentCore qualifier (default:
"DEFAULT").mlflow_experiment_name – MLflow experiment name (optional).
mlflow_run_name – MLflow run name (optional).
networking – VpcConfig for the job (optional).
kms_key_arn – KMS key ID for output encryption (optional).
accept_eula – Boolean for EULA acceptance (optional).
**kwargs – Passed to BaseTrainer (sagemaker_session, role, base_job_name, tags).
- classmethod attach(job_name: str, session=None) AgentRFTJob[source]#
Attach to an existing Agentic RFT job by name.
- Parameters:
job_name – The name of the job.
session – Optional boto3 session.
- Returns:
AgentRFTJob wrapping the existing job.
- static list_bedrock_agentcore_runtimes(session=None) list[dict][source]#
List Bedrock AgentCore runtimes.
- Parameters:
session – Optional boto3 session.
- Returns:
List of dicts, each with keys
name,runtime_id,arn, andstatus.
- static list_supported_models(session=None) list[str][source]#
Return the list of supported model names.
Queries SageMakerPublicHub to discover all models with MTRL recipes in their
RecipeCollection.- Parameters:
session – Optional boto3 session.
- Returns:
List of hub content model names supporting MTRL.
- property output_model_package_arn: str | None#
The output model package ARN from the latest completed training job.
- train(training_dataset: str | DataSet | None = None, wait: bool = True) AgentRFTJob[source]#
Launch an Agentic RFT job.
- Parameters:
training_dataset – Training dataset override.
wait – If True (default), block until job reaches terminal status.
- Returns:
AgentRFTJob instance for tracking the job.