Source code for sagemaker.train.rft.headers
"""Header utilities for inference calls.
Provides functions to create the HTTP headers required by the
AgenticRFTRuntimeService for each inference request.
The runtime service expects individual headers:
- ``X-Amzn-SageMaker-Job-Arn``: job ARN that identifies the training session
- ``X-Amzn-SageMaker-Trajectory-Id``: unique trajectory identifier
"""
from __future__ import annotations
import warnings
from typing import Any
from sagemaker.train.rft.context import get_rollout_context
from sagemaker.train.rft.models import RolloutMetadata
[docs]
def make_inference_headers(metadata: dict[str, Any] | RolloutMetadata) -> dict[str, str]:
"""Create headers dict for inference calls.
Produces the individual headers required by the RFT Runtime Service:
- ``X-Amzn-SageMaker-Job-Arn``: job ARN that identifies the training session
- ``X-Amzn-SageMaker-Trajectory-Id``: unique trajectory identifier
Args:
metadata: The metadata from the rollout request (dict or RolloutMetadata).
Returns:
Headers dict to add to inference calls.
"""
if isinstance(metadata, RolloutMetadata):
metadata = metadata.model_dump()
# Accept both camelCase (from TLM) and snake_case field names
job_arn = metadata.get("job_arn") or metadata.get("jobArn")
trajectory_id = (
metadata.get("trajectory_id")
or metadata.get("trajectoryId")
or metadata.get("rolloutId")
)
headers: dict[str, str] = {}
if job_arn:
headers["X-Amzn-SageMaker-Job-Arn"] = job_arn
if trajectory_id:
headers["X-Amzn-SageMaker-Trajectory-Id"] = trajectory_id
return headers
[docs]
def get_inference_headers() -> dict[str, str]:
"""Get headers from current rollout context.
For use with set_rollout_context() when you need to retrieve headers
deep in the call stack without passing them explicitly.
Returns:
Headers dict, or empty dict if no context set (with warning).
"""
metadata = get_rollout_context()
if metadata is None:
warnings.warn(
"get_inference_headers() called but no rollout context set. "
"Did you forget to call set_rollout_context()? "
"Returning empty headers.",
stacklevel=2,
)
return {}
return make_inference_headers(metadata)