Source code for sagemaker.train.rft.feedback

"""Rollout feedback client for reporting completion and rewards to the RFT Runtime Service."""

from __future__ import annotations

import json
import logging
import os
from typing import Any, List, Optional, Union

import requests as req_lib

from sagemaker.core.token_generator import generate_token
from sagemaker.train.rft.models import RolloutMetadata

logger = logging.getLogger(__name__)

_TRAJECTORY_ALREADY_PROCESSED_MARKERS = (
    "not in valid status",
    "Cannot transition trajectory",
)


def _is_trajectory_already_processed(error: str) -> bool:
    """Return True if the error indicates the trajectory was already completed."""
    return any(marker in error for marker in _TRAJECTORY_ALREADY_PROCESSED_MARKERS)


_DEFAULT_ENDPOINT = os.environ.get(
    "RFT_RUNTIME_ENDPOINT",
    "https://job-runtime.sagemaker.us-east-1.api.aws"
)


def _build_endpoint(region: str, stage: str = "") -> str:
    """Build the RFT Runtime Service endpoint URL for a given region and stage."""
    if stage and stage != "prod":
        prefix = f"job-runtime.{stage}."
    else:
        prefix = "job-runtime."
    return f"https://{prefix}sagemaker.{region}.api.aws"


[docs] class RolloutFeedbackClient: """Client for reporting rollout completion to the RFT Runtime Service. Calls the runtime service's ``/complete-rollout`` and ``/update-reward`` APIs using bearer token auth. Example:: feedback = RolloutFeedbackClient(metadata) feedback.report_complete(reward=0.95) """ def __init__(self, metadata: dict[str, Any] | RolloutMetadata) -> None: if isinstance(metadata, RolloutMetadata): metadata = metadata.model_dump() elif not isinstance(metadata, dict): raise TypeError( f"metadata must be a dict or RolloutMetadata, got {type(metadata).__name__}." ) self._region = ( metadata.get("region") or os.environ.get("AWS_REGION") or "us-west-2" ) self._endpoint = ( metadata.get("endpoint") or os.environ.get("RFT_RUNTIME_ENDPOINT") or _build_endpoint(self._region, os.environ.get("RFT_STAGE", "")) ).rstrip("/") self._job_arn = ( metadata.get("job_arn") or metadata.get("jobArn") or "" ) self._trajectory_id = ( metadata.get("trajectory_id") or metadata.get("trajectoryId") or metadata.get("rolloutId") or "" ) self._metadata = metadata
[docs] def complete_rollout(self, status: str = "ready") -> None: """Report trajectory completion to the runtime service. Args: status: Target status - "ready" for success, "failed" for errors. """ if not self._trajectory_id: logger.warning("No trajectory_id in metadata; skipping complete_rollout") return logger.info( "CompleteRollout: trajectory_id=%s status=%s", self._trajectory_id, status, ) try: self._bearer_post("/complete-rollout", json.dumps({ "JobArn": self._job_arn, "TrajectoryId": self._trajectory_id, "Status": status, })) except Exception as e: err_str = str(e) if "404" in err_str: logger.warning( "CompleteRollout 404: trajectory %s not found", self._trajectory_id, ) elif _is_trajectory_already_processed(err_str): logger.warning( "CompleteRollout: trajectory %s already in terminal status, skipping", self._trajectory_id, ) else: raise
[docs] def update_reward(self, reward: Union[float, List[float]]) -> None: """Report reward(s) to the runtime service. Args: reward: A single float or list of floats for per-turn rewards. """ rewards = [reward] if isinstance(reward, (int, float)) else list(reward) if not self._trajectory_id: logger.warning("No trajectory_id in metadata; skipping update_reward") return logger.info( "UpdateReward: trajectory_id=%s rewards=%s", self._trajectory_id, rewards, ) try: self._bearer_post("/update-reward", json.dumps({ "JobArn": self._job_arn, "TrajectoryId": self._trajectory_id, "Rewards": rewards, })) except Exception as e: err_str = str(e) if "404" in err_str: logger.warning( "UpdateReward 404: trajectory %s not found", self._trajectory_id, ) elif _is_trajectory_already_processed(err_str): logger.warning( "UpdateReward: trajectory %s already in terminal status, skipping", self._trajectory_id, ) else: raise
[docs] def report_complete(self, reward: Union[float, List[float]]) -> None: """Complete the trajectory and report reward(s). Convenience method that calls complete_rollout() then update_reward(). Args: reward: The computed reward(s) for this rollout. """ self.complete_rollout(status="ready") self.update_reward(reward)
[docs] def report_error(self, error: str, reward: Optional[float] = None) -> None: """Report a rollout error, marking the trajectory as failed. Args: error: Error description. reward: Optional partial reward (defaults to 0.0). """ logger.error("Rollout error: trajectory_id=%s error=%s", self._trajectory_id, error) self.complete_rollout(status="failed") self.update_reward(reward if reward is not None else 0.0)
def _bearer_post(self, path: str, body: str) -> None: """Send a bearer-token-authenticated POST to the runtime service.""" url = f"{self._endpoint}{path}" try: token = generate_token(region=self._region) response = req_lib.post( url, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {token}", }, data=body, timeout=120, ) if response.status_code != 200: logger.warning("Failed %s: status=%s body=%s", path, response.status_code, response.text[:500]) response.raise_for_status() except Exception as e: logger.warning("Failed %s: %s", path, e) raise