Source code for sagemaker.serve.utils.model_package_utils
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Utilities for Restricted Model Package support."""
from __future__ import absolute_import
from typing import Optional
[docs]
def is_restricted_model_package(model_package) -> bool:
"""Detect if a model package is a Restricted Model Package.
Args:
model_package: A ModelPackage resource object.
Returns:
True if the model package is restricted, False otherwise.
"""
if not model_package:
return False
managed_storage_type = getattr(model_package, "managed_storage_type", None)
return managed_storage_type == "Restricted"
[docs]
def get_s3_uri_from_inference_spec(inference_specification) -> Optional[str]:
"""Extract s3_uri from the first container's model_data_source.
Args:
inference_specification: The inference_specification from a ModelPackage.
Returns:
The s3_uri string, or None if not available.
"""
if not inference_specification:
return None
containers = getattr(inference_specification, "containers", None)
if not containers:
return None
container = containers[0]
data_source = getattr(container, "model_data_source", None)
if not data_source:
return None
s3_data_source = getattr(data_source, "s3_data_source", None)
if not s3_data_source:
return None
return getattr(s3_data_source, "s3_uri", None)