Source code for sagemaker.train.rft.context
"""Rollout context management using contextvars.
Provides thread-safe storage for rollout metadata and inference parameters
that can be set at the top level and accessed deep in the call stack.
"""
from __future__ import annotations
from contextvars import ContextVar
from typing import Any
_rollout_metadata: ContextVar[dict[str, Any] | None] = ContextVar(
"rollout_metadata", default=None
)
_inference_params: ContextVar[dict[str, Any] | None] = ContextVar(
"inference_params", default=None
)
[docs]
def set_rollout_context(
metadata: dict[str, Any],
inference_params: dict[str, Any] | None = None,
) -> None:
"""Store rollout metadata and inference params in context.
Call this at the start of a rollout handler. Values are available
via get_rollout_context() and get_inference_params() anywhere in the
same thread/async context.
Args:
metadata: Rollout metadata dict from the rollout request.
inference_params: Optional dict with sampling parameters
(temperature, max_tokens, top_p).
Example::
from sagemaker.train.rft import set_rollout_context, clear_rollout_context
@app.post("/rollout")
def handle_rollout(request):
set_rollout_context(
metadata=request.metadata,
inference_params=request.inference_params,
)
try:
result = my_agent.run(request.instance)
finally:
clear_rollout_context()
return result
"""
_rollout_metadata.set(metadata)
_inference_params.set(inference_params)
[docs]
def get_rollout_context() -> dict[str, Any] | None:
"""Retrieve rollout metadata from context.
Returns:
The metadata dict if set, None otherwise.
Example::
from sagemaker.train.rft.context import get_rollout_context
ctx = get_rollout_context()
if ctx:
job_arn = ctx["job_arn"]
trajectory_id = ctx["trajectory_id"]
"""
return _rollout_metadata.get()
[docs]
def get_inference_params() -> dict[str, Any] | None:
"""Retrieve inference parameters from context.
Returns:
The inference_params dict if set, None otherwise.
Example::
from sagemaker.train.rft.context import get_inference_params
params = get_inference_params()
if params:
temperature = params.get("temperature", 1.0)
max_tokens = params.get("max_tokens", 4096)
"""
return _inference_params.get()
[docs]
def clear_rollout_context() -> None:
"""Clear rollout metadata and inference params from context."""
_rollout_metadata.set(None)
_inference_params.set(None)