Source code for sagemaker.train.evaluate.inspect_ai_evaluator

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""InspectAI Evaluator for SageMaker Model Evaluation Module.

This module provides evaluation capabilities using InspectAI as a backend,
enabling a broad set of benchmarks and methodologies via the InspectAI framework.
The evaluator runs InspectAI tasks inside a dedicated container on SageMaker
Training infrastructure.
"""

import logging
import os
import re
import uuid
from typing import Any, Dict, Iterator, List, Optional

import yaml
from pydantic import root_validator, validator
from sagemaker.core.s3.client import S3Uploader
from sagemaker.core.telemetry.constants import Feature
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.utils.utils import Unassigned

from .base_evaluator import BaseEvaluator
from .constants import EvalType, _get_inspect_ai_default_image_uri, _get_nova_inference_image_uri
from .execution import EvaluationPipelineExecution
from .pipeline_templates import INSPECT_AI_TEMPLATE

_logger = logging.getLogger(__name__)

_ECR_URI_PATTERN = re.compile(
    r"^\d{12}\.dkr\.ecr\.[a-z0-9-]+\.amazonaws\.com/[a-z0-9._/-]+(:[a-zA-Z0-9._-]+)?$"
)
_IAM_ROLE_ARN_PATTERN = re.compile(r"^arn:aws(-cn|-us-gov|-iso-f)?:iam::\d{12}:role/.+$")


[docs] class InspectAIEvaluator(BaseEvaluator): """InspectAI evaluation job. Runs InspectAI tasks inside a SageMaker Training container, supporting three inference provider modes: Bedrock, existing SageMaker endpoint, or creating a new endpoint. The evaluator serializes configuration to a YAML file (``inspect_config.yaml``), uploads it to S3, and launches a single-step SageMaker Pipeline that runs the InspectAI container with that config as input. Supports resource chaining: a completed trainer (e.g., ``SFTTrainer``, ``DPOTrainer``, ``MultiTurnRLTrainer``) can be passed directly as the ``model`` parameter. The evaluator will automatically resolve the trainer's output model package artifacts and configure endpoint creation for evaluation. Attributes: benchmarks_path (str): S3 URI pointing to benchmark ``.py`` files with ``@task`` decorators. Required. tasks (Optional[List[Dict[str, Any]]]): List of task configurations. Each dict must have a ``"name"`` key. Optional keys: ``"path"`` (must end with .py), ``"limit"`` (int >= 1), ``"epochs"`` (int >= 1), ``"task_args"`` (dict). If None or empty, all tasks at ``benchmarks_path`` are run. output_format (Optional[str]): Output format for results. One of ``"eval"``, ``"csv"``, ``"jsonl"``, ``"json"``. bedrock_model_id (Optional[str]): Explicit Bedrock model ID for bedrock inference mode. Falls back to the model's bedrock_model_id if not set. endpoint_name (Optional[str]): Existing SageMaker endpoint name. Mutually exclusive with ``model_s3_uri``/``inference_image_uri``. model_s3_uri (Optional[str]): S3 URI of model artifacts for creating a new endpoint. Must be paired with ``inference_image_uri``. inference_image_uri (Optional[str]): ECR image URI for creating a new endpoint. Must be paired with ``model_s3_uri``. endpoint_instance_type (Optional[str]): Instance type for new endpoint (must start with ``ml.``). endpoint_instance_count (int): Instance count for new endpoint. Defaults to 1. endpoint_execution_role_arn (Optional[str]): IAM role ARN for new endpoint. context_length (Optional[str]): Context length as string integer. max_concurrency (Optional[str]): Max concurrency as string integer. cleanup_endpoint (bool): Delete endpoint after evaluation. Defaults to True. endpoint_prefix (str): Prefix for auto-created endpoint names. endpoint_environment (Optional[Dict[str, str]]): Env vars for the inference endpoint container. extra_args (Optional[List[str]]): Additional CLI args forwarded to ``inspect eval``. environment (Optional[Dict[str, str]]): Env vars for the SageMaker Training Job container. image_uri (Optional[str]): Override for the InspectAI container image URI. instance_type (str): Instance type for the orchestrator Training Job (CPU-only). Defaults to ``"ml.m5.large"``. max_runtime_seconds (int): Max runtime for the Training Job in seconds. Defaults to 86400 (24 hours). max_connections (int): Max concurrent inference connections used by the InspectAI eval runner. Defaults to 16. max_retries (int): Max retries per inference request. Defaults to 100. timeout (int): Per-request timeout in seconds. Defaults to 600. temperature (float): Sampling temperature in [0.0, 2.0]. Defaults to 0.0. top_p (float): Nucleus sampling cutoff in [0.0, 1.0]. Defaults to 1.0. top_k (int): Top-k sampling cutoff. Use ``-1`` to disable. Defaults to -1. max_tokens (int): Max tokens to generate per response. Defaults to 8192. Example: .. code:: python from sagemaker.train.evaluate import InspectAIEvaluator evaluator = InspectAIEvaluator( model="amazon-nova-lite-v1", benchmarks_path="s3://my-bucket/benchmarks/", tasks=[{"name": "boolq_pt", "limit": 10}], s3_output_path="s3://my-bucket/eval-output/", ) execution = evaluator.evaluate() execution.wait() execution.show_results() Resource chaining with a trainer: .. code:: python from sagemaker.train import SFTTrainer from sagemaker.train.evaluate import InspectAIEvaluator # Train a model trainer = SFTTrainer(model="llama3-2-1b-instruct", ...) trainer.train(training_dataset="s3://bucket/data.jsonl") # Evaluate the fine-tuned model directly evaluator = InspectAIEvaluator( model=trainer, benchmarks_path="s3://my-bucket/benchmarks/", tasks=[{"name": "boolq_pt", "limit": 10}], s3_output_path="s3://my-bucket/eval-output/", ) execution = evaluator.evaluate() """ # InspectAI-specific fields benchmarks_path: str tasks: Optional[List[Dict[str, Any]]] = None output_format: Optional[str] = None bedrock_model_id: Optional[str] = None endpoint_name: Optional[str] = None model_s3_uri: Optional[str] = None inference_image_uri: Optional[str] = None endpoint_instance_type: Optional[str] = None endpoint_instance_count: int = 1 endpoint_execution_role_arn: Optional[str] = None context_length: Optional[str] = None max_concurrency: Optional[str] = None cleanup_endpoint: bool = True endpoint_prefix: str = "inspectai" endpoint_environment: Optional[Dict[str, str]] = None extra_args: Optional[List[str]] = None environment: Optional[Dict[str, str]] = None image_uri: Optional[str] = None instance_type: str = "ml.m5.large" max_runtime_seconds: int = 86400 # Eval orchestration tunables (forwarded into eval section of inspect_config.yaml) max_connections: int = 16 max_retries: int = 100 timeout: int = 600 # Decoding tunables (forwarded into eval.decoding section of inspect_config.yaml) temperature: float = 0.0 top_p: float = 1.0 top_k: int = -1 max_tokens: int = 8192 @validator("environment") def _validate_environment(cls, v): if v is None: return v for key, val in v.items(): if not isinstance(key, str) or not isinstance(val, str): raise ValueError("environment must be a flat Dict[str, str]") return v @validator("benchmarks_path") def _validate_benchmarks_path(cls, v): if not v or not v.strip(): raise ValueError("benchmarks_path is required and cannot be empty") if not v.startswith("s3://"): raise ValueError(f"benchmarks_path must start with 's3://'. Got: '{v}'") return v @validator("tasks") def _validate_tasks(cls, v): if v is None: return v if not isinstance(v, list): raise ValueError("tasks must be a list of dicts") for i, task in enumerate(v): if not isinstance(task, dict): raise ValueError(f"tasks[{i}] must be a dict, got {type(task).__name__}") if "name" not in task: raise ValueError(f"tasks[{i}] must have a 'name' key") if "path" in task and not task["path"].endswith(".py"): raise ValueError(f"tasks[{i}]['path'] must end with '.py'. Got: '{task['path']}'") if "limit" in task: if not isinstance(task["limit"], int) or task["limit"] < 1: raise ValueError(f"tasks[{i}]['limit'] must be an integer >= 1") if "epochs" in task: if not isinstance(task["epochs"], int) or task["epochs"] < 1: raise ValueError(f"tasks[{i}]['epochs'] must be an integer >= 1") if "task_args" in task and not isinstance(task["task_args"], dict): raise ValueError(f"tasks[{i}]['task_args'] must be a dict") return v @validator("output_format") def _validate_output_format(cls, v): if v is None: return v allowed = ("eval", "csv", "jsonl", "json") if v not in allowed: raise ValueError(f"output_format must be one of {allowed}. Got: '{v}'") return v @validator("model_s3_uri") def _validate_model_s3_uri(cls, v): if v is not None and not v.startswith("s3://"): raise ValueError(f"model_s3_uri must start with 's3://'. Got: '{v}'") return v @validator("inference_image_uri") def _validate_inference_image_uri(cls, v): if v is not None and not _ECR_URI_PATTERN.match(v): raise ValueError(f"inference_image_uri must be a valid ECR URI. Got: '{v}'") return v @validator("endpoint_instance_type") def _validate_endpoint_instance_type(cls, v): if v is not None and not v.startswith("ml."): raise ValueError(f"endpoint_instance_type must start with 'ml.'. Got: '{v}'") return v @validator("endpoint_execution_role_arn") def _validate_endpoint_execution_role_arn(cls, v): if v is not None and not _IAM_ROLE_ARN_PATTERN.match(v): raise ValueError( f"endpoint_execution_role_arn must be a valid IAM role ARN. Got: '{v}'" ) return v @root_validator(skip_on_failure=True) def _validate_inference_mode_consistency(cls, values): from sagemaker.train.base_trainer import BaseTrainer endpoint_name = values.get("endpoint_name") model_s3_uri = values.get("model_s3_uri") inference_image_uri = values.get("inference_image_uri") model = values.get("model") # Skip validation when model is a trainer — _resolve_trainer_model # will fill in model_s3_uri from the trainer's checkpoint. if isinstance(model, BaseTrainer): return values if endpoint_name and model_s3_uri: raise ValueError( "endpoint_name and model_s3_uri are mutually exclusive. " "Use endpoint_name for an existing endpoint, or model_s3_uri + " "inference_image_uri to create a new endpoint." ) if model_s3_uri and not inference_image_uri: raise ValueError( "inference_image_uri is required when model_s3_uri is provided " "(create_endpoint mode)." ) if inference_image_uri and not model_s3_uri: raise ValueError( "model_s3_uri is required when inference_image_uri is provided " "(create_endpoint mode)." ) return values @root_validator(skip_on_failure=True) def _resolve_trainer_model(cls, values): """Auto-resolve model artifacts from a BaseTrainer for endpoint creation. When a trainer is passed as ``model`` and no explicit inference mode (``endpoint_name``, ``model_s3_uri``, ``bedrock_model_id``) is configured, this resolver extracts the model S3 URI and inference image URI from the trainer's output model package, enabling automatic ``create_endpoint`` mode. This supports resource chaining where a completed trainer can be fed directly into the evaluator without manual artifact lookup. """ from sagemaker.train.base_trainer import BaseTrainer model = values.get("model") if not isinstance(model, BaseTrainer): return values # Only auto-resolve if no explicit inference mode is configured endpoint_name = values.get("endpoint_name") model_s3_uri = values.get("model_s3_uri") bedrock_model_id = values.get("bedrock_model_id") if endpoint_name or model_s3_uri or bedrock_model_id: return values # Resolve model package ARN from the trainer source_mp_arn = None # MultiTurnRLTrainer uses _latest_job if hasattr(model, "_latest_job") and model._latest_job is not None: source_mp_arn = getattr(model._latest_job, "output_model_package_arn", None) # Standard trainers (SFT, DPO, RLVR, RLAIF) use _latest_training_job if not source_mp_arn and hasattr(model, "_latest_training_job") and model._latest_training_job is not None: arn = getattr(model._latest_training_job, "output_model_package_arn", None) # Filter out Unassigned sentinels from sagemaker-core if arn is not None and not isinstance(arn, Unassigned): source_mp_arn = arn if not source_mp_arn: # Check if trainer has a resolved checkpoint path from model_artifacts checkpoint_uri = None training_job = getattr(model, '_latest_training_job', None) if training_job: artifacts = getattr(training_job, 'model_artifacts', None) if artifacts and not isinstance(artifacts, Unassigned): s3_path = getattr(artifacts, 's3_model_artifacts', None) if s3_path and isinstance(s3_path, str): checkpoint_uri = s3_path if checkpoint_uri: # CreateModel requires trailing slash for S3 prefix URIs if not checkpoint_uri.endswith("/"): checkpoint_uri += "/" values["model_s3_uri"] = checkpoint_uri # Auto-derive inference image if not explicitly provided if not values.get("inference_image_uri"): model_name = getattr(model, '_model_name', None) or "" region = None session = values.get("sagemaker_session") if session and hasattr(session, "boto_session"): region = session.boto_session.region_name if "nova" in model_name.lower() and region: resolved_image = _get_nova_inference_image_uri(region) if resolved_image: values["inference_image_uri"] = resolved_image _logger.info( "Auto-resolved Nova inference image for trainer checkpoint: " "model_s3_uri=%s, inference_image_uri=%s", checkpoint_uri, resolved_image, ) else: _logger.info( "trainer checkpoint detected but no inference_image_uri set. " "For non-Nova models, provide inference_image_uri explicitly." ) values.pop("model_s3_uri", None) else: _logger.info( "Auto-resolved trainer checkpoint for create_endpoint mode: " "model_s3_uri=%s", checkpoint_uri, ) return values _logger.info( "Trainer has no completed training job output; falling back to bedrock mode." ) return values # Resolve model artifacts from the model package try: session = values.get("sagemaker_session") from sagemaker.core.resources import ModelPackage as _MP boto_session = ( session.boto_session if hasattr(session, "boto_session") else session ) region = boto_session.region_name if boto_session else None mp = _MP.get( model_package_name=source_mp_arn, session=boto_session, region=region, ) # Extract model data URL and image URI from inference specification if ( mp.inference_specification and mp.inference_specification.containers ): container = mp.inference_specification.containers[0] # Resolve model S3 URI: try model_data_url first, then model_data_source resolved_model_s3 = getattr(container, "model_data_url", None) if not resolved_model_s3: model_data_source = getattr(container, "model_data_source", None) if model_data_source: s3_data_source = getattr(model_data_source, "s3_data_source", None) if s3_data_source: resolved_model_s3 = getattr(s3_data_source, "s3_uri", None) # Resolve inference image: try explicit image first, then derive # from base_model for Nova models using escrow account pattern resolved_image = getattr(container, "image", None) if not resolved_image: base_model = getattr(container, "base_model", None) if base_model: hub_content_name = getattr(base_model, "hub_content_name", None) if hub_content_name and "nova" in (hub_content_name or "").lower(): resolved_image = _get_nova_inference_image_uri(region) if resolved_model_s3 and resolved_image: _logger.info( "Auto-resolved trainer model artifacts for create_endpoint mode: " "model_s3_uri=%s, inference_image_uri=%s", resolved_model_s3, resolved_image, ) values["model_s3_uri"] = resolved_model_s3 values["inference_image_uri"] = resolved_image else: _logger.warning( "Trainer output model package does not contain model S3 URI " "or inference image in inference_specification; " "falling back to bedrock mode. " "(resolved_model_s3=%s, resolved_image=%s)", resolved_model_s3, resolved_image, ) else: _logger.warning( "Trainer output model package has no inference_specification; " "falling back to bedrock mode." ) except Exception as e: _logger.warning( "Failed to resolve trainer model artifacts: %s. " "Falling back to bedrock mode.", e, ) return values @validator("image_uri") def _validate_image_uri(cls, v): if v is not None and not _ECR_URI_PATTERN.match(v): raise ValueError(f"image_uri must be a valid ECR URI. Got: '{v}'") return v @validator("instance_type") def _validate_instance_type(cls, v): if not v.startswith("ml."): raise ValueError(f"instance_type must start with 'ml.'. Got: '{v}'") return v @validator("max_connections") def _validate_max_connections(cls, v): if v < 1: raise ValueError(f"max_connections must be >= 1. Got: {v}") return v @validator("max_retries") def _validate_max_retries(cls, v): if v < 1: raise ValueError(f"max_retries must be >= 1. Got: {v}") return v @validator("max_tokens") def _validate_max_tokens(cls, v): if v < 1: raise ValueError(f"max_tokens must be >= 1. Got: {v}") return v @validator("timeout") def _validate_timeout(cls, v): if v < 1: raise ValueError(f"timeout must be >= 1 (seconds). Got: {v}") return v @validator("temperature") def _validate_temperature(cls, v): if v < 0.0 or v > 2.0: raise ValueError(f"temperature must be in [0.0, 2.0]. Got: {v}") return v @validator("top_p") def _validate_top_p(cls, v): if v < 0.0 or v > 1.0: raise ValueError(f"top_p must be in [0.0, 1.0]. Got: {v}") return v @validator("top_k") def _validate_top_k(cls, v): # -1 disables top-k sampling; otherwise must be a positive int if v != -1 and v < 1: raise ValueError(f"top_k must be -1 (disabled) or >= 1. Got: {v}") return v def _infer_scenario(self) -> str: """Determine the inference provider mode. Returns: One of 'bedrock', 'existing_endpoint', 'create_endpoint'. """ if self.endpoint_name: return "existing_endpoint" if self.model_s3_uri: return "create_endpoint" return "bedrock" def _get_bedrock_model_id(self, region: str) -> str: """Resolve the Bedrock model ID for bedrock inference mode. Priority: explicit bedrock_model_id > model's bedrock_model_id > model string. """ if self.bedrock_model_id: return self.bedrock_model_id # Try to derive from model info (cross-region inference profile format) try: model_info = self._get_resolved_model_info() if hasattr(model_info, "bedrock_model_id") and model_info.bedrock_model_id: return model_info.bedrock_model_id except Exception: pass # Fall back to model string if it looks like a model ID # Use cross-region inference profile format: <continent_prefix>.<model_id> if isinstance(self.model, str) and not self.model.startswith("arn:"): region_prefix = region.split("-")[0] return f"{region_prefix}.{self.model}" raise ValueError( "Cannot determine Bedrock model ID. Provide bedrock_model_id explicitly " "or use a model that has a bedrock_model_id mapping." ) def _build_inference_provider_config(self, region: str) -> dict: """Build the inference_provider section of the YAML config.""" scenario = self._infer_scenario() if scenario == "bedrock": model_id = self._get_bedrock_model_id(region) return { "bedrock": { "model_id": model_id, "region": region, } } elif scenario == "existing_endpoint": config = { "sagemaker_endpoint": { "endpoint_name": self.endpoint_name, "region": region, } } if self.context_length: config["sagemaker_endpoint"]["context_length"] = self.context_length if self.max_concurrency: config["sagemaker_endpoint"]["max_concurrency"] = self.max_concurrency return config else: # create_endpoint config = { "sagemaker_endpoint": { "endpoint_name": None, "region": region, "model_s3_uri": self.model_s3_uri, "inference_image_uri": self.inference_image_uri, "cleanup_endpoint": self.cleanup_endpoint, "instance_count": self.endpoint_instance_count, "endpoint_prefix": self.endpoint_prefix, } } ep = config["sagemaker_endpoint"] # execution_role_arn: use explicit endpoint role, fall back to evaluator role execution_role = self.endpoint_execution_role_arn or self.role if execution_role: ep["execution_role_arn"] = execution_role if self.endpoint_instance_type: ep["instance_type"] = self.endpoint_instance_type if self.context_length: ep["context_length"] = self.context_length if self.max_concurrency: ep["max_concurrency"] = self.max_concurrency if self.endpoint_environment: ep["environment"] = self.endpoint_environment return config def _build_yaml_config(self, region: str) -> dict: """Build the complete YAML config dict for the InspectAI container. Matches the structure expected by the sagemaker-inspect-ai container: inference_provider, benchmarks, eval, output. """ config = {} # Inference provider config["inference_provider"] = self._build_inference_provider_config(region) # Benchmarks benchmarks = {} if self.benchmarks_path: benchmarks["s3_path"] = self.benchmarks_path if self.tasks: benchmarks["tasks"] = [] for task in self.tasks: task_entry = {"name": task["name"]} if "path" in task: task_entry["path"] = task["path"] if "limit" in task: task_entry["limit"] = task["limit"] if "epochs" in task: task_entry["epochs"] = task["epochs"] if "task_args" in task: task_entry["task_args"] = task["task_args"] benchmarks["tasks"].append(task_entry) config["benchmarks"] = benchmarks # Eval settings eval_config = { "max_connections": self.max_connections, "max_retries": self.max_retries, "timeout": self.timeout, "decoding": { "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, "max_tokens": self.max_tokens, }, } if self.extra_args: eval_config["extra_args"] = self.extra_args config["eval"] = eval_config # Output output = {} output_path = self.s3_output_path.rstrip("/") output["s3_path"] = f"{output_path}/inspectai-results/" if self.output_format: output["output_format"] = self.output_format config["output"] = output return config def _upload_yaml_config(self, config: dict, region: str) -> str: """Serialize config to YAML and upload to S3. Returns the S3 prefix URI where inspect_config.yaml was uploaded. """ yaml_content = yaml.dump(config, default_flow_style=False, sort_keys=False) s3_base = self.s3_output_path.rstrip("/") config_prefix = f"{s3_base}/inspectai-config/{uuid.uuid4()}" config_s3_uri = f"{config_prefix}/inspect_config.yaml" _logger.info(f"Uploading InspectAI config to: {config_s3_uri}") S3Uploader.upload_string_as_file_body( body=yaml_content, desired_s3_uri=config_s3_uri, kms_key=self.kms_key_id, sagemaker_session=self.sagemaker_session, ) return config_prefix
[docs] def upload_benchmarks(self, local_path: str) -> str: """Upload local benchmark files to S3. Uploads all files from a local directory to an S3 prefix under the configured output path. The uploaded path can be used as ``benchmarks_path``. Args: local_path: Local directory path containing ``.py`` files with ``@task`` decorators. Returns: S3 URI prefix where benchmarks were uploaded. Raises: ValueError: If local_path does not exist or is not a directory. """ if not os.path.isdir(local_path): raise ValueError(f"local_path must be an existing directory. Got: '{local_path}'") s3_base = self.s3_output_path.rstrip("/") s3_prefix = f"{s3_base}/benchmarks/{uuid.uuid4()}" _logger.info(f"Uploading benchmarks from '{local_path}' to '{s3_prefix}'") S3Uploader.upload( local_path=local_path, desired_s3_uri=s3_prefix, kms_key=self.kms_key_id, sagemaker_session=self.sagemaker_session, ) _logger.info(f"Benchmarks uploaded to: {s3_prefix}") return s3_prefix
[docs] @_telemetry_emitter( feature=Feature.MODEL_CUSTOMIZATION, func_name="InspectAIEvaluator.evaluate" ) def evaluate(self) -> EvaluationPipelineExecution: """Create and start an InspectAI evaluation job. Serializes the InspectAI configuration to YAML, uploads it to S3, and launches a single-step SageMaker Pipeline with the InspectAI container. Returns: EvaluationPipelineExecution: The started evaluation execution with ``.wait()``, ``.refresh()``, and ``.show_results()`` methods. Example: .. code:: python evaluator = InspectAIEvaluator( model="amazon-nova-lite-v1", benchmarks_path="s3://my-bucket/benchmarks/", tasks=[{"name": "boolq_pt", "limit": 10}], s3_output_path="s3://my-bucket/eval-output/", ) execution = evaluator.evaluate() execution.wait() execution.show_results() """ # Get AWS execution context aws_context = self._get_aws_execution_context() region = aws_context["region"] role_arn = aws_context["role_arn"] # Build and upload YAML config yaml_config = self._build_yaml_config(region) config_s3_prefix = self._upload_yaml_config(yaml_config, region) # Resolve container image URI resolved_image_uri = self.image_uri or _get_inspect_ai_default_image_uri(region) # Build job name prefix (keep total under 63 chars after pipeline exec ID appended) base_name = self.base_eval_name or "inspectai-eval" job_name_prefix = base_name[:26] # Build template context template_context = { "job_name_prefix": job_name_prefix, "image_uri": resolved_image_uri, "role_arn": role_arn, "instance_type": self.instance_type, "max_runtime_seconds": self.max_runtime_seconds, "config_s3_uri": config_s3_prefix, "s3_output_path": self.s3_output_path.rstrip("/"), "kms_key_id": self.kms_key_id, "environment": self.environment, "vpc_config": self.networking is not None, } if self.networking: template_context["vpc_security_group_ids"] = self.networking.security_group_ids template_context["vpc_subnets"] = self.networking.subnets # Render pipeline definition pipeline_definition = self._render_pipeline_definition( INSPECT_AI_TEMPLATE, template_context ) # Start execution name = self.base_eval_name or "inspectai-eval" return self._start_execution( eval_type=EvalType.INSPECT_AI, name=name, pipeline_definition=pipeline_definition, role_arn=role_arn, region=region, )
[docs] @classmethod @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="InspectAIEvaluator.get_all") def get_all( cls, session: Optional[Any] = None, region: Optional[str] = None ) -> Iterator[EvaluationPipelineExecution]: """Get all InspectAI evaluation executions. Args: session (Optional[Any]): Optional boto3 session. region (Optional[str]): Optional AWS region. Yields: EvaluationPipelineExecution: InspectAI evaluation execution instances. """ yield from EvaluationPipelineExecution.get_all( eval_type=EvalType.INSPECT_AI, session=session, region=region, )