Source code for sagemaker.train.recipe_resolver

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Recipe resolution with 3-level override precedence for Nova model training."""
from __future__ import absolute_import

import copy
import logging
import os
import tempfile
from typing import Any, Dict, Optional, Set, Tuple, Union

import yaml
from omegaconf import OmegaConf

from sagemaker.core.training.configs import HyperPodCompute, TrainingJobCompute
from sagemaker.train.sm_recipes.utils import _register_custom_resolvers

logger = logging.getLogger(__name__)


[docs] def render_template( template: Dict[str, Any], override_spec: Dict[str, Any], ) -> Tuple[Dict[str, Any], Dict[str, str]]: """Render a Hub recipe template by filling {{placeholder}} values. Args: template: Hub recipe template dict containing '{{key}}' placeholders. override_spec: Flat dict mapping spec keys to their metadata (including 'default', 'type', 'min', 'max', 'enum'). Returns: Tuple of (rendered_dict, key_path_map) where key_path_map maps flat spec keys to their dotpath location in the recipe structure. e.g. {"learning_rate": "training_config.learning_rate"} """ key_path_map = {} def _walk(obj, path_parts): if isinstance(obj, dict): return {k: _walk(v, path_parts + [k]) for k, v in obj.items()} elif isinstance(obj, list): return [_walk(item, path_parts + [str(i)]) for i, item in enumerate(obj)] elif isinstance(obj, str) and "{{" in obj and "}}" in obj: spec_key = obj.removeprefix("'").removesuffix("'") spec_key = spec_key.removeprefix('"').removesuffix('"') spec_key = spec_key.removeprefix("{{").removesuffix("}}") spec_key = spec_key.strip() key_path_map[spec_key] = ".".join(path_parts) spec_entry = override_spec.get(spec_key, {}) return spec_entry.get("default") else: return obj rendered = _walk(template, []) return rendered, key_path_map
def _load_user_recipe(recipe_path: str) -> Dict[str, Any]: """Load a user recipe from a local path or S3 URI. Args: recipe_path: Local file path or S3 URI to a YAML recipe file. Returns: Parsed recipe as a dict. Raises: ValueError: If the file cannot be loaded or parsed. """ if recipe_path.startswith("s3://"): try: import boto3 parts = recipe_path.replace("s3://", "").split("/", 1) bucket, key = parts[0], parts[1] s3 = boto3.client("s3") tmp = tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) s3.download_file(bucket, key, tmp.name) tmp.close() with open(tmp.name, "r") as f: content = yaml.safe_load(f) os.unlink(tmp.name) if not isinstance(content, dict): raise ValueError( f"Recipe file at {recipe_path} did not parse as a YAML mapping." ) return content except ImportError: raise ValueError( "boto3 is required to load recipes from S3. Install it with: pip install boto3" ) except Exception as e: raise ValueError(f"Could not load recipe from {recipe_path}: {e}") elif recipe_path.startswith("http://") or recipe_path.startswith("https://"): raise ValueError( f"HTTP/HTTPS recipe URLs are not supported for security reasons. " f"Use a local file path or S3 URI instead: {recipe_path}" ) else: if not os.path.isfile(recipe_path): raise ValueError(f"Recipe file not found: {recipe_path}") with open(recipe_path, "r") as f: content = yaml.safe_load(f) if not isinstance(content, dict): raise ValueError( f"Recipe file at {recipe_path} did not parse as a YAML mapping." ) return content def _validate_value( key: str, value: Any, spec: Dict[str, Any], source: str, resolved_recipe: Optional[Dict[str, Any]] = None, dotpath: Optional[str] = None, ) -> None: """Validate a single value against its spec entry. Performs type checking, range validation, enum membership, and required field presence checks. Args: key: The parameter name. value: The value to validate. spec: The spec entry dict with 'type', 'min', 'max', 'enum', 'required' fields. source: Human-readable source label (e.g., "overrides dict", "user recipe"). resolved_recipe: Optional resolved recipe dict (used for required check context). dotpath: Optional dotpath of the key in the recipe (used for required check context). Raises: ValueError: If validation fails. """ # --- Required field presence check --- if spec.get("required", False): if not dotpath: raise ValueError( f"'{key}' is required but was not found in the resolved recipe." ) if value is None: raise ValueError( f"'{key}' is required but was not found in the resolved recipe." ) if value is None: return expected_type = spec.get("type") if expected_type: type_map = { "float": (int, float), "integer": (int,), "int": (int,), "string": (str,), "boolean": (bool,), "bool": (bool,), } allowed_types = type_map.get(expected_type.lower()) if allowed_types and not isinstance(value, allowed_types): raise ValueError( f"Invalid type for '{key}': expected {expected_type}, " f"got {type(value).__name__} (value: {value}). Source: {source}." ) min_val = spec.get("min") if min_val is not None and isinstance(value, (int, float)): if value < min_val: raise ValueError( f"Invalid value for '{key}': {value} is below minimum {min_val}. " f"Allowed range: [{min_val}, {spec.get('max', '...')}]. Source: {source}." ) max_val = spec.get("max") if max_val is not None and isinstance(value, (int, float)): if value > max_val: raise ValueError( f"Invalid value for '{key}': {value} is above maximum {max_val}. " f"Allowed range: [{spec.get('min', '...')}, {max_val}]. Source: {source}." ) enum_values = spec.get("enum") if enum_values is not None: if value == "" or value == spec.get("default"): pass elif value not in enum_values: raise ValueError( f"Invalid value for '{key}': {value} is not in allowed values {enum_values}. " f"Source: {source}." ) def _validate_step_constraints( resolved_recipe: Dict[str, Any], key_path_map: Dict[str, str], ) -> None: """Perform cross-field validation on a resolved recipe. Validates constraints that involve relationships between multiple recipe parameters (e.g., save_steps must be <= max_steps). Args: resolved_recipe: The fully resolved recipe dict (nested structure). key_path_map: Mapping of flat spec key names to their dotpath location in the resolved recipe. Raises: ValueError: Immediately on the first validation failure. """ # --- Cross-field validation: save_steps must be <= max_steps --- save_steps_path = key_path_map.get("save_steps") max_steps_path = key_path_map.get("max_steps") if save_steps_path and max_steps_path: save_steps = _get_nested_value(resolved_recipe, save_steps_path) max_steps = _get_nested_value(resolved_recipe, max_steps_path) if ( isinstance(save_steps, (int, float)) and isinstance(max_steps, (int, float)) and save_steps > max_steps ): raise ValueError( f"'save_steps' ({save_steps}) must be less than or equal to " f"'max_steps' ({max_steps})." )
[docs] def flatten_resolved_recipe(resolved: Dict[str, Any]) -> Dict[str, Any]: """Flatten a resolved recipe dict into a single-level key-value map. Recursively walks all nested dicts and extracts scalar leaf values keyed by their leaf key name. Used by trainers and evaluators to apply resolved recipe values as flat hyperparameters to the SageMaker training API. For nested structures like: training_config: lr_scheduler: warmup_steps: 15 min_lr: 1e-6 This produces: {"warmup_steps": 15, "min_lr": 1e-6} If duplicate leaf keys exist at different nesting levels, the last one encountered wins (depth-first traversal). Args: resolved: The resolved recipe dict (nested by section). Returns: Flat dict of all scalar leaf key-value pairs across all sections. """ flat = {} def _walk(obj): if isinstance(obj, dict): for k, v in obj.items(): if isinstance(v, dict): _walk(v) elif not isinstance(v, list): flat[k] = v elif isinstance(obj, list): for item in obj: if isinstance(item, dict): _walk(item) _walk(resolved) return flat
def _get_nested_value(d: Dict[str, Any], dotpath: str) -> Any: """Get a value from a nested dict using a dot-separated path.""" parts = dotpath.split(".") current = d for part in parts: if isinstance(current, dict) and part in current: current = current[part] else: return None return current def _set_nested_value(d: Dict[str, Any], dotpath: str, value: Any) -> None: """Set a value in a nested dict using a dot-separated path.""" parts = dotpath.split(".") current = d for part in parts[:-1]: if part not in current or not isinstance(current[part], dict): current[part] = {} current = current[part] current[parts[-1]] = value def _build_key_path_map(recipe_dict: Dict[str, Any], spec_keys: set) -> Dict[str, str]: """Build a key_path_map by finding spec keys in a recipe dict. Walks the recipe dict and maps spec key names to their dotpath locations. Used when a full recipe template is provided instead of a synthetic one. Args: recipe_dict: The full recipe dict (nested). spec_keys: Set of flat spec key names to locate. Returns: Dict mapping spec key names to their dotpath in the recipe. """ key_path_map = {} def _walk(obj, path_parts): if isinstance(obj, dict): for k, v in obj.items(): current_path = path_parts + [k] if k in spec_keys and k not in key_path_map: key_path_map[k] = ".".join(current_path) _walk(v, current_path) elif isinstance(obj, list): for i, item in enumerate(obj): _walk(item, path_parts + [str(i)]) _walk(recipe_dict, []) return key_path_map
[docs] class RecipeResolver: """Resolves a 3-level recipe configuration for Nova model training. Precedence (highest wins): 1. Programmatic overrides (dict) 2. User recipe (YAML file) 3. Base defaults (rendered from Hub template + override-params spec) Immutable after construction — all inputs are deep-copied. resolve() is idempotent: second call returns cached result. """ def __init__( self, recipe_template: Dict[str, Any], override_spec: Dict[str, Any], user_recipe_path: Optional[str] = None, overrides: Optional[Dict[str, Any]] = None, protected_keys: Optional[Set[str]] = None, full_recipe_template: Optional[Dict[str, Any]] = None, compute: Optional[Union["Compute", "HyperPodCompute"]] = None, ): """Initialize the resolver. Args: recipe_template: Hub recipe template dict (with {{placeholder}} syntax). override_spec: Flat dict of parameter specs from Hub (type/min/max/enum/default). user_recipe_path: Optional path to user's recipe YAML file (local or S3). overrides: Optional programmatic overrides dict (nested structure). protected_keys: Optional set of keys that cannot be overridden by user recipe or overrides (e.g., 'model_type', 'task'). full_recipe_template: Optional full recipe dict fetched from Hub (SmtjRecipeTemplateS3Uri). When provided, used as the base layer instead of the synthetic template — enables overriding any key in the full recipe, not just the spec-exposed subset. compute: Optional compute configuration. Union of ``sagemaker.core.training.configs.Compute`` (TrainingJobCompute) or ``sagemaker.core.training.configs.HyperPodCompute``. None indicates SMTJ Serverless. """ self._recipe_template = copy.deepcopy(recipe_template) self._override_spec = copy.deepcopy(override_spec) self._user_recipe_path = user_recipe_path self._overrides = copy.deepcopy(overrides) if overrides else {} self._protected_keys = protected_keys or set() self._full_recipe_template = copy.deepcopy(full_recipe_template) if full_recipe_template else None self._compute = compute self._resolved: Optional[Dict[str, Any]] = None
[docs] def resolve(self) -> Dict[str, Any]: """Perform template render, 3-level merge, and validation. Returns: The fully resolved recipe as a plain dict. Raises: ValueError: If validation fails for any parameter. """ if self._resolved is not None: return copy.deepcopy(self._resolved) # Phase 1: Determine base layer and key_path_map if self._full_recipe_template: # Use the full recipe from Hub as the base layer. # render_template resolves any {{placeholder}} values to spec defaults # and maps those placeholder keys to their dotpaths. base_dict, key_path_map = render_template( self._full_recipe_template, self._override_spec ) # For keys that appear as plain values (not placeholders) in the # full template, locate them by name so validation and protected-key # stripping still work. extra_keys = (set(self._override_spec.keys()) | self._protected_keys) - set(key_path_map.keys()) if extra_keys: extra_paths = _build_key_path_map(base_dict, extra_keys) key_path_map.update(extra_paths) else: # Synthetic template built from spec keys only (legacy path) base_dict, key_path_map = render_template( self._recipe_template, self._override_spec ) # Phase 2: Load user recipe if provided user_dict = {} if self._user_recipe_path: user_dict = _load_user_recipe(self._user_recipe_path) # Phase 3: Strip protected keys and drop unknown keys from copies # (don't mutate loaded inputs). user_dict_for_merge = copy.deepcopy(user_dict) overrides_for_merge = copy.deepcopy(self._overrides) # Expand flat override keys into nested structure using key_path_map. # Users may pass overrides like {"fine_tuned_model": 0.9} or # {"training_config": {"learning_rate": 5e-6}} where the flat key doesn't # match the actual recipe path (e.g., training_config.optim_config.lr). # Use key_path_map to place them at the correct nested position. if overrides_for_merge and key_path_map: expanded = {} remaining = {} # Build a map of recipe field names → dotpaths so users can override # using actual recipe field names (e.g. lora_plus_lr_ratio) all_field_paths = {} def _map_all_fields(d, prefix=""): for k, v in d.items(): path = f"{prefix}.{k}" if prefix else k if isinstance(v, dict): _map_all_fields(v, path) else: all_field_paths[k] = path _map_all_fields(base_dict) def _collect_flat_keys(d, prefix=""): """Recursively find leaf values and their current paths.""" for k, v in d.items(): current_path = f"{prefix}.{k}" if prefix else k if isinstance(v, dict): _collect_flat_keys(v, current_path) else: # Recipe field name takes priority, fall back to spec key (placeholder) if k in all_field_paths: _set_nested_value(expanded, all_field_paths[k], v) elif k in key_path_map: _set_nested_value(expanded, key_path_map[k], v) else: _set_nested_value(expanded, current_path, v) _collect_flat_keys(overrides_for_merge) overrides_for_merge = expanded self._strip_protected_keys(user_dict_for_merge, key_path_map) self._strip_protected_keys(overrides_for_merge, key_path_map) # Overrides and user recipes may only modify parameters that already # exist in the base recipe. Any key not present in the recipe is # dropped (with a warning) so it is never merged in. (No-op when the # dict is empty, e.g. no user recipe was provided.) # Prune the user recipe (loaded from YAML) against the base recipe. self._drop_unknown_keys( user_dict_for_merge, base_dict, source=f"user recipe ({self._user_recipe_path})", ) # Prune the programmatic overrides dict against the base recipe. self._drop_unknown_keys(overrides_for_merge, base_dict, source="overrides") # Phase 4: OmegaConf 3-way merge (base < user < overrides) _register_custom_resolvers() base_cfg = OmegaConf.create(base_dict) user_cfg = OmegaConf.create(user_dict_for_merge) overrides_cfg = OmegaConf.create(overrides_for_merge) merged = OmegaConf.merge(base_cfg, user_cfg, overrides_cfg) # Phase 5: Resolve interpolations try: OmegaConf.resolve(merged) except Exception as e: raise ValueError( f"Failed to resolve recipe interpolations: {e}. " f"Ensure all referenced keys exist in the base template or user recipe." ) resolved_dict = OmegaConf.to_container(merged, resolve=True) # Phase 6: Validate against override spec using key_path_map self._validate(resolved_dict, key_path_map, compute=self._compute) self._resolved = resolved_dict return copy.deepcopy(self._resolved)
[docs] def get_resolved_recipe(self) -> Dict[str, Any]: """Return the resolved recipe as a read-only deep copy. Callable before or after train()/evaluate(). Triggers resolution on first call if not already resolved. Returns: Deep copy of the fully resolved recipe dict. """ return self.resolve()
def _drop_unknown_keys( self, override_dict: Dict[str, Any], base_dict: Dict[str, Any], source: str, _path: str = "", ) -> None: """Drop keys not present in the base recipe and warn about each one. Overrides (from the overrides dict or a user recipe) may only modify parameters that already exist in the base recipe. Any key with no counterpart in the base recipe is removed from ``override_dict`` in place so it is never merged into the resolved recipe. A warning is logged for every dropped key. Walks both dicts in parallel so the comparison respects structure: an override key is "known" only if a key of the same name exists at the same location in the base recipe. Args: override_dict: Override/user-recipe dict to prune (mutated in place). base_dict: The base recipe dict to validate keys against. source: Human-readable origin label used in the warning message (e.g. "overrides" or "user recipe (...)"). _path: Internal dotpath accumulator used for warning messages. """ if not isinstance(override_dict, dict) or not isinstance(base_dict, dict): return for key in list(override_dict.keys()): dotpath = f"{_path}.{key}" if _path else key if key not in base_dict: logger.warning( f"Override key '{dotpath}' from {source} does not exist in " f"the recipe and will be dropped." ) del override_dict[key] continue # Recurse into nested mappings so unknown nested keys are dropped # while known sibling keys are preserved. if isinstance(override_dict[key], dict) and isinstance(base_dict[key], dict): self._drop_unknown_keys( override_dict[key], base_dict[key], source, dotpath ) def _strip_protected_keys( self, d: Dict[str, Any], key_path_map: Dict[str, str] ) -> None: """Remove protected keys from a dict and log warnings.""" for spec_key in self._protected_keys: dotpath = key_path_map.get(spec_key) if dotpath: parts = dotpath.split(".") current = d for part in parts[:-1]: if isinstance(current, dict) and part in current: current = current[part] else: current = None break if isinstance(current, dict) and parts[-1] in current: logger.warning( f"Protected key '{spec_key}' (at {dotpath}) cannot be overridden. " f"Ignoring user-provided value." ) del current[parts[-1]] def _validate( self, resolved: Dict[str, Any], key_path_map: Dict[str, str], compute: Optional[Union["TrainingJobCompute", "HyperPodCompute"]] = None, ) -> None: """Validate resolved values against the override spec. Performs per-value validation (type, range, enum, required) and compute-level checks (instance_type), then runs steps validation. Args: resolved: The fully resolved recipe dict. key_path_map: Mapping of spec keys to their dotpath in the recipe. compute: Optional compute configuration. Union of Compute (TrainingJobCompute) or HyperPodCompute. None indicates SMTJ Serverless. """ is_serverless = compute is None for spec_key, spec_entry in self._override_spec.items(): # --- Instance type: validated from compute, not from recipe --- if spec_key == "instance_type": if is_serverless: continue instance_type = getattr(compute, "instance_type", None) allowed_values = spec_entry.get("enum") if instance_type is None: raise ValueError( "instance_type must be specified in Compute parameter. " f"Allowed types: {sorted(allowed_values) if allowed_values else 'unknown'}." ) if allowed_values and instance_type not in allowed_values: raise ValueError( f"Instance type '{instance_type}' is not supported. " f"Allowed types: {sorted(allowed_values)}." ) continue # --- Replicas: validated from compute, not from recipe --- if spec_key == "replicas": if is_serverless: continue node_count = getattr(compute, "node_count", None) or getattr( compute, "instance_count", None ) allowed_values = spec_entry.get("enum") if node_count is None: raise ValueError( "node_count (or instance_count) must be specified in Compute parameter. " f"Allowed values: {sorted(allowed_values) if allowed_values else 'unknown'}." ) if allowed_values and node_count not in allowed_values: raise ValueError( f"Node/Instance count '{node_count}' is not supported. " f"Allowed values: {sorted(allowed_values)}." ) continue # Skip keys not mapped into the recipe structure dotpath = key_path_map.get(spec_key) if not dotpath: # When using a full recipe template, keys not found in the recipe # have no corresponding {{placeholder}} in the template, meaning # they are not overridable via the recipe - these are skipped. if self._full_recipe_template: continue # Still check required even when not in key_path_map (synthetic template) if spec_entry.get("required", False): raise ValueError( f"'{spec_key}' is required but was not found in the " f"resolved recipe." ) continue value = _get_nested_value(resolved, dotpath) # Determine source for error messages if self._overrides: override_value = _get_nested_value(self._overrides, dotpath) if override_value is not None: source = "overrides dict" elif self._user_recipe_path: source = f"user recipe ({self._user_recipe_path})" else: source = "base defaults" else: source = "resolved recipe" _validate_value(spec_key, value, spec_entry, source, resolved, dotpath) _validate_step_constraints(resolved, key_path_map)