sagemaker.train.rft.context#

Rollout context management using contextvars.

Provides thread-safe storage for rollout metadata and inference parameters that can be set at the top level and accessed deep in the call stack.

Functions

clear_rollout_context()

Clear rollout metadata and inference params from context.

get_inference_params()

Retrieve inference parameters from context.

get_rollout_context()

Retrieve rollout metadata from context.

set_rollout_context(metadata[, inference_params])

Store rollout metadata and inference params in context.

sagemaker.train.rft.context.clear_rollout_context() None[source]#

Clear rollout metadata and inference params from context.

sagemaker.train.rft.context.get_inference_params() dict[str, Any] | None[source]#

Retrieve inference parameters from context.

Returns:

The inference_params dict if set, None otherwise.

Example:

from sagemaker.train.rft.context import get_inference_params

params = get_inference_params()
if params:
    temperature = params.get("temperature", 1.0)
    max_tokens = params.get("max_tokens", 4096)
sagemaker.train.rft.context.get_rollout_context() dict[str, Any] | None[source]#

Retrieve rollout metadata from context.

Returns:

The metadata dict if set, None otherwise.

Example:

from sagemaker.train.rft.context import get_rollout_context

ctx = get_rollout_context()
if ctx:
    job_arn = ctx["job_arn"]
    trajectory_id = ctx["trajectory_id"]
sagemaker.train.rft.context.set_rollout_context(metadata: dict[str, Any], inference_params: dict[str, Any] | None = None) None[source]#

Store rollout metadata and inference params in context.

Call this at the start of a rollout handler. Values are available via get_rollout_context() and get_inference_params() anywhere in the same thread/async context.

Parameters:
  • metadata – Rollout metadata dict from the rollout request.

  • inference_params – Optional dict with sampling parameters (temperature, max_tokens, top_p).

Example:

from sagemaker.train.rft import set_rollout_context, clear_rollout_context

@app.post("/rollout")
def handle_rollout(request):
    set_rollout_context(
        metadata=request.metadata,
        inference_params=request.inference_params,
    )
    try:
        result = my_agent.run(request.instance)
    finally:
        clear_rollout_context()
    return result