import copy
import os
import yaml
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, List, Union
import json
import logging
import re
import subprocess
import tarfile
import tempfile
from urllib.parse import urlparse
import yaml
import boto3
from sagemaker.core.helper.session_helper import Session
from sagemaker.core.training.configs import Tag, Networking, InputData, Channel, OutputDataConfig
from sagemaker.core.shapes import shapes
from sagemaker.core.resources import TrainingJob
from sagemaker.train.common_utils.recipe_utils import _is_nova_model, resolve_recipe, get_resolved_recipe_from_context, NoRecipeError
from sagemaker.core.s3.utils import resolve_s3_uri_placeholders
from sagemaker.train.recipe_resolver import flatten_resolved_recipe
from sagemaker.train.common_utils.finetune_utils import (
get_training_image,
get_hyperpod_training_image,
get_hyperpod_recipe_path,
get_recipe_s3_uri,
_validate_hyperparameter_values,
_get_smhp_replicas_enum,
)
from sagemaker.train.common_utils.mlflow_config_utils import resolve_mlflow_tracking_fields
from sagemaker.train.common_utils.validator import validate_hyperpod_compute
from sagemaker.train.defaults import TrainDefaults
from sagemaker.train.utils import _get_unique_name
[docs]
class BaseTrainer(ABC):
"""Abstract base class for all SageMaker training workflows.
This class provides the common interface and shared functionality for all trainer implementations
including SFT, DPO, RLVR, and RLAIF trainers. It defines the standard parameters and abstract
methods that concrete trainer classes must implement.
Parameters:
sagemaker_session (Optional[Session]):
The SageMaker session for managing API calls and resources.
If not specified, a default session will be created.
role (Optional[str]):
The IAM role ARN for the training job execution.
If not specified, the default SageMaker execution role will be used.
base_job_name (Optional[str]):
The base name for training jobs. A unique suffix will be appended.
If not specified, a default name will be generated based on the trainer type.
tags (Optional[List[Tag]]):
List of tags to apply to the training job for resource management and billing.
hyperparameters (Optional[Dict[str, Any]]):
Dictionary of hyperparameters for the training job.
Trainer-specific defaults will be applied if not specified.
output_data_config (Optional[shapes.OutputDataConfig]):
Configuration for training job outputs including S3 paths and encryption.
If not specified, default output configuration will be used.
input_data_config (Optional[List[Union[Channel, InputData]]]):
List of input data channels for the training job.
Can include training and validation datasets.
environment (Optional[Dict[str, str]]):
Environment variables to set in the training container.
training_image (Optional[str]):
Custom training container image URI. If not provided, the image is
auto-resolved from the model's recipe metadata in SageMaker Hub.
"""
# Class-level attributes with default values
sagemaker_session: Optional[Session] = None
role: Optional[str] = None
base_job_name: Optional[str] = None
tags: Optional[List[Tag]] = None
hyperparameters: Optional[Dict[str, Any]] = None
output_data_config: Optional[shapes.OutputDataConfig] = None
input_data_config: Optional[List[Union[Channel, InputData]]] = None
environment: Optional[Dict[str, str]] = None
training_image: Optional[str] = None
latest_training_job: Optional[TrainingJob] = None
def __init__(
self,
sagemaker_session: Optional[Session] = None,
role: Optional[str] = None,
base_job_name: Optional[str] = None,
tags: Optional[List[Tag]] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
output_data_config: Optional[shapes.OutputDataConfig] = None,
input_data_config: Optional[List[Union[Channel, InputData]]] = None,
environment: Optional[Dict[str, str]] = None,
training_image: Optional[str] = None,
base_model_name: Optional[str] = None,
disable_output_compression: Optional[bool] = False,
):
self.sagemaker_session = sagemaker_session
self.role = role
self.base_job_name = base_job_name
self.tags = tags
self.hyperparameters = hyperparameters or {}
self.output_data_config = output_data_config
self.input_data_config = input_data_config
self.environment = environment or {}
self.training_image = training_image
self.base_model_name = base_model_name
self.disable_output_compression = disable_output_compression
self._checkpoint_s3_uri = None
def _is_nova_model_for_telemetry(self) -> bool:
"""Check if the model is a Nova model for telemetry tracking."""
model_name = getattr(self, "_model_name", None)
return _is_nova_model(model_name) if model_name else False
[docs]
def get_resolved_recipe(self) -> Dict[str, Any]:
"""Return the fully resolved recipe configuration.
Shows the final merged result of base defaults + user recipe + overrides
after interpolation resolution and validation. Callable before or after train().
When neither ``recipe`` nor ``overrides`` were provided at construction time
but hyperparameters have been set directly (e.g. ``trainer.hyperparameters.x = val``),
those user-set values are treated as implicit overrides so the resolved recipe
still reflects the user's intent.
Returns:
dict: Deep copy of the resolved recipe configuration.
Raises:
ValueError: If no recipe, overrides, or direct hyperparameter assignments
were provided.
"""
# Fetch full recipe template from Hub to preserve YAML structure
full_recipe_template = self._fetch_full_recipe_template()
resolved = get_resolved_recipe_from_context(
recipe_path=getattr(self, '_recipe_path', None),
overrides=getattr(self, '_overrides', None),
hyperparameters=self.hyperparameters if hasattr(self, 'hyperparameters') else None,
resolved_cache=getattr(self, '_resolved_recipe_cache', None),
template_section="training_config",
protected_keys={"model_type", "model_name_or_path", "dataset_catalog"},
full_recipe_template=full_recipe_template,
compute=getattr(self, 'compute', None),
)
# Post-resolution patches for display accuracy
self._patch_resolved_recipe(resolved)
self._resolved_recipe_cache = resolved
return copy.deepcopy(resolved)
def _fetch_full_recipe_template(self) -> Optional[Dict[str, Any]]:
"""Fetch the full recipe template from Hub to preserve YAML structure.
Returns None if the template can't be fetched (fallback to synthetic template).
"""
logger = logging.getLogger(__name__)
frt = getattr(self.hyperparameters, '_full_recipe_template', None) if hasattr(self, 'hyperparameters') else None
if isinstance(frt, dict):
return frt
if not hasattr(self, '_model_name') or not hasattr(self, '_customization_technique'):
return None
try:
from sagemaker.core.training.configs import HyperPodCompute
from sagemaker.train.common_utils.finetune_utils import (
_get_recipe_entry_and_override_spec,
_extract_recipe_from_helm_template,
)
is_hyperpod = isinstance(getattr(self, 'compute', None), HyperPodCompute)
sagemaker_session = TrainDefaults.get_sagemaker_session(sagemaker_session=self.sagemaker_session)
platform = "hyperpod" if is_hyperpod else "smtj"
recipe_entry, _ = _get_recipe_entry_and_override_spec(
model_name=self._model_name,
customization_technique=self._customization_technique,
training_type=self.training_type,
sagemaker_session=sagemaker_session,
platform=platform,
)
s3_client = sagemaker_session.boto_session.client("s3")
if is_hyperpod:
hp_uri = recipe_entry["HpEksPayloadTemplateS3Uri"]
bucket, key = hp_uri.replace("s3://", "").split("/", 1)
raw = s3_client.get_object(Bucket=bucket, Key=key)["Body"].read().decode("utf-8")
return yaml.safe_load(_extract_recipe_from_helm_template(raw))
else:
smtj_uri = resolve_s3_uri_placeholders(recipe_entry["SmtjRecipeTemplateS3Uri"], sagemaker_session)
uri_path = smtj_uri.replace("s3://", "")
if uri_path.startswith("arn:"):
match = re.match(r'(arn:aws:s3:[^:]*:[^:]*:accesspoint/[^/]+)/(.*)', uri_path)
bucket, key = (match.group(1), match.group(2)) if match else uri_path.split("/", 1)
else:
bucket, key = uri_path.split("/", 1)
tmp = tempfile.NamedTemporaryFile(suffix=".yaml", delete=False)
s3_client.download_file(bucket, key, tmp.name)
with open(tmp.name, "r") as f:
return yaml.safe_load(f)
except Exception as e:
logger.debug(f"Could not fetch full recipe template: {e}")
return None
def _patch_resolved_recipe(self, resolved: Dict[str, Any]) -> None:
"""Apply post-resolution patches to make the preview match actual job config."""
from sagemaker.train.recipe_resolver import _set_nested_value, _build_key_path_map
# Build a map of where keys live in the resolved structure
patch_values = {}
# base_job_name → name
if self.base_job_name:
patch_values["name"] = _get_unique_name(self.base_job_name)
# output_s3_path and data_s3_path from trainer config
if getattr(self, 's3_output_path', None):
patch_values["output_s3_path"] = self.s3_output_path
if getattr(self, 'training_dataset', None):
patch_values["data_s3_path"] = self.training_dataset
# Subclass-specific hyperparameters (e.g. reward_lambda_arn for RLVR)
patch_values.update(self._get_extra_smtj_hyperparameters())
if not patch_values:
return
# Find where each key lives in the resolved dict and set it
key_path_map = _build_key_path_map(resolved, set(patch_values.keys()))
for key, value in patch_values.items():
dotpath = key_path_map.get(key)
if dotpath:
_set_nested_value(resolved, dotpath, value)
def _apply_recipe_to_hyperparameters(self, final_hyperparameters: Dict[str, Any]) -> Dict[str, Any]:
"""Apply resolved recipe values to final_hyperparameters dict.
If recipe/overrides were provided, or if the user set hyperparameters
directly via ``.hyperparameters.*``, merges resolved recipe values into
the hyperparameters dict. All leaf values from the resolved recipe are
applied — including keys not in the Hub spec subset — enabling
power users to override any parameter in the full recipe.
Values are converted to strings (matching the SageMaker API
expectation for hyperparameter values).
Args:
final_hyperparameters: The hyperparameters dict from to_dict().
Returns:
The updated hyperparameters dict with recipe values applied.
"""
if not hasattr(self, 'hyperparameters') or not isinstance(getattr(self.hyperparameters, '_specs', None), dict):
return final_hyperparameters
try:
resolved = self.get_resolved_recipe()
except NoRecipeError:
return final_hyperparameters
flat = flatten_resolved_recipe(resolved)
for k, v in flat.items():
if v is not None:
final_hyperparameters[k] = str(v) if not isinstance(v, str) else v
return final_hyperparameters
def _validate_instance_count(self, instance_count, sagemaker_session):
"""Validate instance/node count against allowed values from SMHP recipe."""
smhp_replicas_enum = _get_smhp_replicas_enum(
model_name=self._model_name,
customization_technique=self._customization_technique,
training_type=self.training_type,
sagemaker_session=sagemaker_session,
)
if smhp_replicas_enum and instance_count not in smhp_replicas_enum:
raise ValueError(
f"Node/Instance count '{instance_count}' is not supported. "
f"Allowed values: {sorted(smhp_replicas_enum)}."
)
return smhp_replicas_enum
[docs]
@abstractmethod
def train(self, input_data_config: List[InputData], wait: bool = True, logs: bool = True, wait_timeout: Optional[int] = None):
"""Common training method that calls the specific implementation."""
pass
def _get_extra_smtj_hyperparameters(self) -> Dict[str, Any]:
"""Return extra hyperparameters to inject for SMTJ training.
Subclasses can override this to add trainer-specific hyperparameters
(e.g. RLVRTrainer adds ``reward_lambda_arn``).
Returns:
Dict of additional hyperparameters to merge.
"""
return {}
def _train_serverful_smtj(self, training_dataset=None, validation_dataset=None,
wait=True, wait_timeout=None, poll=5):
"""Execute training on serverful SageMaker Training Job (SMTJ) compute.
Uses ModelTrainer.from_recipe() with the model's recipe template from
SageMaker Hub, running on user-specified instances.
This method is shared across SFT, DPO, and RLVR trainers. The only
trainer-specific variation is the ``customization_technique`` (derived
from ``self._customization_technique``) and any extra hyperparameters
from ``_get_extra_smtj_hyperparameters()``.
"""
import logging
import tempfile
from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.core.training.configs import TrainingJobCompute, InputData, Networking
from sagemaker.core.shapes import S3DataSource
from sagemaker.train.common_utils.finetune_utils import (
get_recipe_s3_uri,
get_training_image,
_validate_hyperparameter_values,
)
from sagemaker.train.defaults import TrainDefaults
logger = logging.getLogger(__name__)
sagemaker_session = TrainDefaults.get_sagemaker_session(
sagemaker_session=self.sagemaker_session
)
role = TrainDefaults.get_role(role=self.role, sagemaker_session=sagemaker_session)
compute = self.compute
customization_technique = self._customization_technique
# Resolve the recipe S3 URI from hub metadata
recipe_s3_uri = get_recipe_s3_uri(
model_name=self._model_name,
customization_technique=customization_technique,
training_type=self.training_type,
sagemaker_session=sagemaker_session,
)
logger.info(f"SMTJ recipe S3 URI: {recipe_s3_uri}")
# Download recipe from S3 to a local temp file
recipe_s3_uri = resolve_s3_uri_placeholders(recipe_s3_uri, sagemaker_session)
s3_client = sagemaker_session.boto_session.client("s3")
uri_path = recipe_s3_uri.replace("s3://", "")
# Handle S3 access point ARN URIs
if uri_path.startswith("arn:"):
match = re.match(r'(arn:aws:s3:[^:]*:[^:]*:accesspoint/[^/]+)/(.*)', uri_path)
if match:
bucket = match.group(1)
key = match.group(2)
else:
raise ValueError(f"Cannot parse S3 access point ARN: {uri_path}")
else:
bucket, key = uri_path.split("/", 1)
recipe_tmp = tempfile.NamedTemporaryFile(
prefix="smtj_recipe_", suffix=".yaml", delete=False
)
s3_client.download_file(bucket, key, recipe_tmp.name)
recipe_local_path = recipe_tmp.name
# Resolve datasets up front so their paths can be injected into the recipe
# before rendering. The recipe maps data.train_files -> {{data_path}} and
# data.val_files -> {{validation_data_path}}; these are mounted into the
# container as SageMaker input channels (train/validation), so the recipe
# must point at the local channel mount paths, not be left empty.
resolved_training_dataset = training_dataset or self.training_dataset
resolved_validation_dataset = validation_dataset or self.validation_dataset
def _channel_mount_path(dataset_uri, channel_name):
"""Map an S3 dataset URI to the container path where its channel mounts."""
mount_dir = "/opt/ml/input/data/" + channel_name
basename = dataset_uri.rstrip("/").rsplit("/", 1)[-1]
# A trailing object key with an extension is mounted as that file; an
# S3 prefix (directory) is mounted as the channel directory itself.
if "." in basename:
return mount_dir + "/" + basename
return mount_dir
# Render {{placeholder}} values in the recipe template with defaults
from sagemaker.train.common_utils.finetune_utils import (
_render_recipe_placeholders,
_get_smtj_override_spec,
_get_smhp_replicas_enum,
_resolve_base_model_weights_s3_uri,
)
override_spec = _get_smtj_override_spec(
model_name=self._model_name,
customization_technique=customization_technique,
training_type=self.training_type,
sagemaker_session=sagemaker_session,
)
# Validate instance count against allowed values from SMHP recipe.
smhp_replicas_enum = self._validate_instance_count(compute.instance_count, sagemaker_session)
if smhp_replicas_enum:
override_spec.setdefault("replicas", {})["enum"] = smhp_replicas_enum
if hasattr(self, 'hyperparameters') and hasattr(self.hyperparameters, '_specs'):
self.hyperparameters._specs.setdefault("replicas", {})["enum"] = smhp_replicas_enum
if not hasattr(self.hyperparameters, 'replicas'):
object.__setattr__(self.hyperparameters, 'replicas', compute.instance_count)
# Inject the resolved dataset channel paths so the rendered recipe's
# train_files / val_files are non-empty (the container aborts otherwise).
def _set_spec_default(spec, key, value):
entry = spec.get(key)
if isinstance(entry, dict):
entry["default"] = value
else:
spec[key] = {"default": value, "type": "string"}
# For OSS/LLMFT models the recipe's model_name_or_path feeds straight into
# AutoModelForCausalLM.from_pretrained(), so it must point at HF-format weights
# on the local filesystem. When the Hub override spec leaves it empty, deliver
# the SageMaker-prepared base weights via a dedicated "model" input channel and
# point model_name_or_path at that channel's local mount.
# Scoped to non-Nova: Nova recipes resolve model_name_or_path through
# _get_args_from_nova_recipe (into the base_model hyperparameter), so this
# OSS-specific workaround must never touch the Nova flow.
base_model_weights_uri = getattr(self, 'model_source', None) if not _is_nova_model(self._model_name) else None
if not _is_nova_model(self._model_name):
model_name_or_path_spec = override_spec.get("model_name_or_path")
if model_name_or_path_spec is not None:
current_default = model_name_or_path_spec.get("default", "") if isinstance(model_name_or_path_spec, dict) else model_name_or_path_spec
if not current_default and not base_model_weights_uri:
base_model_weights_uri = _resolve_base_model_weights_s3_uri(
model_name=self._model_name,
sagemaker_session=sagemaker_session,
)
if base_model_weights_uri:
_set_spec_default(
override_spec, "model_name_or_path",
"/opt/ml/input/data/model",
)
if resolved_training_dataset:
_set_spec_default(
override_spec, "data_path",
_channel_mount_path(resolved_training_dataset, "train"),
)
if resolved_validation_dataset:
_set_spec_default(
override_spec, "validation_data_path",
_channel_mount_path(resolved_validation_dataset, "validation"),
)
# Point the recipe's output/training dir at the local SageMaker model dir so the
# trained model is written there and SageMaker uploads it to s3_output_path as
# model.tar.gz. Without this, the recipe's {{output_path}} renders empty and the
# container writes the model to a local cwd that never gets uploaded (job succeeds
# but no artifact lands in S3). The llmft container uses local paths for output
# (e.g. the metering callback writes to /opt/ml/metering), so /opt/ml/model is the
# correct target. Scoped to non-Nova: Nova uses a managed escrow output mechanism.
if not _is_nova_model(self._model_name) and "output_path" in override_spec:
_set_spec_default(override_spec, "output_path", "/opt/ml/model")
# MLflow configuration: inject tracking URI, experiment name, and run name
# into the recipe override spec so they render into {{mlflow_*}} placeholders.
# Uses the shared resolve helper to default empty names to base_job_name when
# a tracking URI is set (prevents OSS container recipe validation failures).
job_base_name = self.base_job_name or f"{self._model_name}-{customization_technique}"
mlflow_tracking_uri, mlflow_experiment_name, mlflow_run_name = (
resolve_mlflow_tracking_fields(
mlflow_tracking_uri=getattr(self, 'mlflow_resource_arn', None),
mlflow_experiment_name=getattr(self, 'mlflow_experiment_name', None),
mlflow_run_name=getattr(self, 'mlflow_run_name', None),
base_job_name=job_base_name,
)
)
if mlflow_tracking_uri:
_set_spec_default(override_spec, "mlflow_tracking_uri", mlflow_tracking_uri)
_set_spec_default(override_spec, "mlflow_experiment_name", mlflow_experiment_name)
_set_spec_default(override_spec, "mlflow_run_name", mlflow_run_name)
# Inject user-set hyperparameters into the recipe before rendering.
# For LLMFT/SMTJ the recipe YAML is the source of truth: ModelTrainer.from_recipe
# ignores the hyperparameters dict for non-Nova recipes, so values the user set on
# self.hyperparameters (e.g. global_batch_size, learning_rate, max_epochs) must be
# rendered into the recipe's {{placeholders}} or they are silently dropped in favor
# of the Hub spec defaults.
def _yaml_safe_default(value):
# Render floats in decimal form: scientific notation like "5e-06" is parsed
# as a string by YAML, which breaks numeric recipe fields.
if isinstance(value, float):
s = format(value, ".12f").rstrip("0")
return s + "0" if s.endswith(".") else s
return value
for hp_key in (getattr(self.hyperparameters, "_user_set", None) or []):
if hp_key in override_spec:
hp_value = getattr(self.hyperparameters, hp_key, None)
if hp_value is not None:
_set_spec_default(override_spec, hp_key, _yaml_safe_default(hp_value))
# Build hyperparameters early to inject into recipe template before runtime.
final_hyperparameters = self.hyperparameters.to_dict()
_validate_hyperparameter_values(final_hyperparameters)
# Allow subclasses to inject extra hyperparameters
extra_hp = self._get_extra_smtj_hyperparameters()
if extra_hp:
final_hyperparameters.update(extra_hp)
# Merge user-provided recipe/overrides into hyperparameters
final_hyperparameters = self._apply_recipe_to_hyperparameters(final_hyperparameters)
# Inject all final hyperparameters into the override spec
for hp_key, hp_value in final_hyperparameters.items():
if hp_value is not None and hp_value != "":
_set_spec_default(override_spec, hp_key, hp_value)
with open(recipe_local_path, "r") as f:
recipe_content = f.read()
recipe_content = _render_recipe_placeholders(recipe_content, override_spec)
# Inject model_source into the recipe as model_name_or_path for iterative
# training (resuming from a previously trained checkpoint).
# Only applies to Nova models — OSS models handle this via the input channel.
if getattr(self, 'model_source', None) and _is_nova_model(self._model_name):
import yaml as _yaml
recipe_dict = _yaml.safe_load(recipe_content)
applied = False
if "run" in recipe_dict and isinstance(recipe_dict["run"], dict):
recipe_dict["run"]["model_name_or_path"] = self.model_source
applied = True
if not applied:
logger.warning(
"model checkpoint path was provided but the expected recipe path for "
"'model_name_or_path' was not found. The checkpoint path will not be applied."
)
else:
recipe_content = _yaml.dump(recipe_dict, default_flow_style=False, sort_keys=False)
logger.info(f"Overriding model_name_or_path with checkpoint: {self.model_source}")
with open(recipe_local_path, "w") as f:
f.write(recipe_content)
logger.info(f"Recipe downloaded and rendered to: {recipe_local_path}")
# Resolve training image
training_image = self.training_image
if not training_image:
training_image = get_training_image(
model_name=self._model_name,
customization_technique=customization_technique,
training_type=self.training_type,
sagemaker_session=sagemaker_session,
)
if not training_image:
raise ValueError(
"training_image is required for SMTJ compute but could not be resolved "
"from model metadata. Pass it explicitly via the trainer's "
"training_image parameter."
)
# Build compute config for ModelTrainer
trainer_compute = TrainingJobCompute(
instance_type=compute.instance_type,
instance_count=compute.instance_count,
volume_size_in_gb=compute.volume_size_in_gb,
keep_alive_period_in_seconds=compute.keep_alive_period_in_seconds,
)
# Build input data config (datasets resolved earlier for recipe injection)
# Build input data config
resolved_training_dataset = training_dataset or self.training_dataset
resolved_validation_dataset = validation_dataset or self.validation_dataset
# Use "Converse" S3DataType for Nova SFT and DPO datasets
is_nova = _is_nova_model(self._model_name)
use_converse = is_nova and customization_technique not in ("RLVR", "RLAIF")
s3_data_type = "Converse" if use_converse else "S3Prefix"
input_data_list = []
if resolved_training_dataset:
input_data_list.append(
InputData(
channel_name="train",
data_source=S3DataSource(
s3_uri=resolved_training_dataset,
s3_data_type=s3_data_type,
s3_data_distribution_type="FullyReplicated",
),
)
)
if resolved_validation_dataset:
input_data_list.append(
InputData(
channel_name="validation",
data_source=S3DataSource(
s3_uri=resolved_validation_dataset,
s3_data_type=s3_data_type,
s3_data_distribution_type="FullyReplicated",
),
)
)
# For OSS/LLMFT models, deliver the SageMaker-prepared base model weights as a
# "model" channel (mounted at /opt/ml/input/data/model). The recipe's
# model_name_or_path was pointed at that mount above.
if base_model_weights_uri:
input_data_list.append(
InputData(
channel_name="model",
data_source=S3DataSource(
s3_uri=base_model_weights_uri,
s3_data_type="S3Prefix",
s3_data_distribution_type="FullyReplicated",
),
)
)
# Build networking config
networking = None
if self.networking:
networking = Networking(
security_group_ids=getattr(self.networking, 'security_group_ids', None),
subnets=getattr(self.networking, 'subnets', None),
)
# Create ModelTrainer from recipe
base_job_name = self.base_job_name or f"{self._model_name}-{customization_technique}"
# Build output data config from s3_output_path if provided
output_data_config = None
if self.s3_output_path:
output_config_kwargs = {"s3_output_path": self.s3_output_path}
if getattr(self, "disable_output_compression", False):
output_config_kwargs["compression_type"] = "NONE"
output_data_config = OutputDataConfig(**output_config_kwargs)
model_trainer = ModelTrainer.from_recipe(
training_recipe=recipe_local_path,
compute=trainer_compute,
networking=networking,
stopping_condition=self.stopping_condition,
training_image=training_image,
input_data_config=input_data_list if input_data_list else None,
output_data_config=output_data_config,
hyperparameters=final_hyperparameters,
environment=self.environment or None,
sagemaker_session=sagemaker_session,
role=role,
base_job_name=base_job_name,
)
# Execute training
model_trainer.train(
wait=wait,
logs=wait,
)
# Store latest training job reference
self._latest_training_job = model_trainer._latest_training_job
if wait:
job_name = None
if hasattr(self._latest_training_job, 'training_job_name'):
job_name = self._latest_training_job.training_job_name
elif hasattr(self._latest_training_job, 'name'):
job_name = self._latest_training_job.name
if job_name:
try:
checkpoint_path = self._resolve_checkpoint_from_manifest(
job_name=job_name,
output_s3_path=self.s3_output_path,
sagemaker_session=sagemaker_session,
)
if checkpoint_path:
self._latest_training_job.model_artifacts = shapes.ModelArtifacts(
s3_model_artifacts=checkpoint_path
)
logger.info(
"Resolved checkpoint for %s: %s", job_name, checkpoint_path
)
except Exception as e:
logger.warning(
"Could not resolve checkpoint from manifest for %s: %s",
job_name,
e,
)
return self._latest_training_job
@staticmethod
def _resolve_checkpoint_from_manifest(
job_name: str,
output_s3_path: Optional[str],
sagemaker_session=None,
) -> Optional[str]:
"""Resolve the model checkpoint S3 path from a training job's manifest.
Supports both platforms:
- **SMHP (HyperPod)**: reads ``{output_s3_path}/{job_name}/manifest.json``
directly from S3.
- **SMTJ (Serverful)**: downloads
``{output_s3_path}/{job_name}/output/output.tar.gz``, extracts
``manifest.json`` from the archive.
The manifest contains a ``checkpoint_s3_bucket`` field pointing to the
final checkpoint location on S3 (e.g. in the customer-escrow bucket).
Args:
job_name: The training job name.
output_s3_path: The S3 output path configured for the training job.
sagemaker_session: SageMaker session (used for region/boto client).
Returns:
The S3 URI of the checkpoint, or None if unavailable.
"""
if not output_s3_path:
return None
parsed = urlparse(output_s3_path)
bucket = parsed.netloc
base_key = parsed.path.lstrip("/").rstrip("/")
region = None
if sagemaker_session and hasattr(sagemaker_session, 'boto_session'):
region = sagemaker_session.boto_session.region_name
s3_client = boto3.client("s3", region_name=region) if region else boto3.client("s3")
manifest = None
# Try SMHP format first: manifest.json directly in S3
manifest_key = f"{base_key}/{job_name}/manifest.json"
try:
response = s3_client.get_object(Bucket=bucket, Key=manifest_key)
manifest = json.loads(response["Body"].read())
except Exception:
pass
# Try SMTJ format: manifest.json inside output.tar.gz
if manifest is None:
tar_key = f"{base_key}/{job_name}/output/output.tar.gz"
try:
with tempfile.NamedTemporaryFile() as tmp_file:
s3_client.download_file(bucket, tar_key, tmp_file.name)
with tarfile.open(tmp_file.name, "r:gz") as tar:
manifest_file = tar.extractfile("manifest.json")
if manifest_file is not None:
manifest = json.loads(manifest_file.read())
except Exception:
pass
if manifest is None:
return None
checkpoint_path = manifest.get("checkpoint_s3_bucket")
if not checkpoint_path or not checkpoint_path.strip():
return None
# The manifest may store a relative path (SMHP convention). If it
# doesn't start with s3://, it's relative and we cannot resolve it
# without knowing the escrow bucket. Return as-is if it's absolute.
if not checkpoint_path.startswith("s3://"):
return None
checkpoint_path = checkpoint_path.strip()
return checkpoint_path
def _train_hyperpod(self, training_dataset=None, validation_dataset=None,
wait=True, wait_timeout=None, poll=5):
"""Execute training on a SageMaker HyperPod cluster.
Uses the HyperPod CLI to connect to the cluster and submit a training job
using a recipe-based approach. Shared across trainers that support HyperPod
(SFT, DPO, RLVR).
"""
logger = logging.getLogger(__name__)
sagemaker_session = TrainDefaults.get_sagemaker_session(
sagemaker_session=self.sagemaker_session
)
compute = self.compute
if not compute.cluster_name:
raise ValueError(
"cluster_name is required in HyperPodCompute for HyperPod training."
)
# HyperPod submits via the HyperPod CLI running as the *caller's* identity,
# so there is no execution role to resolve here; this verifies the caller's
# cluster-connect permissions (warn, non-blocking).
TrainDefaults.verify_hyperpod_caller_permissions(
sagemaker_session=sagemaker_session,
cluster_name=compute.cluster_name,
)
# Validate HyperPod cluster capacity before proceeding
is_nova = _is_nova_model(self._model_name)
validate_hyperpod_compute(
compute=compute,
sagemaker_session=sagemaker_session,
is_nova=is_nova,
)
namespace = compute.namespace or "kubeflow"
# Connect to the HyperPod cluster
try:
subprocess.run(
[
"hyperpod", "connect-cluster",
"--cluster-name", compute.cluster_name,
"--namespace", namespace,
],
capture_output=True, text=True, check=True,
)
except FileNotFoundError:
raise RuntimeError(
"The 'hyperpod' CLI is not installed or not on PATH. "
"Install it with: pip install hyperpod"
)
# Resolve training image
training_image = self.training_image
if not training_image:
smtj_image = get_training_image(
model_name=self._model_name,
customization_technique=self._customization_technique,
training_type=self.training_type,
sagemaker_session=sagemaker_session,
)
if smtj_image:
training_image = smtj_image.replace("SM-TJ-", "SM-HP-")
else:
training_image = get_hyperpod_training_image(
model_name=self._model_name,
customization_technique=self._customization_technique,
training_type=self.training_type,
sagemaker_session=sagemaker_session,
)
if not training_image:
raise ValueError(
"training_image is required for HyperPod compute but could not be resolved "
f"from model metadata for model '{self._model_name}' with customization "
f"technique '{self._customization_technique}'. Pass it explicitly via the "
"trainer's training_image parameter."
)
job_base_name = self.base_job_name or f"{self._model_name}-{self._customization_technique}"
# Validate node_count against allowed values from SMHP recipe
self._validate_instance_count(compute.node_count, sagemaker_session)
# Resolve and validate the recipe (3-level merge: base → user recipe → overrides)
try:
resolved = self.get_resolved_recipe()
additional_overrides = flatten_resolved_recipe(resolved)
except NoRecipeError:
additional_overrides = {}
# Add HyperPod-specific fields not in the recipe
additional_overrides["name"] = _get_unique_name(job_base_name)
if compute.node_count:
additional_overrides["replicas"] = compute.node_count
# Data paths
resolved_training_dataset = training_dataset or self.training_dataset
resolved_validation_dataset = validation_dataset or self.validation_dataset
if resolved_training_dataset:
additional_overrides["data_s3_path"] = resolved_training_dataset
if resolved_validation_dataset:
additional_overrides["validation_data_s3_path"] = resolved_validation_dataset
# Output path
if self.s3_output_path:
additional_overrides["output_s3_path"] = self.s3_output_path
# MLflow configuration
mlflow_uri, mlflow_exp, mlflow_run = resolve_mlflow_tracking_fields(
mlflow_tracking_uri=getattr(self, 'mlflow_resource_arn', None),
mlflow_experiment_name=getattr(self, 'mlflow_experiment_name', None),
mlflow_run_name=getattr(self, 'mlflow_run_name', None),
base_job_name=job_base_name,
)
if mlflow_uri:
additional_overrides["mlflow_tracking_uri"] = mlflow_uri
additional_overrides["mlflow_experiment_name"] = mlflow_exp
additional_overrides["mlflow_run_name"] = mlflow_run
# Render recipe with all overrides baked in and write to CLI directory
recipe_cli_path = get_hyperpod_recipe_path(
model_name=self._model_name,
customization_technique=self._customization_technique,
training_type=self.training_type,
sagemaker_session=sagemaker_session,
job_name=job_base_name,
additional_overrides=additional_overrides,
)
logger.info(f"HyperPod recipe resolved: {recipe_cli_path}")
# Only instance_type, container, and model_name_or_path remain as override parameters
override_parameters = {}
if compute.instance_type:
override_parameters["instance_type"] = compute.instance_type
if training_image:
override_parameters["container"] = training_image
if getattr(self, 'model_source', None):
override_parameters["recipes.run.model_name_or_path"] = self.model_source
# Submit job
start_job_cmd = [
"hyperpod", "start-job",
"--namespace", namespace,
"--recipe", recipe_cli_path,
]
if override_parameters:
start_job_cmd.extend(["--override-parameters", json.dumps(override_parameters)])
logger.info(f"Submitting HyperPod job: {' '.join(start_job_cmd)}")
try:
start_result = subprocess.run(
start_job_cmd, capture_output=True, text=True, check=True,
)
except subprocess.CalledProcessError as e:
logger.error(f"Failed to start HyperPod job: {e.stderr}")
raise
# Extract job name from output
matched = re.search(r"NAME: (\S+)", start_result.stdout)
if not matched:
raise ValueError(
f"Could not find job name in HyperPod CLI output: {start_result.stdout}"
)
job_name = matched.group(1)
logger.info(f"HyperPod job submitted: {job_name}")
training_job = TrainingJob(training_job_name=job_name)
if wait:
try:
checkpoint_path = self._resolve_checkpoint_from_manifest(
job_name=job_name,
output_s3_path=self.s3_output_path,
sagemaker_session=sagemaker_session,
)
if checkpoint_path:
training_job.model_artifacts = shapes.ModelArtifacts(
s3_model_artifacts=checkpoint_path
)
except Exception as e:
logger.warning(
"Could not resolve checkpoint from manifest for %s: %s", job_name, e
)
self._latest_training_job = training_job
return job_name