sagemaker.train.rlaif_trainer#
Classes
|
Class that performs Reinforcement Learning from AI Feedback (RLAIF) fine-tuning on foundation models using AWS SageMaker. |
- class sagemaker.train.rlaif_trainer.RLAIFTrainer(model: str | ModelPackage, training_type: TrainingType | str = TrainingType.LORA, model_package_group: str | ModelPackageGroup | None = None, reward_model_id: str | None = None, reward_prompt: str | Evaluator | None = None, mlflow_resource_arn: str | MlflowTrackingServer | None = None, mlflow_experiment_name: str | None = None, mlflow_run_name: str | None = None, training_dataset: str | DataSet | None = None, validation_dataset: str | DataSet | None = None, s3_output_path: str | None = None, kms_key_id: str | None = None, networking: VpcConfig | None = None, accept_eula: bool = False, stopping_condition: StoppingCondition | None = None, **kwargs)[source]#
Bases:
BaseTrainerClass that performs Reinforcement Learning from AI Feedback (RLAIF) fine-tuning on foundation models using AWS SageMaker.
Example:
from sagemaker.train import RLAIFTrainer from sagemaker.train.common import TrainingType trainer = RLAIFTrainer( model="meta-llama/Llama-2-7b-hf", training_type=TrainingType.LORA, model_package_group="my-model-group", reward_model_id="reward-model-id", reward_prompt="Rate the helpfulness of this response on a scale of 1-10", training_dataset="s3://bucket/rlaif_data.jsonl" ) trainer.train() # Complete workflow: create -> wait -> get model package ARN trainer = RLAIFTrainer( model="meta-llama/Llama-2-7b-hf", model_package_group="my-rlaif-models", reward_model_id="reward-model-id", reward_prompt="Rate the helpfulness of this response on a scale of 1-10" ) # Create training job (non-blocking) training_job = trainer.train( training_dataset="s3://bucket/rlaif_data.jsonl", wait=False ) # Wait for completion training_job.wait() # Refresh job status training_job.refresh() # Get the fine-tuned model package ARN model_package_arn = training_job.output_model_package_arn
- Parameters:
model (Union[str, ModelPackage]) – The foundation model to fine-tune. Can be a model name string, model package ARN, or ModelPackage object.
training_type (Union[TrainingType, str]) – The fine-tuning approach. Valid values are TrainingType.LORA (default), TrainingType.FULL.
model_package_group (Optional[Union[str, ModelPackageGroup]]) – The model package group for storing the fine-tuned model. Can be a group name, ARN, or ModelPackageGroup object. Required when model is not a ModelPackage.
reward_model_id (str) – Bedrock model identifier for generating LLM feedback. Required for RLAIF training to provide reward signals.
reward_prompt (Union[str, Evaluator]) – The reward prompt or evaluator for AI feedback generation. Can be a prompt string or Evaluator object. For Builtin metric prompts refer: https://docs.aws.amazon.com/bedrock/latest/userguide/model-evaluation-metrics.html
mlflow_resource_arn (Optional[Union[str, MlflowTrackingServer]]) – The MLflow tracking server ARN for experiment tracking. If not specified, uses default MLflow experience.
mlflow_experiment_name (Optional[str]) – The MLflow experiment name for organizing runs.
mlflow_run_name (Optional[str]) – The MLflow run name for this training job.
training_dataset (Optional[Union[str, DataSet]]) – The training dataset. Can be a dataset ARN, or DataSet object.
validation_dataset (Optional[Union[str, DataSet]]) – The validation dataset. Can be a dataset ARN, or DataSet object.
s3_output_path (Optional[str]) – The S3 path for training job outputs. If not specified, defaults to s3://sagemaker-<region>-<account>/output.
kms_key_id (Optional[str]) – The KMS key ID for encrypting training job outputs.
networking (Optional[VpcConfig]) – The VPC configuration for the training job.
stopping_condition (Optional[StoppingCondition]) – The stopping condition to override training runtime limit. If not specified, uses SageMaker service default (24 hours for serverless training).
- train(training_dataset: str | DataSet | None = None, validation_dataset: str | DataSet | None = None, wait: bool = True, wait_timeout: int | None = None, poll: int = 5)[source]#
Execute the RLAIF training job.
- Parameters:
training_dataset (Optional[Union[str, DataSet]]) – The training dataset for this job. Overrides the dataset specified in __init__. Can be an S3 URI, dataset ARN, or DataSet object.
validation_dataset (Optional[Union[str, DataSet]]) – The validation dataset for this job. Overrides the dataset specified in __init__. Can be an S3 URI, dataset ARN, or DataSet object.
wait (bool) – Whether to wait for the training job to complete. Defaults to True.
wait_timeout (Optional[int]) – Maximum time in seconds to wait for the training job to complete. Only used when wait=True. If None, uses the default timeout from the wait utility.
poll (int) – Polling interval in seconds for checking training job status. Defaults to 5.
- Returns:
The SageMaker training job object.
- Return type: