sagemaker.train.rft#
SageMaker RFT SDK - Integration library for multi-turn RL training platform.
Strands + AgentCore (simplest):
from sagemaker.train.rft import sagemaker_rft_handler, RolloutFeedbackClient
from sagemaker.train.rft.adapters.strands import wrap_model
model = wrap_model(OpenAIModel(...))
@app.entrypoint
@sagemaker_rft_handler
async def invoke_agent(payload):
result = await agent.invoke_async(payload["instance"])
return result
Strands Standalone:
from sagemaker.train.rft import set_rollout_context, RolloutFeedbackClient
from sagemaker.train.rft.adapters.strands import wrap_model
model = wrap_model(model)
@app.post("/rollout")
def rollout(request):
set_rollout_context(request.metadata, request.inference_params)
result = agent(request.instance)
RolloutFeedbackClient(request.metadata).report_complete(reward)
Custom Integration:
from sagemaker.train.rft import make_inference_headers, RolloutFeedbackClient
@app.post("/rollout")
def handle(request):
headers = make_inference_headers(request.metadata)
client = OpenAI(base_url=endpoint, default_headers=headers)
result = my_agent.run(request.instance, client)
RolloutFeedbackClient(request.metadata).report_complete(reward)
- class sagemaker.train.rft.InferenceParams(*, temperature: float | None = None, max_tokens: int | None = None, top_p: float | None = None)[source]#
Bases:
BaseModelInference parameters for rollout sampling.
All fields are optional - if not provided, model defaults are used.
- max_tokens: int | None#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- temperature: float | None#
- top_p: float | None#
- class sagemaker.train.rft.RolloutFeedbackClient(metadata: dict[str, Any] | RolloutMetadata)[source]#
Bases:
objectClient for reporting rollout completion to the RFT Runtime Service.
Calls the runtime service’s
/complete-rolloutand/update-rewardAPIs using bearer token auth.Example:
feedback = RolloutFeedbackClient(metadata) feedback.report_complete(reward=0.95)
- complete_rollout(status: str = 'ready') None[source]#
Report trajectory completion to the runtime service.
- Parameters:
status – Target status - “ready” for success, “failed” for errors.
- report_complete(reward: float | List[float]) None[source]#
Complete the trajectory and report reward(s).
Convenience method that calls complete_rollout() then update_reward().
- Parameters:
reward – The computed reward(s) for this rollout.
- class sagemaker.train.rft.RolloutMetadata(*, job_arn: str, trajectory_id: str, endpoint: str, region: str = 'us-west-2')[source]#
Bases:
BaseModelMetadata sent by the trainer with each rollout request.
Pass this entire object (or its dict form) to RolloutFeedbackClient and make_inference_headers.
- endpoint: str#
- job_arn: str#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- region: str#
- trajectory_id: str#
- class sagemaker.train.rft.RolloutRequest(*, instance: Dict[str, Any], metadata: RolloutMetadata, inference_params: InferenceParams | None = None, model_name: str | None = None, model_endpoint: str | None = None)[source]#
Bases:
BaseModelRequest format sent by the trainer to your /rollout endpoint.
This is the enforced contract. Your server must accept this exact format.
- inference_params: InferenceParams | None#
- instance: Dict[str, Any]#
- metadata: RolloutMetadata#
- model_config: ClassVar[ConfigDict] = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- model_endpoint: str | None#
- model_name: str | None#
- sagemaker.train.rft.clear_rollout_context() None[source]#
Clear rollout metadata and inference params from context.
- sagemaker.train.rft.get_inference_headers() dict[str, str][source]#
Get headers from current rollout context.
For use with set_rollout_context() when you need to retrieve headers deep in the call stack without passing them explicitly.
- Returns:
Headers dict, or empty dict if no context set (with warning).
- sagemaker.train.rft.get_inference_params() dict[str, Any] | None[source]#
Retrieve inference parameters from context.
- Returns:
The inference_params dict if set, None otherwise.
Example:
from sagemaker.train.rft.context import get_inference_params params = get_inference_params() if params: temperature = params.get("temperature", 1.0) max_tokens = params.get("max_tokens", 4096)
- sagemaker.train.rft.make_inference_headers(metadata: dict[str, Any] | RolloutMetadata) dict[str, str][source]#
Create headers dict for inference calls.
- Produces the individual headers required by the RFT Runtime Service:
X-Amzn-SageMaker-Job-Arn: job ARN that identifies the training sessionX-Amzn-SageMaker-Trajectory-Id: unique trajectory identifier
- Parameters:
metadata – The metadata from the rollout request (dict or RolloutMetadata).
- Returns:
Headers dict to add to inference calls.
- sagemaker.train.rft.sagemaker_rft_handler(func: Callable) Callable[source]#
Decorator for AgentCore Runtime entrypoints to handle RFT rollout lifecycle.
Automatically: 1. Sets rollout context (metadata + inference_params) for header injection 2. On success, calls CompleteTrajectory + UpdateReward if “reward” in result 3. On error, calls report_error 4. Clears context when done
Works with both sync and async functions.
Example:
from sagemaker.train.rft import rft_handler from sagemaker.train.rft.adapters.strands import wrap_model from strands import Agent from strands.models.openai import OpenAIModel model = wrap_model(OpenAIModel(client_args={...}, model_id="my-model")) agent = Agent(model=model, tools=[...]) @app.entrypoint @rft_handler async def invoke_agent(payload): result = await agent.invoke_async(payload["instance"]["prompt"]) return {"reward": compute_reward(result)}
- sagemaker.train.rft.set_rollout_context(metadata: dict[str, Any], inference_params: dict[str, Any] | None = None) None[source]#
Store rollout metadata and inference params in context.
Call this at the start of a rollout handler. Values are available via get_rollout_context() and get_inference_params() anywhere in the same thread/async context.
- Parameters:
metadata – Rollout metadata dict from the rollout request.
inference_params – Optional dict with sampling parameters (temperature, max_tokens, top_p).
Example:
from sagemaker.train.rft import set_rollout_context, clear_rollout_context @app.post("/rollout") def handle_rollout(request): set_rollout_context( metadata=request.metadata, inference_params=request.inference_params, ) try: result = my_agent.run(request.instance) finally: clear_rollout_context() return result
Modules
Framework-specific adapters for automatic header injection. |
|
Rollout context management using contextvars. |
|
Decorators for RFT integration. |
|
Rollout feedback client for reporting completion and rewards to the RFT Runtime Service. |
|
Header utilities for inference calls. |
|
Contract models for the rollout server API. |