"""MultiTurnRLEvaluator — evaluate MTRL agents on held-out prompts.
This module implements :class:`MultiTurnRLEvaluator`, the SDK surface for
evaluating Multi-Turn Reinforcement Learning (MTRL) agent models via the
AgentRFT ``CreateJob`` pipeline step. Mirrors the architecture of
:class:`sagemaker.train.evaluate.BenchMarkEvaluator`, with MTRL-specific
fields, validators, and the three-template rendering surface defined in
:mod:`sagemaker.train.evaluate.mtrl_pipeline_templates`.
"""
from __future__ import absolute_import
import logging
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from pydantic import Field, root_validator, validator
from .base_evaluator import BaseEvaluator
from .constants import EvalType
from .mtrl_pipeline_templates import (
MTRL_TEMPLATE,
MTRL_TEMPLATE_BASE_MODEL_ONLY,
MTRL_TEMPLATE_FINE_TUNED_ONLY,
)
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
if TYPE_CHECKING:
from .execution import MTRLEvaluationExecution
_logger = logging.getLogger(__name__)
# Validation patterns.
_BEDROCK_AGENTCORE_ARN_RE = re.compile(
r"^arn:aws[a-z\-]*:bedrock-agentcore:[a-z0-9\-]+:[0-9]{12}:(?:agent-runtime|runtime)/.+$"
)
_LAMBDA_ARN_RE = re.compile(
r"^arn:aws[a-z\-]*:lambda:[a-z0-9\-]+:[0-9]{12}:function:.+$"
)
# Stopping-condition bounds (seconds): 0 < v <= 72 hours.
_MAX_STOPPING_CONDITION_SECONDS = 72 * 60 * 60
[docs]
class MultiTurnRLEvaluator(BaseEvaluator):
"""Evaluate a multi-turn RL agent model against a held-out prompt dataset.
The evaluator runs rollouts of the agent against an environment
(Bedrock AgentCore runtime or a Lambda-wrapped agent) and computes
aggregate metrics (pass@k, mean reward, etc.). Execution routes through
SageMaker Pipelines using the new AgentRFT ``Job`` step type
(``JobCategory="AgentRFTEvaluation"``).
The evaluator supports three evaluation shapes, selected automatically
based on the provided inputs:
* **Base-model only** — pass a base model (JumpStart ID or ModelPackage)
with an explicit ``agent_config``.
* **Fine-tuned only** — pass a ``MultiTurnRLTrainer`` or a
fine-tuned ``ModelPackage``; the evaluator extracts the source model
package ARN and evaluates it only.
* **Base + fine-tuned comparison** — pass ``evaluate_base_model=True``
along with a fine-tuned trainer / ModelPackage; both runs land in
the same MLflow experiment for side-by-side comparison.
Attributes:
dataset (Union[str, Any]): Prompt dataset — S3 URI, hub-content
DataSet ARN, or object exposing an ``.arn`` attribute.
Required.
agent_config (Optional[Union[str, Any]]): Agent environment —
Bedrock AgentCore ARN or Lambda ARN. Auto-resolved from a
``MultiTurnRLTrainer`` when provided as ``model``.
agent_qualifier (Optional[str]): Bedrock AgentCore qualifier
(e.g. ``"PROD"``). Ignored when ``agent_config`` is a Lambda.
accept_eula (bool): Forwarded to
``JobConfigDocument.EvaluationConfig.AcceptEula``. Defaults
to ``True`` (templates emit ``true`` unconditionally; flag
kept for future backend schemas).
evaluate_base_model (bool): When ``True`` and a fine-tuned model is
present, render the comparison template (both base and
fine-tuned are evaluated). Defaults to ``False`` — fine-tuned
only.
stopping_condition (int): Maximum job duration in seconds. Default
``86400`` (24 hours); must be in ``(0, 259200]``.
tags (Optional[List[Dict[str, str]]]): Customer tags propagated to
the pipeline + step ``Tags`` list.
See :class:`BaseEvaluator` for inherited fields (``model``,
``s3_output_path``, ``mlflow_resource_arn``,
``mlflow_experiment_name``, ``networking``, ``kms_key_id``,
``model_package_group``, ``base_eval_name``, ``region``, ``role``,
``sagemaker_session``).
Example:
.. code:: python
from sagemaker.train.evaluate import MultiTurnRLEvaluator
# Evaluate a fine-tuned MTRL trainer output
evaluator = MultiTurnRLEvaluator(
model=completed_mtrl_trainer,
dataset='s3://my-bucket/eval-prompts.jsonl',
s3_output_path='s3://my-bucket/mtrl-eval-output/',
)
execution = evaluator.evaluate()
execution.wait()
execution.show_results()
"""
# --- Declared fields -------------------------------------------------
dataset: Any = Field(..., description="Prompt dataset (S3 URI, ARN, or object with .arn).")
agent_config: Optional[Any] = Field(default=None, description="Agent environment.")
agent_qualifier: Optional[str] = Field(default=None, description="Bedrock AgentCore qualifier.")
accept_eula: bool = Field(default=True, description="Accept EULA for the base model.")
evaluate_base_model: bool = Field(
default=False,
description="When True, render the base + fine-tuned comparison template.",
)
stopping_condition: int = Field(
default=86400,
description="Maximum job duration in seconds; must be in (0, 259200].",
)
tags: Optional[List[Dict[str, str]]] = Field(default=None, description="Customer tags.")
# Private instance state (populated during resolution).
_base_model_arn_cache: Optional[str] = None
_base_model_name_cache: Optional[str] = None
_source_model_package_arn_cache: Optional[str] = None
_agent_arn_resolved: Optional[str] = None
_agent_kind: Optional[str] = None # "bedrock" | "lambda"
_hyperparameters: Optional[Any] = None
# --- Validators ------------------------------------------------------
@validator("dataset", pre=True, always=True)
def _resolve_dataset(cls, v):
if v is None:
raise ValueError(
"[PySDK Error] 'dataset' is required. Accepted: S3 URI "
"(s3://...), hub-content DataSet ARN, or an object with an "
"`.arn` attribute."
)
return BaseEvaluator._validate_and_resolve_dataset(v)
@validator("agent_config", pre=True, always=True)
def _resolve_agent_config(cls, v):
if v is None:
return None
# AgentLambdaAdapter-like object with .materialize() → ARN is deferred
# to evaluate() time; here we only accept pre-resolved strings.
if not isinstance(v, str):
# Pass through non-string objects; evaluate() will materialize.
return v
if _BEDROCK_AGENTCORE_ARN_RE.match(v) or _LAMBDA_ARN_RE.match(v):
return v
raise ValueError(
f"[PySDK Error] 'agent_config' value '{v}' is not a recognized "
f"Bedrock AgentCore ARN or Lambda ARN."
)
@validator("stopping_condition", always=True)
def _validate_stopping_condition(cls, v):
if v is None:
return 86400
if v <= 0:
raise ValueError(
f"[PySDK Error] 'stopping_condition' must be > 0; got {v}."
)
if v > _MAX_STOPPING_CONDITION_SECONDS:
raise ValueError(
f"[PySDK Error] 'stopping_condition' must be <= "
f"{_MAX_STOPPING_CONDITION_SECONDS} seconds (72 hours); "
f"got {v}."
)
return v
@root_validator(skip_on_failure=True)
def _check_agent_config_for_non_trainer_models(cls, values):
"""When the model is not a ``MultiTurnRLTrainer``, require ``agent_config``.
When the customer passes a trainer instance, the evaluator
auto-resolves the agent config from the trainer's stored
configuration. For any other model type (string JumpStart ID,
``ModelPackage`` object, ModelPackage ARN string) the customer
must supply ``agent_config`` explicitly.
"""
model = values.get("model")
agent_config = values.get("agent_config")
if agent_config is not None:
return values
# Avoid a hard import cycle on MultiTurnRLTrainer; check by class name.
model_cls_name = type(model).__name__ if model is not None else ""
if model_cls_name not in ("MultiTurnRLTrainer", "AgentRFTJob"):
raise ValueError(
"[PySDK Error] 'agent_config' is required when 'model' is "
"not a MultiTurnRLTrainer. Provide a Bedrock AgentCore ARN "
"or a Lambda ARN."
)
return values
# --- Trainer / model resolution -------------------------------------
def _resolve_trainer_defaults(self) -> None:
"""Pull base/source ARNs and agent defaults from a MultiTurnRLTrainer.
Idempotent; re-reading a trainer yields the same values. Customer-
provided ``agent_config`` / ``agent_qualifier`` always win over
trainer-sourced values.
"""
if type(self.model).__name__ not in ("MultiTurnRLTrainer", "AgentRFTJob"):
return
trainer = self.model
# Resolve the output model package ARN from the completed job.
# MultiTurnRLTrainer stores the job in _latest_job (AgentRFTJob),
# which exposes output_model_package_arn as a property.
source_mp = (
getattr(trainer, "output_model_package_arn", None)
or getattr(trainer, "model_package_arn", None)
)
if not source_mp and hasattr(trainer, "_latest_job") and trainer._latest_job is not None:
source_mp = getattr(trainer._latest_job, "output_model_package_arn", None)
if not source_mp:
raise ValueError(
"[PySDK Error] The provided MultiTurnRLTrainer has no "
"completed training job (output model package ARN is "
"unavailable). Run trainer.wait() and retry."
)
self._source_model_package_arn_cache = source_mp
self._base_model_arn_cache = (
getattr(trainer, "base_model_arn", None)
or getattr(trainer, "_base_model_arn", None)
or getattr(trainer, "_model_arn", None)
)
self._base_model_name_cache = (
getattr(trainer, "base_model_name", None)
or getattr(trainer, "_base_model_name", None)
or getattr(trainer, "_model_name", None)
)
# Customer values win.
if self.agent_config is None:
resolved_agent = (
getattr(trainer, "agent_config", None)
or getattr(trainer, "_agent_config", None)
or getattr(trainer, "agent_env", None)
)
# AgentRFTJob.agent_config returns a dict like {"AgentRuntimeArn": "..."}
if isinstance(resolved_agent, dict):
self.agent_config = (
resolved_agent.get("AgentRuntimeArn")
or resolved_agent.get("AgentLambdaArn")
or resolved_agent.get("LambdaArn")
)
else:
self.agent_config = resolved_agent
if self.agent_qualifier is None:
self.agent_qualifier = (
getattr(trainer, "agent_qualifier", None)
or getattr(trainer, "_agent_qualifier", None)
or getattr(trainer, "bedrock_agentcore_qualifier", None)
)
# --- Hyperparameters property ---------------------------------------
@property
@_telemetry_emitter(
feature=Feature.MODEL_CUSTOMIZATION,
func_name="MultiTurnRLEvaluator.hyperparameters",
)
def hyperparameters(self):
"""Lazy-load evaluation hyperparameters from the JumpStart hub.
Returns a ``FineTuningOptions`` object exposing ``to_dict()``,
``get_info()``, and attribute-style read/write access with
hub-sourced validation (type + range).
Supported parameters (sourced from the AgentRFT evaluation recipe):
``eval_group_size``, ``sampling_temperature``, ``top_p``,
``max_tokens``, ``pass_k_values``, ``success_threshold``.
Raises:
ValueError: If the base model name is not available or the hub
does not expose an AgentRFTEvaluation override spec for
the model.
"""
if self._hyperparameters is not None:
return self._hyperparameters
from ..common import FineTuningOptions
from ..common_utils.recipe_utils import (
_extract_eval_override_options,
_get_evaluation_override_params,
)
hub_content_name = self._base_model_name_cache or self._base_model_name
if not hub_content_name:
raise ValueError(
"[PySDK Error] Cannot load MTRL hyperparameters: base "
"model name not available. Ensure `model` resolves to a "
"JumpStart / hub-backed base model."
)
boto_session = (
self.sagemaker_session.boto_session
if hasattr(self.sagemaker_session, "boto_session")
else self.sagemaker_session
)
override_params = _get_evaluation_override_params(
hub_content_name=hub_content_name,
hub_name="SageMakerPublicHub",
evaluation_type="AgentRFTEvaluation",
region=self.region,
session=boto_session,
)
if not override_params:
raise ValueError(
f"[PySDK Error] Base model '{hub_content_name}' does not "
f"expose AgentRFTEvaluation hyperparameter overrides in the "
f"JumpStart hub."
)
spec = _extract_eval_override_options(override_params, return_full_spec=True)
self._hyperparameters = FineTuningOptions(spec)
return self._hyperparameters
# --- Helpers ---------------------------------------------------------
def _resolve_agent_arn(self) -> None:
"""Resolve ``agent_config`` to a concrete ARN string + kind.
* String ARN: classify as ``bedrock`` or ``lambda`` by regex.
* ``AgentLambdaAdapter``-like object: call ``.materialize()`` which
returns a Lambda ARN string. Gated — returns a clear error if no
``.materialize()`` is available.
"""
if self.agent_config is None:
self._agent_arn_resolved = None
self._agent_kind = None
return
if isinstance(self.agent_config, str):
arn = self.agent_config
else:
materialize = getattr(self.agent_config, "materialize", None)
if callable(materialize):
arn = materialize()
else:
arn = (
getattr(self.agent_config, "lambda_arn", None)
or getattr(self.agent_config, "arn", None)
)
if not isinstance(arn, str):
raise ValueError(
"[PySDK Error] Could not resolve agent_config to an ARN. "
"Pass a Bedrock AgentCore ARN string, a Lambda ARN "
"string, or an object exposing `.lambda_arn` or `.materialize()`."
)
if _BEDROCK_AGENTCORE_ARN_RE.match(arn):
self._agent_kind = "bedrock"
elif _LAMBDA_ARN_RE.match(arn):
self._agent_kind = "lambda"
else:
raise ValueError(
f"[PySDK Error] Resolved agent ARN '{arn}' is neither a "
f"Bedrock AgentCore ARN nor a Lambda ARN."
)
self._agent_arn_resolved = arn
def _select_mtrl_template(self) -> str:
"""Pick the right template based on fine-tuned vs base vs comparison."""
has_ft = bool(self._source_model_package_arn_cache)
if not has_ft:
return MTRL_TEMPLATE_BASE_MODEL_ONLY
if has_ft and self.evaluate_base_model:
return MTRL_TEMPLATE
return MTRL_TEMPLATE_FINE_TUNED_ONLY
def _build_template_context(
self,
aws_context: Dict[str, str],
artifacts: Dict[str, str],
model_package_group_arn: Optional[str],
) -> Dict[str, Any]:
"""Assemble the rendering context expected by the MTRL templates."""
import json as _json
hparams: Dict[str, Any] = {}
try:
hparams = self.hyperparameters.to_dict() if self._base_model_name_cache else {}
except Exception as e: # hub fetch can fail in offline envs
_logger.info(f"Skipping hub-sourced hyperparameters: {e}")
def _str_or_none(v):
return str(v) if v is not None else None
action_arn_prefix = (
f"arn:aws:sagemaker:{aws_context['region']}:{aws_context['account_id']}:action"
)
networking = getattr(self, "networking", None)
vpc_security_group_ids: List[str] = []
vpc_subnets: List[str] = []
if networking is not None:
vpc_security_group_ids = list(getattr(networking, "security_group_ids", []) or [])
vpc_subnets = list(getattr(networking, "subnets", []) or [])
base_model_arn = (
self._base_model_arn_cache
or self._base_model_arn
or artifacts.get("base_model_arn")
)
# --- Build JobConfigDocument as a dict, then json.dumps() it ----
# The Pipelines service expects JobConfigDocument as a JSON string
# (double-encoded), not a nested object.
def _build_job_config_doc(include_mpc: bool, mlflow_run_name: str) -> str:
agent_arn = self._agent_arn_resolved
# AgentConfig
if self._agent_kind == "bedrock":
agent_cfg = {"BedrockAgentCoreConfig": {"AgentRuntimeArn": agent_arn}}
if self.agent_qualifier:
agent_cfg["BedrockAgentCoreConfig"]["Qualifier"] = self.agent_qualifier
else:
agent_cfg = {"CustomAgentLambdaConfig": {"LambdaArn": agent_arn}}
# InputDataConfig
ds = self.dataset
if ds.startswith("arn:") and "hub-content" in ds and "/DataSet/" in ds:
data_source = {"DatasetSource": {"DatasetArn": ds}}
else:
data_source = {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": ds}}
input_data = [{"ChannelName": "evaluation", "DataSource": data_source}]
# OutputDataConfig
output_data: Dict[str, Any] = {"S3OutputPath": self.s3_output_path}
if getattr(self, "kms_key_id", None):
output_data["KmsKeyArn"] = self.kms_key_id
if self.mlflow_resource_arn:
mlflow_cfg: Dict[str, Any] = {"MlflowResourceArn": self.mlflow_resource_arn}
exp_name = (
getattr(self, "mlflow_experiment_name", None)
or f"mtrl-eval-{self._base_model_name_cache or 'default'}"
)
mlflow_cfg["MlflowExperimentName"] = exp_name
mlflow_cfg["MlflowRunName"] = mlflow_run_name
output_data["MlflowConfig"] = mlflow_cfg
eval_cfg: Dict[str, Any] = {
"BaseModelArn": base_model_arn,
"AcceptEula": True,
}
hp: Dict[str, str] = {}
for k in ("eval_group_size", "sampling_temperature", "top_p",
"max_tokens", "pass_k_values", "success_threshold"):
v = hparams.get(k)
if v is not None:
hp[k] = str(v)
if hp:
eval_cfg["HyperParameters"] = hp
doc: Dict[str, Any] = {
"AgentConfig": agent_cfg,
"InputDataConfig": input_data,
"OutputDataConfig": output_data,
"EvaluationConfig": eval_cfg,
}
# ModelPackageConfig (fine-tuned only)
if include_mpc and self._source_model_package_arn_cache:
mpc: Dict[str, str] = {
"InputModelPackageArn": self._source_model_package_arn_cache,
}
doc["ModelPackageConfig"] = mpc
# StoppingCondition
doc["StoppingCondition"] = {
"MaxRuntimeInSeconds": self.stopping_condition,
}
return _json.dumps(doc)
# Build both variants (base-only and fine-tuned).
job_config_doc_str = _build_job_config_doc(include_mpc=False, mlflow_run_name="base-model-eval")
job_config_doc_ft_str = _build_job_config_doc(include_mpc=True, mlflow_run_name="fine-tuned-model-eval")
return {
"pipeline_name": aws_context.get("pipeline_name")
or artifacts.get("pipeline_name")
or f"SagemakerEvaluation-MTRLEvaluation",
"role_arn": aws_context["role_arn"],
"base_model_arn": base_model_arn,
"agent_arn": self._agent_arn_resolved,
"agent_qualifier": self.agent_qualifier,
"dataset_uri": self.dataset,
"s3_output_path": self.s3_output_path,
"mlflow_resource_arn": self.mlflow_resource_arn,
"mlflow_experiment_name": getattr(self, "mlflow_experiment_name", None)
or aws_context.get("pipeline_name"),
"eval_group_size": _str_or_none(hparams.get("eval_group_size")),
"sampling_temperature": _str_or_none(hparams.get("sampling_temperature")),
"top_p": _str_or_none(hparams.get("top_p")),
"max_tokens": _str_or_none(hparams.get("max_tokens")),
"pass_k_values": _str_or_none(hparams.get("pass_k_values")),
"success_threshold": _str_or_none(hparams.get("success_threshold")),
"stopping_condition": self.stopping_condition,
"model_package_group_arn": model_package_group_arn,
"source_model_package_arn": self._source_model_package_arn_cache,
"action_arn_prefix": action_arn_prefix,
"dataset_artifact_arn": None,
"kms_key_arn": getattr(self, "kms_key_id", None),
"vpc_config": bool(networking),
"vpc_security_group_ids": vpc_security_group_ids,
"vpc_subnets": vpc_subnets,
"tags": self.tags,
# Pre-stringified JobConfigDocument for the templates.
"job_config_document_str": job_config_doc_str,
"job_config_document_ft_str": job_config_doc_ft_str,
}
# --- Public entry points --------------------------------------------
[docs]
@_telemetry_emitter(
feature=Feature.MODEL_CUSTOMIZATION,
func_name="MultiTurnRLEvaluator.evaluate",
)
def evaluate(self) -> 'MTRLEvaluationExecution':
"""Render the MTRL pipeline and start a non-blocking execution.
Returns:
MTRLEvaluationExecution: The started pipeline execution.
Call ``.wait()`` to block until completion and ``.show_results()``
to render the aggregate report.
Example:
.. code:: python
execution = evaluator.evaluate()
execution.wait()
execution.show_results()
"""
# 1. Trainer-sourced resolution (no-op if model is not a trainer).
self._resolve_trainer_defaults()
# 2. Resolve agent ARN.
self._resolve_agent_arn()
if not self._agent_arn_resolved:
raise ValueError(
"[PySDK Error] 'agent_config' resolved to None. A valid agent "
"ARN is required for evaluation. Provide either:\n"
" - A Bedrock AgentCore ARN: arn:aws:bedrock-agentcore:<region>:<account>:runtime/<id>\n"
" - A Lambda ARN: arn:aws:lambda:<region>:<account>:function:<name>"
)
# 3. AWS context + model artifacts (reuses BaseEvaluator plumbing).
aws_context = self._get_aws_execution_context()
artifacts = self._resolve_model_artifacts(aws_context["region"])
if not self._base_model_arn_cache:
self._base_model_arn_cache = self._base_model_arn
if not self._base_model_name_cache:
self._base_model_name_cache = self._base_model_name
if not self._source_model_package_arn_cache:
self._source_model_package_arn_cache = self._source_model_package_arn
model_package_group_arn = self._get_model_package_group_arn()
# 4. Template context.
template_context = self._build_template_context(
aws_context=aws_context,
artifacts=artifacts,
model_package_group_arn=model_package_group_arn,
)
# 5. Template selection + render.
template_str = self._select_mtrl_template()
pipeline_definition = self._render_pipeline_definition(template_str, template_context)
# Dump the pipeline definition to a local JSON file for debugging.
import json as _json_mod
_debug_path = "mtrl_eval_pipeline_input.json"
with open(_debug_path, "w") as _f:
_json_mod.dump(_json_mod.loads(pipeline_definition), _f, indent=2)
_logger.info(f"Pipeline definition written to {_debug_path}")
# 6. Start execution via custom boto3 path. The MTRL pipeline uses the
# "Job" step type which requires the beta endpoint for CreatePipeline.
# We still tag the pipeline for discoverability via get_all().
name = self.base_eval_name or f"mtrl-eval-{(self._base_model_name_cache or 'model')}"
return self._start_mtrl_execution(
pipeline_definition=pipeline_definition,
name=name,
role_arn=aws_context["role_arn"],
region=aws_context["region"],
)
def _get_mlflow_presigned_url(self, region: str, sm_client=None) -> Optional[str]:
"""Generate a presigned MLflow tracking server URL for the user.
Uses the provided sm_client to call
create_presigned_mlflow_app_url. Falls back to a console deep-link if
the presigned URL call fails.
"""
if not self.mlflow_resource_arn:
return None
eval_experiment_name = (
getattr(self, "mlflow_experiment_name", None)
or f"mtrl-eval-{self._base_model_name_cache or 'default'}"
)
base_url = None
# Try presigned URL via the provided client first (respects beta endpoint).
if sm_client is not None:
try:
response = sm_client.create_presigned_mlflow_app_url(
Arn=self.mlflow_resource_arn
)
base_url = response.get("AuthorizedUrl")
except Exception as e:
_logger.debug(f"Presigned MLflow URL via sm_client failed: {e}")
# Fallback: get presigned URL via default SageMakerClient
if not base_url:
try:
from sagemaker.core.utils.utils import SageMakerClient
client = SageMakerClient().sagemaker_client
response = client.create_presigned_mlflow_app_url(
Arn=self.mlflow_resource_arn
)
base_url = response.get("AuthorizedUrl")
except Exception as e:
_logger.debug(f"Presigned MLflow URL via SageMakerClient failed: {e}")
if not base_url:
return None
# Build deep link with experiment name in the URL
# We can't resolve experiment name → ID without an authenticated MLflow session,
# so we use the experiment name directly in the search filter deep link
from sagemaker.train.common_utils.mlflow_url_utils import _build_mlflow_deep_link_by_name
return _build_mlflow_deep_link_by_name(base_url, eval_experiment_name)
def _start_mtrl_execution(self, pipeline_definition, name, role_arn, region):
"""Start MTRL pipeline execution via boto3.
This method handles pipeline get-or-create with proper evaluation
tagging so executions are discoverable via ``get_all()``.
"""
import uuid
import boto3
from .execution import MTRLEvaluationExecution, PipelineExecutionStatus
from .constants import _get_pipeline_name_prefix, _TAG_SAGEMAKER_MODEL_EVALUATION
sm_client = boto3.client("sagemaker", region_name=region)
pipeline_prefix = _get_pipeline_name_prefix(EvalType.MTRL)
pipeline_name = pipeline_prefix
# Search for existing MTRL pipeline
existing_pipeline_name = None
try:
resp = sm_client.list_pipelines(PipelineNamePrefix=pipeline_prefix)
for p in resp.get("PipelineSummaries", []):
existing_pipeline_name = p["PipelineName"]
break
except Exception:
pass
if existing_pipeline_name:
pipeline_name = existing_pipeline_name
sm_client.update_pipeline(
PipelineName=pipeline_name,
PipelineDefinition=pipeline_definition,
RoleArn=role_arn,
)
_logger.info(f"Updated existing pipeline: {pipeline_name}")
else:
sm_client.create_pipeline(
PipelineName=pipeline_name,
PipelineDefinition=pipeline_definition,
RoleArn=role_arn,
PipelineDisplayName=pipeline_name,
PipelineDescription="MTRL evaluation pipeline",
ClientRequestToken=str(uuid.uuid4()),
Tags=[{"Key": _TAG_SAGEMAKER_MODEL_EVALUATION, "Value": "true"}],
)
_logger.info(f"Created pipeline: {pipeline_name}")
# Start execution
resp = sm_client.start_pipeline_execution(
PipelineName=pipeline_name,
PipelineExecutionDisplayName=f"{name}-{int(__import__('time').time())}",
ClientRequestToken=str(uuid.uuid4()),
)
exec_arn = resp["PipelineExecutionArn"]
_logger.info(f"Started MTRL pipeline execution: {exec_arn}")
# Build execution object using the shared subclass
from sagemaker.core.resources import PipelineExecution
execution = MTRLEvaluationExecution(
name=name,
arn=exec_arn,
eval_type=EvalType.MTRL,
s3_output_path=self.s3_output_path,
status=PipelineExecutionStatus(overall_status="Executing"),
)
# Store the pipeline execution reference for wait/refresh
try:
pe = PipelineExecution.get(pipeline_execution_arn=exec_arn, region=region)
execution._pipeline_execution = pe
except Exception as e:
_logger.debug(f"Could not fetch PipelineExecution for wait/refresh: {e}")
# Print job summary and MLflow URL for the user.
mlflow_url = self._get_mlflow_presigned_url(region, sm_client=sm_client)
template_type = self._select_mtrl_template()
if "BASE_MODEL_ONLY" in template_type:
eval_mode = "Base model only"
elif "FINE_TUNED_ONLY" in template_type:
eval_mode = "Fine-tuned model only"
else:
eval_mode = "Base + Fine-tuned comparison"
print(f"\n{'─' * 60}")
print(f" MTRL Evaluation Job")
print(f"{'─' * 60}")
print(f" Model : {self._base_model_name_cache or self.model}")
print(f" Eval mode : {eval_mode}")
print(f" Dataset : {self.dataset}")
print(f" Agent : {self._agent_arn_resolved or self.agent_config}")
print(f" Output : {self.s3_output_path}")
print(f" Region : {region}")
if mlflow_url:
print(f" MLflow URL : {mlflow_url}")
else:
print(f" MLflow ARN : {self.mlflow_resource_arn}")
print(f"{'─' * 60}")
print(f" Pipeline execution started: {exec_arn}\n")
# Store MLflow URL and config on execution for later access.
execution.mlflow_url = mlflow_url
execution.mlflow_resource_arn = self.mlflow_resource_arn
execution.mlflow_experiment_name = (
getattr(self, "mlflow_experiment_name", None)
or f"mtrl-eval-{self._base_model_name_cache or 'default'}"
)
return execution
[docs]
@classmethod
@_telemetry_emitter(
feature=Feature.MODEL_CUSTOMIZATION,
func_name="MultiTurnRLEvaluator.get_all",
)
def get_all(cls, session=None, region=None):
"""List all MTRL evaluation executions in the account / region.
Args:
session: Optional boto3 session.
region: Optional AWS region.
Yields:
EvaluationPipelineExecution: MTRL evaluation execution instances.
"""
from .execution import EvaluationPipelineExecution
yield from EvaluationPipelineExecution.get_all(
eval_type=EvalType.MTRL, session=session, region=region
)
[docs]
@staticmethod
@_telemetry_emitter(
feature=Feature.MODEL_CUSTOMIZATION,
func_name="MultiTurnRLEvaluator.list_supported_models",
)
def list_supported_models(session=None) -> list:
"""Return the list of models that support MTRL evaluation.
Queries SageMakerPublicHub to discover all models with MTRL
recipes in their ``RecipeCollection``.
Args:
session: Optional boto3 session.
Returns:
List of hub content model names supporting MTRL evaluation.
"""
from sagemaker.train.common_utils.recipe_utils import _list_hub_models_by_recipe
return _list_hub_models_by_recipe(
recipe_type="FineTuning", technique="MTRL", session=session
)
[docs]
@staticmethod
@_telemetry_emitter(
feature=Feature.MODEL_CUSTOMIZATION,
func_name="MultiTurnRLEvaluator.list_bedrock_agentcore_runtimes",
)
def list_bedrock_agentcore_runtimes(session=None) -> list:
"""List Bedrock AgentCore runtimes.
Args:
session: Optional boto3 session.
Returns:
List of dicts, each with keys ``name``, ``runtime_id``, ``arn``,
and ``status``.
"""
import boto3
client = (session or boto3.Session()).client("bedrock-agentcore-control")
runtimes: list = []
next_token = None
while True:
kwargs: dict = {}
if next_token:
kwargs["nextToken"] = next_token
response = client.list_agent_runtimes(**kwargs)
for rt in response.get("agentRuntimes", []):
entry = {
"name": rt.get("agentRuntimeName", ""),
"runtime_id": rt.get("agentRuntimeId", ""),
"arn": rt["agentRuntimeArn"],
"status": rt.get("status", ""),
}
runtimes.append(entry)
next_token = response.get("nextToken")
if not next_token:
break
return runtimes