Source code for sagemaker.train.custom_agent_lambda
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""CustomAgentLambda — Lambda-based agent environment for Agentic RFT."""
from __future__ import annotations
import io
import re
import zipfile
from pathlib import Path
from typing import Optional
import boto3
S3_URI_PATTERN = re.compile(r"^s3://[^/]+(/.*)?$")
[docs]
class CustomAgentLambda:
"""Lambda-based agent environment for Agentic RFT.
Creates and wraps Lambda functions that serve as agent environments or
bridges between SageMaker and custom agent environments (e.g., LangSmith,
EKS, Fargate).
Args:
lambda_arn: ARN of the Lambda function.
"""
def __init__(self, lambda_arn: str):
self.lambda_arn = lambda_arn
def __repr__(self):
return f"CustomAgentLambda(lambda_arn={self.lambda_arn!r})"
[docs]
@classmethod
def create(
cls,
source: str,
function_name: Optional[str] = None,
role: Optional[str] = None,
runtime: str = "python3.12",
handler: str = "lambda_function.handler",
timeout: int = 900,
memory_size: int = 256,
environment: Optional[dict] = None,
sagemaker_session=None,
) -> CustomAgentLambda:
"""Create a new Lambda function and return an CustomAgentLambda.
The ``source`` parameter accepts three formats:
- **S3 URI** (``s3://bucket/key.zip``): deploys from an S3 artifact.
- **Local file path**: reads the file, packages it as a zip, and uploads.
- **Inline code string**: packages the raw code as a zip and uploads.
Detection order: S3 URI → existing local path → inline code.
Args:
source: S3 URI, local file path, or inline Python code string.
function_name: Lambda function name. If not provided, a unique name
is generated automatically.
role: IAM role ARN for the Lambda execution role.
runtime: Lambda runtime (default: ``"python3.12"``).
handler: Lambda handler (default: ``"lambda_function.handler"``).
timeout: Lambda timeout in seconds (default: 900).
memory_size: Lambda memory in MB (default: 256).
environment: Dict of environment variables for the Lambda.
sagemaker_session: Optional SageMaker session for role resolution.
Returns:
CustomAgentLambda wrapping the created Lambda ARN.
Raises:
ValueError: If ``source`` is empty.
"""
if not source or not source.strip():
raise ValueError("'source' must be provided.")
if not function_name:
from sagemaker.train.utils import _get_unique_name
function_name = _get_unique_name("SageMaker-agent-adapter", max_length=64)
if not role:
from sagemaker.train.defaults import TrainDefaults
sagemaker_session = TrainDefaults.get_sagemaker_session(
sagemaker_session=sagemaker_session
)
role = TrainDefaults.get_role(role=role, sagemaker_session=sagemaker_session)
lambda_client = boto3.client("lambda")
if S3_URI_PATTERN.match(source):
bucket, key = _parse_s3_uri(source)
if key.endswith(".zip"):
code_param = {"S3Bucket": bucket, "S3Key": key}
else:
s3_client = boto3.client("s3")
response = s3_client.get_object(Bucket=bucket, Key=key)
code_content = response["Body"].read().decode("utf-8")
code_param = {"ZipFile": _zip_code(code_content)}
else:
code_content = source
if Path(source).exists():
with open(source, "r") as f:
code_content = f.read()
code_param = {"ZipFile": _zip_code(code_content)}
response = lambda_client.create_function(
FunctionName=function_name,
Runtime=runtime,
Role=role,
Handler=handler,
Code=code_param,
Timeout=timeout,
MemorySize=memory_size,
Environment={"Variables": environment} if environment else {},
)
return cls(lambda_arn=response["FunctionArn"])
[docs]
@classmethod
def get(cls, lambda_arn: str) -> CustomAgentLambda:
"""Wrap an existing Lambda ARN.
Validates the Lambda exists by calling GetFunction.
Args:
lambda_arn: ARN of an existing Lambda function.
Returns:
CustomAgentLambda wrapping the Lambda ARN.
Raises:
botocore.exceptions.ClientError: If the Lambda does not exist.
"""
lambda_client = boto3.client("lambda")
lambda_client.get_function(FunctionName=lambda_arn)
return cls(lambda_arn=lambda_arn)
def _parse_s3_uri(uri: str) -> tuple[str, str]:
"""Parse an S3 URI into (bucket, key)."""
path = uri[len("s3://"):]
bucket, _, key = path.partition("/")
return bucket, key
def _zip_code(code_content: str) -> bytes:
"""Package code content into a zip archive."""
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w") as zf:
zf.writestr("lambda_function.py", code_content)
zip_buffer.seek(0)
return zip_buffer.read()