sagemaker.train.rft#

SageMaker RFT SDK - Integration library for multi-turn RL training platform.

Strands + AgentCore (simplest):

from sagemaker.train.rft import sagemaker_rft_handler, RolloutFeedbackClient
from sagemaker.train.rft.adapters.strands import wrap_model

model = wrap_model(OpenAIModel(...))

@app.entrypoint
@sagemaker_rft_handler
async def invoke_agent(payload):
    result = await agent.invoke_async(payload["instance"])
    return result

Strands Standalone:

from sagemaker.train.rft import set_rollout_context, RolloutFeedbackClient
from sagemaker.train.rft.adapters.strands import wrap_model

model = wrap_model(model)

@app.post("/rollout")
def rollout(request):
    set_rollout_context(request.metadata, request.inference_params)
    result = agent(request.instance)
    RolloutFeedbackClient(request.metadata).report_complete(reward)

Custom Integration:

from sagemaker.train.rft import make_inference_headers, RolloutFeedbackClient

@app.post("/rollout")
def handle(request):
    headers = make_inference_headers(request.metadata)
    client = OpenAI(base_url=endpoint, default_headers=headers)
    result = my_agent.run(request.instance, client)
    RolloutFeedbackClient(request.metadata).report_complete(reward)
class sagemaker.train.rft.InferenceParams(*, temperature: float | None = None, max_tokens: int | None = None, top_p: float | None = None)[source]#

Bases: BaseModel

Inference parameters for rollout sampling.

All fields are optional - if not provided, model defaults are used.

max_tokens: int | None#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

temperature: float | None#
top_p: float | None#
class sagemaker.train.rft.RolloutFeedbackClient(metadata: dict[str, Any] | RolloutMetadata)[source]#

Bases: object

Client for reporting rollout completion to the RFT Runtime Service.

Calls the runtime service’s /complete-rollout and /update-reward APIs using bearer token auth.

Example:

feedback = RolloutFeedbackClient(metadata)
feedback.report_complete(reward=0.95)
complete_rollout(status: str = 'ready') None[source]#

Report trajectory completion to the runtime service.

Parameters:

status – Target status - “ready” for success, “failed” for errors.

report_complete(reward: float | List[float]) None[source]#

Complete the trajectory and report reward(s).

Convenience method that calls complete_rollout() then update_reward().

Parameters:

reward – The computed reward(s) for this rollout.

report_error(error: str, reward: float | None = None) None[source]#

Report a rollout error, marking the trajectory as failed.

Parameters:
  • error – Error description.

  • reward – Optional partial reward (defaults to 0.0).

update_reward(reward: float | List[float]) None[source]#

Report reward(s) to the runtime service.

Parameters:

reward – A single float or list of floats for per-turn rewards.

class sagemaker.train.rft.RolloutMetadata(*, job_arn: str, trajectory_id: str, endpoint: str, region: str = 'us-west-2')[source]#

Bases: BaseModel

Metadata sent by the trainer with each rollout request.

Pass this entire object (or its dict form) to RolloutFeedbackClient and make_inference_headers.

endpoint: str#
job_arn: str#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

region: str#
trajectory_id: str#
class sagemaker.train.rft.RolloutRequest(*, instance: Dict[str, Any], metadata: RolloutMetadata, inference_params: InferenceParams | None = None, model_name: str | None = None, model_endpoint: str | None = None)[source]#

Bases: BaseModel

Request format sent by the trainer to your /rollout endpoint.

This is the enforced contract. Your server must accept this exact format.

inference_params: InferenceParams | None#
instance: Dict[str, Any]#
metadata: RolloutMetadata#
model_config: ClassVar[ConfigDict] = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

model_endpoint: str | None#
model_name: str | None#
sagemaker.train.rft.clear_rollout_context() None[source]#

Clear rollout metadata and inference params from context.

sagemaker.train.rft.get_inference_headers() dict[str, str][source]#

Get headers from current rollout context.

For use with set_rollout_context() when you need to retrieve headers deep in the call stack without passing them explicitly.

Returns:

Headers dict, or empty dict if no context set (with warning).

sagemaker.train.rft.get_inference_params() dict[str, Any] | None[source]#

Retrieve inference parameters from context.

Returns:

The inference_params dict if set, None otherwise.

Example:

from sagemaker.train.rft.context import get_inference_params

params = get_inference_params()
if params:
    temperature = params.get("temperature", 1.0)
    max_tokens = params.get("max_tokens", 4096)
sagemaker.train.rft.make_inference_headers(metadata: dict[str, Any] | RolloutMetadata) dict[str, str][source]#

Create headers dict for inference calls.

Produces the individual headers required by the RFT Runtime Service:
  • X-Amzn-SageMaker-Job-Arn: job ARN that identifies the training session

  • X-Amzn-SageMaker-Trajectory-Id: unique trajectory identifier

Parameters:

metadata – The metadata from the rollout request (dict or RolloutMetadata).

Returns:

Headers dict to add to inference calls.

sagemaker.train.rft.sagemaker_rft_handler(func: Callable) Callable[source]#

Decorator for AgentCore Runtime entrypoints to handle RFT rollout lifecycle.

Automatically: 1. Sets rollout context (metadata + inference_params) for header injection 2. On success, calls CompleteTrajectory + UpdateReward if “reward” in result 3. On error, calls report_error 4. Clears context when done

Works with both sync and async functions.

Example:

from sagemaker.train.rft import rft_handler
from sagemaker.train.rft.adapters.strands import wrap_model
from strands import Agent
from strands.models.openai import OpenAIModel

model = wrap_model(OpenAIModel(client_args={...}, model_id="my-model"))
agent = Agent(model=model, tools=[...])

@app.entrypoint
@rft_handler
async def invoke_agent(payload):
    result = await agent.invoke_async(payload["instance"]["prompt"])
    return {"reward": compute_reward(result)}
sagemaker.train.rft.set_rollout_context(metadata: dict[str, Any], inference_params: dict[str, Any] | None = None) None[source]#

Store rollout metadata and inference params in context.

Call this at the start of a rollout handler. Values are available via get_rollout_context() and get_inference_params() anywhere in the same thread/async context.

Parameters:
  • metadata – Rollout metadata dict from the rollout request.

  • inference_params – Optional dict with sampling parameters (temperature, max_tokens, top_p).

Example:

from sagemaker.train.rft import set_rollout_context, clear_rollout_context

@app.post("/rollout")
def handle_rollout(request):
    set_rollout_context(
        metadata=request.metadata,
        inference_params=request.inference_params,
    )
    try:
        result = my_agent.run(request.instance)
    finally:
        clear_rollout_context()
    return result

Modules

adapters

Framework-specific adapters for automatic header injection.

context

Rollout context management using contextvars.

decorators

Decorators for RFT integration.

feedback

Rollout feedback client for reporting completion and rewards to the RFT Runtime Service.

headers

Header utilities for inference calls.

models

Contract models for the rollout server API.