Source code for sagemaker.train.rft.adapters.strands
"""Strands framework adapter for automatic header and inference param injection.
Provides wrap_model() which wraps a Strands model to automatically inject
RFT headers and inference parameters into requests using the rollout context.
The wrapper intercepts ``stream()`` and injects headers via
``client_args["default_headers"]`` because Strands ``OpenAIModel`` creates
a new OpenAI client per request from ``client_args``.
"""
from __future__ import annotations
import logging
from typing import Any
from sagemaker.train.rft.headers import get_inference_headers
from sagemaker.train.rft.context import get_inference_params
logger = logging.getLogger(__name__)
[docs]
def wrap_model(model: Any) -> Any:
"""Wrap a Strands model to auto-inject headers and inference params from context.
Creates a transparent proxy that:
1. Injects the ``X-RFT-Metadata`` header (containing job_id, experiment_id,
rollout_id) via client_args["default_headers"] on every stream() call
2. Injects inference parameters (temperature, max_tokens, top_p)
Args:
model: A Strands model instance (e.g., ``OpenAIModel``).
Returns:
A wrapped model that transparently injects RFT headers.
Example::
from strands.models.openai import OpenAIModel
from strands import Agent
from sagemaker.train.rft.adapters.strands import wrap_model
model = OpenAIModel(
client_args={"api_key": "...", "base_url": "..."},
model_id="my-model",
)
wrapped = wrap_model(model)
agent = Agent(model=wrapped, tools=[...])
result = agent("Solve this task")
"""
return _RFTModelWrapper(model)
class _RFTModelWrapper:
"""Transparent proxy that injects RFT headers into Strands model calls.
Delegates all attribute access to the inner model so it quacks like
the original. Intercepts ``stream()`` to inject headers.
"""
def __init__(self, inner_model: Any):
object.__setattr__(self, "_inner", inner_model)
def __getattr__(self, name: str) -> Any:
return getattr(self._inner, name)
def __setattr__(self, name: str, value: Any):
if name == "_inner":
object.__setattr__(self, name, value)
else:
setattr(self._inner, name, value)
def stream(self, *args: Any, **kwargs: Any) -> Any:
"""Intercept stream() to inject RFT headers via client_args default_headers.
The OpenAI client supports ``default_headers`` in its constructor,
which are sent with every request. We inject the RFT headers there since
Strands OpenAIModel creates a new client per request from ``client_args``.
"""
rft_headers = get_inference_headers()
if rft_headers:
client_args = getattr(self._inner, "client_args", None)
if client_args is not None:
existing = client_args.get("default_headers") or {}
existing.update(rft_headers)
client_args["default_headers"] = existing
logger.debug("Injected RFT headers: %s", list(rft_headers.keys()))
# Inject inference params via update_config(params={...})
inference_params = get_inference_params()
if inference_params:
params_update = {}
for camel, snake in [("temperature", "temperature"), ("maxTokens", "max_tokens"), ("topP", "top_p")]:
val = inference_params.get(snake) if inference_params.get(snake) is not None else inference_params.get(camel)
if val is not None:
params_update[snake] = val
if params_update:
logger.debug("Updating model params with inference params: %s", params_update)
existing_params = dict(self._inner.config.get("params", {}) or {})
existing_params.update(params_update)
self._inner.update_config(params=existing_params)
return self._inner.stream(*args, **kwargs)
def update_config(self, **model_config: Any) -> None:
return self._inner.update_config(**model_config)
def get_config(self) -> Any:
return self._inner.get_config()
def structured_output(self, *args: Any, **kwargs: Any) -> Any:
return self._inner.structured_output(*args, **kwargs)