Source code for sagemaker.train.rft.models
"""Contract models for the rollout server API.
These models define the enforced contract between the platform trainer
and customer rollout servers.
Customer server requirements:
POST /rollout - Accept RolloutRequest
GET /health - Return {"status": "healthy"} when ready
"""
from __future__ import annotations
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
[docs]
class InferenceParams(BaseModel):
"""Inference parameters for rollout sampling.
All fields are optional - if not provided, model defaults are used.
"""
temperature: Optional[float] = Field(default=None, description="Sampling temperature")
max_tokens: Optional[int] = Field(default=None, description="Maximum tokens to generate")
top_p: Optional[float] = Field(default=None, description="Top-p (nucleus) sampling")
[docs]
class RolloutRequest(BaseModel):
"""Request format sent by the trainer to your /rollout endpoint.
This is the enforced contract. Your server must accept this exact format.
"""
instance: Dict[str, Any] = Field(
description="Problem instance from customer's data file"
)
metadata: RolloutMetadata = Field(description="Platform-provided rollout context")
inference_params: Optional[InferenceParams] = Field(
default=None,
description="Optional inference parameters (temperature, max_tokens, top_p)",
)
model_name: Optional[str] = Field(
default=None, description="Optional model name override from trainer"
)
model_endpoint: Optional[str] = Field(
default=None, description="Optional model endpoint override from trainer"
)