sagemaker.train.multi_turn_rl_trainer#

MultiTurnRLTrainer — trainer for Agentic Reinforcement Fine-Tuning (Multi-Turn RL) jobs.

Classes

MultiTurnRLTrainer(model, agent_env[, ...])

Trainer for Agentic Reinforcement Fine-Tuning (Multi-Turn RL) jobs.

class sagemaker.train.multi_turn_rl_trainer.MultiTurnRLTrainer(model: str | ModelPackage, agent_env: str | CustomAgentLambda, training_dataset: str | DataSet | None = None, mlflow_app_arn: str | MlflowApp | None = None, s3_output_path: str | None = None, output_model_package_group: str | ModelPackageGroup | None = None, intermediate_checkpoint_model_package_group: str | ModelPackageGroup | None = None, validation_dataset: str | DataSet | None = None, bedrock_agentcore_qualifier: str = 'DEFAULT', mlflow_experiment_name: str | None = None, mlflow_run_name: str | None = None, networking: VpcConfig | None = None, kms_key_arn: str | None = None, accept_eula: bool = False, **kwargs)[source]#

Bases: BaseTrainer

Trainer for Agentic Reinforcement Fine-Tuning (Multi-Turn RL) jobs.

Uses CreateJob API (not CreateTrainingJob) with a JobConfigDocument JSON string.

Example:

from sagemaker.train.multi_turn_rl_trainer import MultiTurnRLTrainer

trainer = MultiTurnRLTrainer(
    model="huggingface-reasoning-qwen3-32b",
    agent_env="arn:aws:bedrock-agentcore::us-west-2:123456789012:runtime/AGENTID",
    training_dataset="s3://my-bucket/",
    output_model_package_group="arn:aws:sagemaker:us-west-2:123456789012:model-package-group/grp",
    mlflow_app_arn="arn:aws:sagemaker:us-west-2:123456789012:mlflow-app/srv",
    s3_output_path="s3://my-bucket/output/",
    accept_eula=True,
)
job = trainer.train()
Parameters:
  • model – JumpStart model ID string or JumpStart hub content Model ARN.

  • agent_env – Bedrock AgentCore ARN, agent runtime ID, Lambda ARN, or CustomAgentLambda. When a bare agent runtime ID is provided (e.g. "myRuntime-aBcDeFgHiJ"), it is resolved to the full ARN via GetAgentRuntime.

  • training_dataset – S3 URI, DataSet object, or DataSet ARN string (optional). Must be provided at __init__ or train() time.

  • mlflow_app_arn – MLflow app ARN or MlflowApp object (optional). If not specified, uses the default MLflow experience.

  • s3_output_path – S3 path for output artifacts (optional). If not specified, defaults to s3://sagemaker-<region>-<account>/output.

  • output_model_package_group – ModelPackageGroup object or ARN string (optional).

  • intermediate_checkpoint_model_package_group – ModelPackageGroup object or ARN string for intermediate checkpoints (optional). If not provided, auto-creates {model_name}-mtrl-checkpoint-mpg. Must differ from output_model_package_group.

  • validation_dataset – S3 URI, DataSet object, or DataSet ARN string (optional).

  • bedrock_agentcore_qualifier – Bedrock AgentCore qualifier (default: "DEFAULT").

  • mlflow_experiment_name – MLflow experiment name (optional).

  • mlflow_run_name – MLflow run name (optional).

  • networking – VpcConfig for the job (optional).

  • kms_key_arn – KMS key ID for output encryption (optional).

  • accept_eula – Boolean for EULA acceptance (optional).

  • **kwargs – Passed to BaseTrainer (sagemaker_session, role, base_job_name, tags).

classmethod attach(job_name: str, session=None) AgentRFTJob[source]#

Attach to an existing Agentic RFT job by name.

Parameters:
  • job_name – The name of the job.

  • session – Optional boto3 session.

Returns:

AgentRFTJob wrapping the existing job.

static list_bedrock_agentcore_runtimes(session=None) list[dict][source]#

List Bedrock AgentCore runtimes.

Parameters:

session – Optional boto3 session.

Returns:

List of dicts, each with keys name, runtime_id, arn, and status.

static list_supported_models(session=None) list[str][source]#

Return the list of supported model names.

Queries SageMakerPublicHub to discover all models with MTRL recipes in their RecipeCollection.

Parameters:

session – Optional boto3 session.

Returns:

List of hub content model names supporting MTRL.

property output_model_package_arn: str | None#

The output model package ARN from the latest completed training job.

train(training_dataset: str | DataSet | None = None, wait: bool = True) AgentRFTJob[source]#

Launch an Agentic RFT job.

Parameters:
  • training_dataset – Training dataset override.

  • wait – If True (default), block until job reaches terminal status.

Returns:

AgentRFTJob instance for tracking the job.