Source code for sagemaker.train.agent_rft_job

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

"""AgentRFTJob — wrapper around sagemaker-core Job for AgentRFT job category."""
from __future__ import annotations

import json
import logging
from typing import Optional

from sagemaker.core.resources import Job
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature

logger = logging.getLogger(__name__)

JOB_CATEGORY = "AgentRFT"
TERMINAL_STATUSES = ("Completed", "Failed", "Stopped")


[docs] class AgentRFTJob: """Wrapper around sagemaker-core Job for AgentRFT job category. Delegates lifecycle methods to the underlying Job and adds typed convenience properties by parsing the JobConfigDocument JSON string. Args: job: The sagemaker-core Job instance to wrap. """ JOB_CATEGORY = JOB_CATEGORY def __init__(self, job: Job): self._job = job self._cached_config: dict | None = None self.description: str | None = None self.sagemaker_session = None
[docs] @classmethod def from_job(cls, job: Job) -> AgentRFTJob: """Create an AgentRFTJob from a sagemaker-core Job instance.""" return cls(job)
# --- Delegated properties --- @property def job_name(self) -> str: return self._job.job_name @property def job_arn(self) -> str: return self._job.job_arn @property def job_status(self) -> str: return self._job.job_status @property def secondary_status(self) -> str: return self._job.secondary_status @property def secondary_status_transitions(self) -> list: return self._job.secondary_status_transitions @property def failure_reason(self) -> str | None: return self._job.failure_reason @property def creation_time(self): return self._job.creation_time @property def last_modified_time(self): return self._job.last_modified_time @property def end_time(self): return self._job.end_time # --- Delegated lifecycle methods ---
[docs] def refresh(self): """Refresh job state from DescribeJob API.""" self._job.refresh() self._cached_config = None
[docs] def wait(self, poll: int = 5, timeout: Optional[int] = 3000, max_log_lines: int = 20): """Wait for job to reach terminal status. Args: poll: Seconds between polls. timeout: Maximum seconds to wait. max_log_lines: Maximum number of log lines to display. Defaults to 20. """ from sagemaker.train.common_utils.job_wait import wait as _job_wait _job_wait(self._job, poll=poll, timeout=timeout, description=self.description, max_log_lines=max_log_lines)
[docs] def stop(self): """Stop the job via StopJob API.""" self._job.stop()
[docs] def delete(self): """Delete the job via DeleteJob API.""" self._job.delete()
[docs] def wait_for_delete(self): """Wait for job deletion to complete.""" self._job.wait_for_delete()
# --- Parsed properties from JobConfigDocument --- def _parse_config_document(self) -> dict: """Parse JobConfigDocument JSON string into a dict. Cached after refresh.""" if self._cached_config is None: doc = self._job.job_config_document self._cached_config = json.loads(doc) if doc else {} return self._cached_config @property def output_model_package_arn(self) -> str | None: """ARN of the output model package from ServiceOutput, or None.""" config = self._parse_config_document() return config.get("ServiceOutput", {}).get("OutputModelPackageArn") @property def mlflow_details(self) -> dict | None: """MLflow experiment/run details from ServiceOutput. Returns dict with keys: ExperimentName, RunName, ExperimentId, RunId. """ config = self._parse_config_document() return config.get("ServiceOutput", {}).get("MlflowDetails")
[docs] def get_mlflow_url(self) -> str | None: """Generate a fresh presigned MLflow URL for this job's experiment/run. In Jupyter notebooks, also renders a clickable link. Returns: Presigned URL string, or None if MLflow is not configured. """ from sagemaker.train.common_utils.job_wait import ( _get_mlflow_arn_from_config, _get_mlflow_output_details, _get_mlflow_experiment_name, _get_mlflow_presigned_url, _is_jupyter_environment, ) config = self._parse_config_document() mlflow_arn = _get_mlflow_arn_from_config(config) if not mlflow_arn: return None exp_id, run_id = _get_mlflow_output_details(config) exp_name = _get_mlflow_experiment_name(config) url = _get_mlflow_presigned_url(mlflow_arn, exp_name, experiment_id=exp_id, run_id=run_id) if url and _is_jupyter_environment(): from IPython.display import display as ipy_display, HTML ipy_display(HTML( f'🔗 <a href="{url}" target="_blank">Open MLflow Experiment</a>' )) return url
@property def s3_output_path(self) -> str | None: """S3 output path from OutputDataConfig.""" config = self._parse_config_document() return config.get("OutputDataConfig", {}).get("S3OutputPath") @property def billable_token_usage(self) -> dict | None: """Billable token usage from ServiceOutput. Returns dict with keys: TrainTokenCount, PrefillTokenCount, SampleTokenCount. """ config = self._parse_config_document() return config.get("ServiceOutput", {}).get("BillableTokenUsage") @property def progress_info(self) -> dict | None: """Training progress from ServiceOutput. Supports two formats: - Epoch-based: dict with MaxEpoch, StepsPerEpoch, CurrentEpoch, CurrentStep. - Step-only: dict with MaxSteps, CurrentStep. Returns None if not available. """ config = self._parse_config_document() info = config.get("ServiceOutput", {}).get("ProgressInfo") if not info: return None has_epoch = info.get("MaxEpoch") and info.get("StepsPerEpoch") has_steps = info.get("MaxSteps") if not has_epoch and not has_steps: return None return info @property def training_config(self) -> dict | None: """Full TrainingConfig section from JobConfigDocument.""" return self._parse_config_document().get("TrainingConfig") @property def agent_config(self) -> dict | None: """Full AgentConfig section from JobConfigDocument.""" return self._parse_config_document().get("AgentConfig") # --- Training metrics ---
[docs] @_telemetry_emitter( feature=Feature.MODEL_CUSTOMIZATION, func_name="AgentRFTJob.get_training_metrics" ) def get_training_metrics(self) -> list[dict]: """Fetch per-step MTRL training metrics from MLflow. Retrieves ``rollout/reward/mean``, ``rollout/turns/mean``, ``training/total_tokens``, and ``training/num_trajectories`` for each training step and prints a summary table. Returns: List of dicts, one per step, with keys ``step``, ``rollout/reward/mean``, ``rollout/turns/mean``, ``training/total_tokens``, and ``training/num_trajectories``. """ from sagemaker.train.common_utils.job_wait import ( _get_mlflow_arn_from_config, _get_mlflow_experiment_name, _get_mlflow_output_details, _get_mlflow_run_name, _get_step_metrics, _setup_mlflow_metrics_util, MTRL_METRIC_KEYS, ) config = self._parse_config_document() mlflow_arn = _get_mlflow_arn_from_config(config) exp_name = _get_mlflow_experiment_name(config) svc_details = config.get("ServiceOutput", {}).get("MlflowDetails", {}) if not exp_name: exp_name = svc_details.get("ExperimentName") if not mlflow_arn or not exp_name: logger.warning("MLflow not configured for this job.") return [] util = _setup_mlflow_metrics_util(mlflow_arn, exp_name) run_name = _get_mlflow_run_name(config) or svc_details.get("RunName") _, run_id = _get_mlflow_output_details(config) rows = _get_step_metrics(util, run_name, run_id, MTRL_METRIC_KEYS, mlflow_arn=mlflow_arn) if rows: self._print_metrics_table(rows) return rows
@staticmethod def _print_metrics_table(rows: list[dict]) -> None: """Print a formatted metrics table. Columns are derived from the dict keys (excluding ``step``). """ if not rows: return metric_keys = [k for k in rows[0] if k != "step"] # Build column headers from metric names def _col_name(k: str) -> str: parts = k.split("/") label = "/".join(parts[-2:]) if len(parts) > 1 else parts[0] return label.replace("_", " ").title() col_names = [_col_name(k) for k in metric_keys] col_width = max(14, *(len(c) for c in col_names)) header = f"{'Step':>6}" + "".join(f" {c:>{col_width}}" for c in col_names) sep = "-" * len(header) print(f"\n{sep}\n Training Metrics\n{sep}") print(header) print(sep) for r in rows: line = f"{r['step']:>6}" for k in metric_keys: v = r.get(k) if v is None: line += f" {'—':>{col_width}}" elif isinstance(v, float) and v != int(v): line += f" {v:>{col_width}.4f}" else: line += f" {int(v):>{col_width}}" print(line) print(sep) # --- Class methods ---
[docs] @classmethod def get(cls, job_name: str, session=None) -> AgentRFTJob: """Attach to an existing AgentRFT job by name. Args: job_name: The name of the job. session: Optional boto3 session. Returns: AgentRFTJob wrapping the existing job. """ job = Job.get(job_name=job_name, job_category=cls.JOB_CATEGORY, session=session) return cls.from_job(job)
[docs] @classmethod def get_all(cls, session=None, **kwargs): """List all AgentRFT jobs. Delegates to Job.get_all with job_category pre-filled. Additional keyword arguments (e.g. creation_time_after, name_contains, sort_by, sort_order, status_equals) are forwarded. Args: session: Optional boto3 session. **kwargs: Additional filter arguments forwarded to Job.get_all. Yields: AgentRFTJob instances. """ for job in Job.get_all(job_category=cls.JOB_CATEGORY, session=session, **kwargs): yield cls.from_job(job)