Source code for sagemaker.core.token_generator

"""AWS SageMaker Token Generator.

A lightweight module for generating short-term bearer tokens for AWS SageMaker
API authentication. Provides the ``generate_token`` helper and the lower-level
``SageMakerTokenGenerator`` class.

Example::

    >>> from sagemaker.core.token_generator import generate_token
    >>> token = generate_token(region="us-east-1")
"""

from __future__ import annotations

import os
from datetime import timedelta

from botocore.credentials import CredentialProvider
from botocore.session import Session

from sagemaker.core.token_generator.token_generator import (
    TOKEN_DURATION,
    SageMakerTokenGenerator,
    _generate_token,
)

__all__ = ["SageMakerTokenGenerator", "generate_token"]


[docs] def generate_token( region: str | None = None, aws_credentials_provider: CredentialProvider | None = None, expiry: timedelta = timedelta(hours=12), ) -> str: """Generate a short-lived AWS SageMaker bearer token. Args: region (str): AWS region. Falls back to the ``AWS_REGION`` environment variable when not provided. aws_credentials_provider (CredentialProvider): Optional credential provider. Uses the default AWS credential chain when omitted. expiry (timedelta): Token lifetime. Must be between 1 second and 12 hours inclusive. Defaults to 12 hours. Returns: str: A bearer token string. Raises: ValueError: If *region* is missing or *expiry* is out of range. RuntimeError: If no valid AWS credentials are found. """ region = region or os.environ.get("AWS_REGION") if not region: raise ValueError("Region must be provided or set via the AWS_REGION environment variable.") if expiry.total_seconds() <= 0 or expiry.total_seconds() > TOKEN_DURATION: raise ValueError( "Token expiry must be greater than zero and less than or equal to 12 hours" ) credentials = ( aws_credentials_provider.load() if aws_credentials_provider else Session().get_credentials() ) if credentials is None: raise RuntimeError( "No AWS credentials found. Check your environment or credential provider." ) return _generate_token(credentials, region, int(expiry.total_seconds()))