Source code for sagemaker.train.rft.decorators
"""Decorators for RFT integration.
Provides sagemaker_rft_handler decorator for AgentCore Runtime entrypoints.
"""
from __future__ import annotations
import asyncio
import logging
from functools import wraps
from typing import Any, Callable
from sagemaker.train.rft.context import set_rollout_context, clear_rollout_context
from sagemaker.train.rft.feedback import RolloutFeedbackClient, _is_trajectory_already_processed
logger = logging.getLogger(__name__)
[docs]
def sagemaker_rft_handler(func: Callable) -> Callable:
"""Decorator for AgentCore Runtime entrypoints to handle RFT rollout lifecycle.
Automatically:
1. Sets rollout context (metadata + inference_params) for header injection
2. On success, calls CompleteTrajectory + UpdateReward if "reward" in result
3. On error, calls report_error
4. Clears context when done
Works with both sync and async functions.
Example::
from sagemaker.train.rft import rft_handler
from sagemaker.train.rft.adapters.strands import wrap_model
from strands import Agent
from strands.models.openai import OpenAIModel
model = wrap_model(OpenAIModel(client_args={...}, model_id="my-model"))
agent = Agent(model=model, tools=[...])
@app.entrypoint
@rft_handler
async def invoke_agent(payload):
result = await agent.invoke_async(payload["instance"]["prompt"])
return {"reward": compute_reward(result)}
"""
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(payload: dict) -> Any:
metadata = payload.get("metadata") or {}
inference_params = payload.get("inferenceParams") or payload.get("inference_params")
set_rollout_context(metadata, inference_params)
feedback = RolloutFeedbackClient(metadata)
try:
result = await func(payload)
except Exception as e:
error_str = str(e)
if _is_trajectory_already_processed(error_str):
logger.warning("Trajectory already processed, skipping: %s", error_str)
return {"status": "skipped", "error": error_str}
logger.error("RFT rollout failed: %s", e)
try:
feedback.report_error(error_str)
except Exception:
logger.exception("Failed to report rollout error")
raise
else:
try:
_handle_result(feedback, result)
except Exception:
logger.exception("Failed to report rollout result (non-fatal)")
return result
finally:
clear_rollout_context()
return async_wrapper
else:
@wraps(func)
def sync_wrapper(payload: dict) -> Any:
metadata = payload.get("metadata") or {}
inference_params = payload.get("inferenceParams") or payload.get("inference_params")
set_rollout_context(metadata, inference_params)
feedback = RolloutFeedbackClient(metadata)
try:
result = func(payload)
except Exception as e:
error_str = str(e)
if _is_trajectory_already_processed(error_str):
logger.warning("Trajectory already processed, skipping: %s", error_str)
return {"status": "skipped", "error": error_str}
logger.error("RFT rollout failed: %s", e)
try:
feedback.report_error(error_str)
except Exception:
logger.exception("Failed to report rollout error")
raise
else:
try:
_handle_result(feedback, result)
except Exception:
logger.exception("Failed to report rollout result (non-fatal)")
return result
finally:
clear_rollout_context()
return sync_wrapper
def _handle_result(feedback: RolloutFeedbackClient, result: Any) -> None:
"""Handle rollout result: report success or error based on result status."""
if not isinstance(result, dict):
return
# If the agent reported an error, mark trajectory as failed.
# This catches cases where the agent returns {"status": "error", "reward": 0.0}
# instead of raising — e.g. streaming errors caught by the agent.
status = result.get("status", "")
if status == "skipped":
logger.info("Trajectory already processed, skipping feedback reporting")
return
if status == "error":
error_msg = result.get("error", "unknown error")
logger.warning("Agent returned error status: %s", error_msg)
feedback.report_error(error_msg)
return
reward = result.get("reward")
if reward is not None:
if isinstance(reward, list):
feedback.complete_rollout()
feedback.update_reward(reward)
else:
feedback.report_complete(reward)
else:
feedback.complete_rollout()