Interactive exploration of the Proximal Policy Optimization trainer implementation
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import math
import os
import textwrap
import time
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from pathlib import Path
from typing import Optional, Union
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from accelerate import Accelerator
from accelerate.utils import broadcast, gather_object
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
BaseImageProcessor,
DataCollatorWithPadding,
FeatureExtractionMixin,
GenerationConfig,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
TrainerCallback,
TrainerControl,
is_wandb_available,
)
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
from transformers.utils import is_peft_available, is_rich_available
from ..core import masked_mean, masked_whiten
from ..models import create_reference_model
from ..models.utils import unwrap_model_for_generation
from .ppo_config import PPOConfig
from .utils import (
OnlineTrainerState,
batch_generation,
disable_dropout_in_model,
empty_cache,
exact_div,
first_true_indices,
forward,
generate_model_card,
get_comet_experiment_url,
get_reward,
log_table_to_comet_experiment,
peft_module_casting_to_bf16,
prepare_deepspeed,
print_rich_table,
selective_log_softmax,
truncate_response,
)
if is_peft_available():
from peft import PeftConfig, PeftModel, get_peft_model
if is_wandb_available():
import wandb
INVALID_LOGPROB = 1.0
# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29
# we did this we can do a single `model = accelerator.prepare(model)`
class PolicyAndValueWrapper(nn.Module):
def __init__(self, policy, value_model) -> None:
super().__init__()
self.policy = policy
self.value_model = value_model
self.critic_backbone = getattr(value_model, value_model.base_model_prefix)
def forward(self, **kwargs):
output = self.critic_backbone(**kwargs)
logits = self.value_model.score(output.hidden_states[-1])
return self.policy(**kwargs), logits
class PPOTrainer(Trainer):
_tag_names = ["trl", "ppo"]
def __init__(
self,
args: PPOConfig,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
],
model: nn.Module,
ref_model: Optional[nn.Module],
reward_model: nn.Module,
train_dataset: Dataset,
value_model: nn.Module,
data_collator: Optional[DataCollatorWithPadding] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
# less commonly used
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
callbacks: Optional[list[TrainerCallback]] = None,
peft_config: Optional["PeftConfig"] = None,
) -> None:
if ref_model is model:
raise ValueError(
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
"same as `model`, you must make a copy of it, or `None` if you use peft."
)
self.args = args
self.processing_class = processing_class
self.policy_model = model
# Define the collator if not provided
if data_collator is None:
data_collator = DataCollatorWithPadding(self.processing_class)
# Handle stop token settings: update policy model's generation_config to use provided stop token
if args.stop_token and args.stop_token_id:
raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
elif args.stop_token:
if args.stop_token == "eos":
self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
else:
raise ValueError(
f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
)
else:
self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
# Check that the kl estimator is valid
if self.args.kl_estimator not in {"k1", "k3"}:
raise ValueError(
"kl_estimator must be either 'k1' (straightforward, unbiased) or 'k3' (lower variance, unbiased, "
"appears to be a strictly better estimator). See "
"[Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for details."
)
# peft support
if not is_peft_available() and peft_config is not None:
raise ImportError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:
# if model is a peft model and we have a peft_confg, we merge and unload it first
if isinstance(self.policy_model, PeftModel):
self.policy_model = self.policy_model.merge_and_unload()
# get peft model with the given config
self.policy_model = get_peft_model(self.policy_model, peft_config)
if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
peft_module_casting_to_bf16(self.policy_model)
self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
self.model_adapter_name = args.model_adapter_name
self.ref_adapter_name = args.ref_adapter_name
if ref_model:
self.ref_model = ref_model
elif self.is_peft_model:
self.ref_model = None
else:
self.ref_model = create_reference_model(self.policy_model)
self.reward_model = reward_model
self.train_dataset = train_dataset
self.train_dataset_len = len(train_dataset)
self.value_model = value_model
self.data_collator = data_collator
self.eval_dataset = eval_dataset
self.optimizer, self.lr_scheduler = optimizers
self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
#########
# calculate various batch sizes
#########
if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
self.accelerator = accelerator
args.world_size = accelerator.num_processes
args.local_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps
args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
args.batch_size = int(args.local_batch_size * args.world_size)
args.mini_batch_size = exact_div(
args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
)
args.local_mini_batch_size = exact_div(
args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
)
if args.whiten_rewards:
assert args.local_mini_batch_size >= 8, (
f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
)
# `per_rank_rollout_batch_size` is our `args.local_batch_size`
# `per_rank_minibatch_size` is our `args.local_mini_batch_size`
args.num_total_batches = math.ceil(
args.total_episodes / args.batch_size
) # we may train for more than `total_episodes`
time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
if args.num_sample_generations > 0:
self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
self.local_dataloader_batch_size = args.local_batch_size
#########
# setup model, optimizer, and others
#########
for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
if module is not None:
disable_dropout_in_model(module)
self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
self.model.config = self.policy_model.config # needed for pushing to hub
self.create_optimizer_and_scheduler(
num_training_steps=args.num_total_batches
) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
#########
### trainer specifics
#########
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self.callback_handler = CallbackHandler(
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
self.control = TrainerControl()
self.state = OnlineTrainerState(
is_local_process_zero=self.is_local_process_zero(),
is_world_process_zero=self.is_world_process_zero(),
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
],
)
self.current_flos = 0
self.hp_search_backend = None
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
# Create distant repo and output directory if needed
self.hub_model_id = None
if self.args.push_to_hub:
self.init_hf_repo()
if self.args.should_save:
os.makedirs(self.args.output_dir, exist_ok=True)
# Add tags for models that have been loaded with the correct transformers version
if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names)
#########
### setup dataloader
#########
self.dataloader = DataLoader(
self.train_dataset,
batch_size=self.local_dataloader_batch_size,
shuffle=True,
collate_fn=self.data_collator,
drop_last=True, # needed; otherwise the last batch will be of ragged shape
)
# sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
# see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
torch.manual_seed(args.seed)
self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
torch.manual_seed(self.local_seed) # reset the local seed again
self.eval_dataloader = DataLoader(
self.eval_dataset,
batch_size=args.per_device_eval_batch_size,
collate_fn=self.data_collator,
drop_last=True,
) # no need to shuffle eval dataset
self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
if self.is_deepspeed_enabled:
self.reward_model = prepare_deepspeed(
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
)
if self.ref_model is None:
if not self.is_peft_model:
raise ValueError("No reference model and model is not a Peft model.")
else:
self.ref_model = prepare_deepspeed(
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
)
else:
if self.ref_model is None:
if not self.is_peft_model:
raise ValueError("No reference model and model is not a Peft model.")
else:
self.ref_model = self.ref_model.to(self.accelerator.device)
self.reward_model = self.reward_model.to(self.accelerator.device)
def get_train_dataloader(self) -> DataLoader:
return self.dataloader
def get_eval_dataloader(self) -> DataLoader:
return self.eval_dataloader
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with (
self.accelerator.unwrap_model(self.model.policy).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.policy.set_adapter(self.ref_adapter_name)
yield
if self.ref_adapter_name:
self.model.policy.set_adapter(self.model_adapter_name or "default")
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
backup_model = self.model
self.model = self.model.policy # save only the policy
if self.is_deepspeed_enabled:
backup_deepspeed = self.deepspeed
self.deepspeed = self.model
super().save_model(output_dir, _internal_call)
self.model = backup_model
if self.is_deepspeed_enabled:
self.deepspeed = backup_deepspeed
def train(self):
args = self.args
accelerator = self.accelerator
optimizer = self.optimizer
model = self.model
ref_policy = self.ref_model
reward_model = self.reward_model
processing_class = self.processing_class
dataloader = self.dataloader
device = accelerator.device
def repeat_generator():
while True:
yield from dataloader
iter_dataloader = iter(repeat_generator())
generation_config = GenerationConfig(
max_new_tokens=args.response_length,
temperature=(args.temperature + 1e-7),
top_k=0.0,
top_p=1.0,
do_sample=True,
)
accelerator.print("===training policy===")
start_time = time.time()
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
approxkl_stats = torch.zeros(stats_shape, device=device)
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
pg_loss_stats = torch.zeros(stats_shape, device=device)
vf_loss_stats = torch.zeros(stats_shape, device=device)
vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
entropy_stats = torch.zeros(stats_shape, device=device)
ratio_stats = torch.zeros(stats_shape, device=device)
model.train()
# trainer state initialization
self.state.global_step = 0
self.state.episode = 0
self.state.max_steps = args.num_total_batches
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
if args.logging_steps < 1:
self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
else:
self.state.logging_steps = args.logging_steps
if args.eval_steps is not None:
if args.eval_steps < 1:
self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
else:
self.state.eval_steps = args.eval_steps
if args.save_steps is not None:
if args.save_steps < 1:
self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
else:
self.state.save_steps = args.save_steps
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model
self.model_wrapped = self.model
for update in range(1, args.num_total_batches + 1):
self.state.episode += 1 * args.batch_size
data = next(iter_dataloader)
with torch.no_grad():
queries = data["input_ids"].to(device)
context_length = queries.shape[1]
responses = []
postprocessed_responses = []
logprobs = []
ref_logprobs = []
scores = []
sequence_lengths = []
values = []
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
query_responses, logitss = batch_generation(
unwrapped_model.policy,
queries,
args.local_rollout_forward_batch_size,
processing_class.pad_token_id,
generation_config,
)
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
query = queries[i : i + args.local_rollout_forward_batch_size]
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
response = query_response[:, context_length:]
logits = logitss[i : i + args.local_rollout_forward_batch_size]
logprob = selective_log_softmax(logits, response)
del logits
empty_cache()
if ref_policy is None:
with self.null_ref_context():
ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
else:
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.temperature + 1e-7
ref_logprob = selective_log_softmax(ref_logits, response)
del ref_output, ref_logits
empty_cache()
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
postprocessed_response = response
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(
self.stop_token_id, processing_class.pad_token_id, response
)
# Response Processing 2. run reward model on the truncated responses
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
unwrapped_value_model = accelerator.unwrap_model(model).value_model
full_value, _, _ = get_reward(
unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
)
value = full_value[:, context_length - 1 : -1].squeeze(-1)
_, score, _ = get_reward(
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
)
responses.append(response)
postprocessed_responses.append(postprocessed_response)
logprobs.append(logprob)
ref_logprobs.append(ref_logprob)
sequence_lengths.append(sequence_length)
scores.append(score)
values.append(value)
responses = torch.cat(responses, 0)
postprocessed_responses = torch.cat(postprocessed_responses, 0)
logprobs = torch.cat(logprobs, 0)
ref_logprobs = torch.cat(ref_logprobs, 0)
sequence_lengths = torch.cat(sequence_lengths, 0)
scores = torch.cat(scores, 0)
values = torch.cat(values, 0)
del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
empty_cache()
gc.collect()
# Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
# Completions not passing that filter will receive a lower score.
contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
if self.args.missing_eos_penalty is not None:
scores[~contain_eos_token] -= self.args.missing_eos_penalty
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
sequence_lengths_p1 = sequence_lengths + 1
padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
values = torch.masked_fill(values, padding_mask_p1, 0)
# 4. compute rewards
# Formula used by http://joschu.net/blog/kl-approx.html for the k1 and k3 estimators
logr = ref_logprobs - logprobs
kl = -logr if args.kl_estimator == "k1" else (logr.exp() - 1) - logr # Else statement is k3
non_score_reward = -args.kl_coef * kl
rewards = non_score_reward.clone()
actual_start = torch.arange(rewards.size(0), device=rewards.device)
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
rewards[[actual_start, actual_end]] += scores
# 5. whiten rewards
if args.whiten_rewards:
rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
# 6. compute advantages and returns
lastgaelam = 0
advantages_reversed = []
gen_length = responses.shape[1]
for t in reversed(range(gen_length)):
nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
lastgaelam = delta + args.gamma * args.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], axis=1)
returns = advantages + values
advantages = masked_whiten(advantages, ~padding_mask)
advantages = torch.masked_fill(advantages, padding_mask, 0)
empty_cache()
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
for ppo_epoch_idx in range(args.num_ppo_epochs):
b_inds = np.random.permutation(args.local_batch_size)
minibatch_idx = 0
for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
mini_batch_end = mini_batch_start + args.local_mini_batch_size
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
gradient_accumulation_idx = 0
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
with accelerator.accumulate(model):
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
mb_advantage = advantages[micro_batch_inds]
mb_responses = responses[micro_batch_inds]
mb_query_responses = query_responses[micro_batch_inds]
mb_logprobs = logprobs[micro_batch_inds]
mb_return = returns[micro_batch_inds]
mb_values = values[micro_batch_inds]
output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
logits = output.logits[:, context_length - 1 : -1]
logits /= args.temperature + 1e-7
new_logprobs = selective_log_softmax(logits, mb_responses)
new_logprobs = torch.masked_fill(
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
)
vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
vpredclipped = torch.clamp(
vpred,
mb_values - args.cliprange_value,
mb_values + args.cliprange_value,
)
vf_losses1 = torch.square(vpred - mb_return)
vf_losses2 = torch.square(vpredclipped - mb_return)
vf_loss_max = torch.max(vf_losses1, vf_losses2)
vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
vf_clipfrac = masked_mean(
(vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
)
logprobs_diff = new_logprobs - mb_logprobs
ratio = torch.exp(logprobs_diff)
pg_losses = -mb_advantage * ratio
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
pg_loss_max = torch.max(pg_losses, pg_losses2)
pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
loss = pg_loss + args.vf_coef * vf_loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
with torch.no_grad():
pg_clipfrac = masked_mean(
(pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
)
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
approxkl = 0.5 * (logprobs_diff**2).mean()
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
pg_clipfrac
)
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
vf_clipfrac
)
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
gradient_accumulation_idx += 1
minibatch_idx += 1
# del everything and empty cache
# fmt: off
del (
output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
)
# fmt: on
empty_cache()
with torch.no_grad():
mean_kl = kl.sum(1).mean()
mean_entropy = (-logprobs).sum(1).mean()
mean_non_score_reward = non_score_reward.sum(1).mean()
rlhf_reward = mean_non_score_reward + scores.mean()
eps = int(self.state.episode / (time.time() - start_time))
metrics = {}
metrics["eps"] = eps
metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
metrics["objective/non_score_reward"] = (
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
)
metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
metrics["episode"] = self.state.episode
self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
self.state.global_step += 1
self.log(metrics)
self.lr_scheduler.step()
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
empty_cache()
gc.collect()
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
self.generate_completions(sampling=True)
empty_cache()
del (
query_responses,
responses,
postprocessed_responses,
logprobs,
ref_logprobs,
values,
sequence_lengths,
contain_eos_token,
sequence_lengths_p1,
response_idxs,
padding_mask,
padding_mask_p1,
rewards,
actual_start,
actual_end,
advantages,
returns,
)
empty_cache()
# HF trainer specifics
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None, metrics=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
def generate_completions(self, sampling: bool = False):
args = self.args
processing_class = self.processing_class
generation_config = GenerationConfig(
max_new_tokens=self.args.response_length,
temperature=(0.01 + 1e-7),
top_k=0.0,
top_p=1.0,
do_sample=True,
)
table = defaultdict(list)
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
for batch in self.eval_dataloader:
query = batch["input_ids"]
with torch.no_grad():
context_length = query.shape[1]
query_response, _ = batch_generation(
unwrapped_model.policy,
query,
query.shape[0],
processing_class.pad_token_id,
generation_config,
)
response = query_response[:, context_length:]
postprocessed_response = response
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(
self.stop_token_id, processing_class.pad_token_id, response
)
table["query"].extend(
gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
)
table["model response"].extend(
gather_object(processing_class.batch_decode(postprocessed_response))
)
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
_, score, _ = get_reward(
self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
)
table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
if sampling:
break
df = pd.DataFrame(table)
if self.accelerator.is_main_process:
if is_rich_available():
print_rich_table(df.iloc[0 : 0 + 5])
if "wandb" in args.report_to:
import wandb
if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})
if "comet_ml" in args.report_to:
log_table_to_comet_experiment(
name="completions.csv",
table=df,
)
# Ensure the model card is saved along with the checkpoint
def _save_checkpoint(self, model, trial):
if self.args.hub_model_id is None:
model_name = Path(self.args.output_dir).name
else:
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str` or `None`, *optional*, defaults to `None`):
Name of the model.
dataset_name (`str` or `None`, *optional*, defaults to `None`):
Name of the dataset used for training.
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
# normalize `tags` to a mutable set
if tags is None:
tags = set()
elif isinstance(tags, str):
tags = {tags}
else:
tags = set(tags)
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
tags.update(self._tag_names)
citation = textwrap.dedent("""\
@article{mziegler2019fine-tuning,
title = {{Fine-Tuning Language Models from Human Preferences}},
author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
year = 2019,
eprint = {arXiv:1909.08593}
}""")
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="PPO",
trainer_citation=citation,
paper_title="Fine-Tuning Language Models from Human Preferences",
paper_id="1909.08593",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))
__init__
- Initialization and setuptrain
- Main training loopgenerate_completions
- Generate and evaluate completionsnull_ref_context
- Context manager for reference model handlingSets up the trainer with models, datasets, and configuration. Handles PEFT integration, batch size calculations, and accelerator setup.
Main PPO training algorithm implementation:
Generates sample completions for evaluation and monitoring training progress.
Context manager for handling reference model when using PEFT adapters.
Purpose: This class is a clever optimization used in PPO training. It wraps a policy model (the actor, which generates text) and a value model (the critic, which estimates the quality of the generated text) into a single `nn.Module`. Its main goal is to allow sharing the computationally expensive transformer backbone between the actor and critic, assuming they have a similar architecture. This significantly speeds up training and reduces memory usage.
class PolicyAndValueWrapper(nn.Module):
def __init__(self, policy, value_model):
super().__init__()
self.policy = policy
self.value_model = value_model
# The critic backbone is implicitly shared if policy and value models are the same base.
# This design assumes a shared backbone for efficiency.
def forward(self, **kwargs):
# Request hidden states from the model's forward pass.
kwargs['output_hidden_states'] = True
# Run the policy model ONCE to get both logits and hidden states.
# This is efficient because the expensive backbone computation is not repeated.
policy_output = self.policy(**kwargs)
# The `hidden_states` is a tuple of all layer outputs. We take the last one.
last_hidden_state = policy_output.hidden_states[-1]
# Pass the shared hidden states to the value model's scoring head.
value_estimates = self.value_model.score(last_hidden_state)
return policy_output, value_estimates
1. base_model_prefix
:
GPT2LMHeadModel
, it's "transformer"
. For BertForSequenceClassification
, it's "bert"
. For T5ForConditionalGeneration
, it's "encoder"
.2. getattr(object, 'attribute_name')
:
getattr(x, 'y')
is the same as writing x.y
.value_model.transformer
, it uses the base_model_prefix
to dynamically fetch the correct backbone, making the wrapper work for various model architectures (like BERT, T5, etc.).3. self.critic_backbone
:
Let's simulate a setup for Reinforcement Learning from Human Feedback (RLHF).
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn
# --- Policy Model (Actor) ---
# This is a standard causal LM that generates text.
policy_model = AutoModelForCausalLM.from_pretrained("gpt2")
print(f"Policy model is: {type(policy_model)}")
# A `base_model_prefix` of 'transformer' means its backbone is at `policy_model.transformer`
print(f"Policy model's base_model_prefix: '{policy_model.base_model_prefix}'")
# --- Value Model (Critic) ---
# We create a separate model that will learn to predict a scalar "value" or "score".
# It shares the same core architecture but will have a different head.
value_model = AutoModelForCausalLM.from_pretrained("gpt2")
# Define a custom "value head" that predicts a single number from the hidden states.
class ValueHead(nn.Module):
def __init__(self, hidden_size):
super().__init__()
# A simple linear layer to map hidden state to a scalar value
self.value_head = nn.Linear(hidden_size, 1, bias=False)
def forward(self, hidden_states):
# hidden_states shape: (batch_size, seq_len, hidden_size)
# We return a value for each token in the sequence.
return self.value_head(hidden_states)
# Attach our custom head to the value model as the `.score` attribute.
value_model.score = ValueHead(value_model.config.hidden_size)
# We must also tell the wrapper where to find the backbone.
value_model.base_model_prefix = "transformer"
# ๐จ FIX NOTE: The previous version of this wrapper was buggy.
# The corrected version below is more efficient and ensures `output_hidden_states=True` is set.
class PolicyAndValueWrapper(nn.Module):
def __init__(self, policy, value_model):
super().__init__()
self.policy = policy
self.value_model = value_model
# The critic backbone is implicitly shared if policy and value models are the same base.
# This design assumes a shared backbone for efficiency.
def forward(self, **kwargs):
# Request hidden states from the model's forward pass.
kwargs['output_hidden_states'] = True
# Run the policy model ONCE to get both logits and hidden states.
# This is efficient because the expensive backbone computation is not repeated.
policy_output = self.policy(**kwargs)
# The `hidden_states` is a tuple of all layer outputs. We take the last one.
last_hidden_state = policy_output.hidden_states[-1]
# Pass the shared hidden states to the value model's scoring head.
value_estimates = self.value_model.score(last_hidden_state)
return policy_output, value_estimates
# Instantiate the wrapper
model_wrapper = PolicyAndValueWrapper(policy_model, value_model)
# --- Numerical Example ---
tokenizer = AutoTokenizer.from_pretrained("gpt2")
text = "The capital of France is"
inputs = tokenizer(text, return_tensors="pt")
# inputs['input_ids'] is tensor([[ 464, 2159, 286, 6701, 318]])
# Perform a forward pass
with torch.no_grad():
# The wrapper takes the same inputs as a standard Hugging Face model
policy_output, value_estimates = model_wrapper(**inputs)
print("\n--- Outputs ---")
# 1. Policy Output (from the full policy model)
print(f"Policy Logits Shape: {policy_output.logits.shape}")
# -> torch.Size([1, 5, 50257]) (batch_size, sequence_length, vocab_size)
# 2. Value Estimates (from shared backbone + score head)
print(f"Value Estimates Shape: {value_estimates.shape}")
# -> torch.Size([1, 5, 1]) (batch_size, sequence_length, 1)
print(f"Value for each token:\n{value_estimates.squeeze()}")
# -> tensor([0.1521, 0.1833, 0.1685, 0.1587, 0.1764]) (Example values)
In PPO, for every step, you need both the action probabilities (from the policy) and the state value (from the critic).
The Naive, Inefficient Way:
policy_output = policy_model(inputs)
(Full transformer pass)value_output = value_model(inputs)
(Another full transformer pass)The Efficient Wrapper Way:
The PolicyAndValueWrapper
is not just a container; it's an optimization pattern. The corrected example now properly demonstrates how to implement it efficiently by running the shared backbone once, which is critical for performant training of large language models with actor-critic methods.
Purpose: This section of the code integrates Hugging Face's `peft` library, allowing users to fine-tune massive language models using a fraction of the memory and computational power. Instead of training all the billions of parameters in a model, PEFT techniques like LoRA (Low-Rank Adaptation) freeze the original model and inject small, trainable "adapter" layers. This makes fine-tuning accessible on consumer hardware.
# 1. Check if PEFT library is installed if a config is provided
if not is_peft_available() and peft_config is not None:
raise ImportError(...)
# 2. Main PEFT logic block
elif is_peft_available() and peft_config is not None:
# If the model is already a PEFT model, merge the old adapters first
# This gives a clean slate before applying the new config.
if isinstance(self.policy_model, PeftModel):
self.policy_model = self.policy_model.merge_and_unload()
# ๐ฅ KEY LINE: Apply the new PEFT config (e.g., LoRA) to the base model
self.policy_model = get_peft_model(self.policy_model, peft_config)
# Compatibility fix for 4-bit models trained in bfloat16
if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
peft_module_casting_to_bf16(self.policy_model)
# 3. Set flags for later use in the trainer
self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
self.model_adapter_name = args.model_adapter_name
self.ref_adapter_name = args.ref_adapter_name
merge_and_unload()
: This is for edge cases. If you pass a model that has *already* been modified with PEFT, this function merges the existing adapter's weights into the base model and removes the adapter layers, effectively "baking in" the old changes to create a standard model again.get_peft_model()
: This is the core of PEFT. It takes the original, frozen transformer model and injects the small, trainable adapter layers (e.g., LoRA layers) into the places specified in the `peft_config`.Let's see the dramatic impact on the number of trainable parameters.
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, PeftModel
# Load a standard GPT-2 model
base_model = AutoModelForCausalLM.from_pretrained("gpt2")
# --- Before PEFT ---
total_params = sum(p.numel() for p in base_model.parameters())
trainable_params = sum(p.numel() for p in base_model.parameters() if p.requires_grad)
print(f"--- Base Model ---")
print(f"Total Parameters: {total_params / 1e6:.2f}M")
print(f"Trainable Parameters: {trainable_params / 1e6:.2f}M (100%)")
# Output:
# --- Base Model ---
# Total Parameters: 124.44M
# Trainable Parameters: 124.44M (100%)
# Define the LoRA configuration
lora_config = LoraConfig(
r=16, # Rank of the update matrices. Lower = fewer parameters.
lora_alpha=32, # A scaling factor.
target_modules=["c_attn"], # Target only the attention query, key, value projections.
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# Apply the LoRA config to the base model
peft_model = get_peft_model(base_model, lora_config)
# --- After PEFT ---
peft_total_params = sum(p.numel() for p in peft_model.parameters())
peft_trainable_params = sum(p.numel() for p in peft_model.parameters() if p.requires_grad)
print(f"\n--- PEFT Model (LoRA Applied) ---")
print(f"Total Parameters: {peft_total_params / 1e6:.2f}M")
print(f"Trainable Parameters: {peft_trainable_params / 1e6:.2f}M")
print(f"Trainable %: {peft_trainable_params / peft_total_params * 100:.4f}%")
print("\nModel structure with LoRA layers:")
peft_model.print_trainable_parameters()
# Output:
# --- PEFT Model (LoRA Applied) ---
# Total Parameters: 125.18M
# Trainable Parameters: 0.79M
# Trainable %: 0.6291%
#
# Model structure with LoRA layers:
# trainable params: 786,432 || all params: 125,178,240 || trainable%: 0.6282
In PPO, you need a `policy_model` (which you are training) and a `ref_model` (a frozen reference to calculate KL divergence against). Without PEFT, you would need to load two full models into memory.
With PEFT, you only need one base model in memory!
The `PPOTrainer`'s `null_ref_context` manager handles this adapter-switching automatically. This dramatically reduces memory requirements, making RLHF accessible to many more users.
The PEFT support block is a powerful feature that swaps out the entire model for small, efficient adapters. It uses `get_peft_model` to inject these adapters, resulting in a model where over 99% of the parameters are frozen, drastically reducing the memory and compute needed for fine-tuning while still achieving strong performance.
Purpose: This section of the `__init__` method decides how to create the `ref_model` (reference model). The reference model is a crucial component in PPO training for RLHF. It's a frozen version of the original language model used to calculate a KL-divergence penalty, which prevents the policy model from deviating too much from sensible language and improves training stability.
# If a reference model is explicitly passed by the user, use it.
if ref_model:
self.ref_model = ref_model
# ๐ฅ KEY LOGIC: If using a PEFT model, we DON'T need a separate reference model in memory.
# Setting it to `None` signals the trainer to use a special adapter-toggling strategy.
elif self.is_peft_model:
self.ref_model = None
# Otherwise (not using PEFT and no ref_model passed), create a full, memory-intensive copy.
else:
self.ref_model = create_reference_model(self.policy_model)
It enables a massive memory-saving strategy. Instead of loading two multi-billion parameter models into memory (one for the policy, one for reference), we load only one. This single model plays both roles:
This switching is handled automatically later in the trainer by the `null_ref_context` context manager, making the process seamless.
Let's prove that we can get both policy and reference outputs from a single PEFT model object, demonstrating why a second model copy is unnecessary.
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
import torch
# Load the original, base gpt2 model
base_model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Create a PEFT model by applying LoRA adapters. This is our `policy_model`.
lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["c_attn"], task_type="CAUSAL_LM")
policy_model = get_peft_model(base_model, lora_config)
# Dummy input for demonstration
inputs = tokenizer("The goal of PPO is to", return_tensors="pt")
print("--- 1. Acting as POLICY MODEL (Adapters ON by default) ---")
# After `get_peft_model`, the LoRA adapters are active by default.
# There is no `enable_adapter` method; they are enabled unless in a `disable_adapter` context.
with torch.no_grad():
policy_logits = policy_model(**inputs).logits
# The output is influenced by the small, trainable LoRA layers.
print(f"Sample policy logit for the last token: {policy_logits[0, -1, 200].item():.4f}")
# Example output might be something like: 1.9873
print("\n--- 2. Acting as REFERENCE MODEL (Adapters OFF) ---")
# We temporarily disable the adapters using a context manager.
# The model now behaves exactly like the original base model.
with policy_model.disable_adapter():
with torch.no_grad():
ref_logits_from_peft_model = policy_model(**inputs).logits
print(f"Sample ref logit (from policy model): {ref_logits_from_peft_model[0, -1, 200].item():.4f}")
# Example output: Sample ref logit (from policy model): 2.3145
print("\n--- 3. Verifying with ORIGINAL BASE MODEL ---")
# For proof, let's run the original base model that never saw the adapters.
with torch.no_grad():
original_base_logits = base_model(**inputs).logits
print(f"Sample logit from original base model: {original_base_logits[0, -1, 200].item():.4f}")
# Example output: Sample logit from original base model: 2.3145
# --- Verification ---
are_they_equal = torch.allclose(ref_logits_from_peft_model, original_base_logits)
print(f"\nAre the reference logits and original logits identical? -> {are_they_equal}")
# Output: Are the reference logits and original logits identical? -> True
The example proves it: by simply disabling the adapters, the `policy_model` produces the exact same output as the original `base_model`. We successfully simulated having a reference model without ever creating a second copy.
Setting self.ref_model = None
is the key that unlocks this massive ~50% memory saving, making large-scale RLHF dramatically more accessible.
Purpose: This context manager is the mechanism that brings the memory-saving strategy (discussed in Q3) to life. Its job is to temporarily make the policy model behave like the reference model *just for the moment when the reference logits are needed*. It intelligently handles two main PEFT scenarios: single-adapter training and multi-adapter training.
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
# This is the main scenario: using PEFT with a single adapter.
# `disable_adapter()` is itself a context manager that turns adapters off inside the `with`
# block and automatically turns them back on upon exit.
with (
self.accelerator.unwrap_model(self.model.policy).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
# This handles the advanced scenario: using two different adapters.
# It activates the specified reference adapter upon entering the `with` block.
if self.ref_adapter_name:
self.model.policy.set_adapter(self.ref_adapter_name)
# This is where the code inside the `with` block runs (e.g., the forward pass).
yield
# After the code runs, switch back to the main policy adapter.
if self.ref_adapter_name:
self.model.policy.set_adapter(self.model_adapter_name or "default")
is_peft_model
is `True`.ref_adapter_name
is `None`.is_peft_model
is `True`.ref_adapter_name
is provided (e.g., 'my-ref-adapter').is_peft_model
is `False`.Let's create a PEFT model with two distinct adapters and a dummy trainer to see how the context manager works in both scenarios.
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType
import torch
from contextlib import contextmanager, nullcontext
# --- Setup ---
base_model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Create a PEFT model and add the first adapter, which is named "default" automatically
peft_model = get_peft_model(base_model, LoraConfig(r=8, lora_alpha=16, target_modules=["c_attn"], task_type=TaskType.CAUSAL_LM))
# Now, add a SECOND, different adapter and name it "ref_adapter"
peft_model.add_adapter("ref_adapter", LoraConfig(r=4, lora_alpha=8, target_modules=["c_attn"], task_type=TaskType.CAUSAL_LM))
# --- Dummy Trainer Class to hold state and the context manager logic ---
class DummyTrainer:
def __init__(self, model, is_peft, model_adapter, ref_adapter):
self.model = type("obj", (object,), {"policy": model})()
self.is_peft_model = is_peft
self.model_adapter_name = model_adapter
self.ref_adapter_name = ref_adapter
# This dummy accelerator has an `unwrap_model` method that mimics the real one by returning the model passed to it.
# The lambda now correctly accepts `s` (for the dummy object's self) and `m` (for the model).
self.accelerator = type("obj", (object,), {"unwrap_model": lambda s, m: m})()
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with (
self.accelerator.unwrap_model(self.model.policy).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.policy.set_adapter(self.ref_adapter_name)
yield
if self.ref_adapter_name:
self.model.policy.set_adapter(self.model_adapter_name or "default")
inputs = tokenizer("To be or not to be", return_tensors="pt")
print("--- SCENARIO A: Switching between 'default' and 'ref_adapter' ---")
trainer_with_named_adapters = DummyTrainer(model=peft_model, is_peft=True, model_adapter="default", ref_adapter="ref_adapter")
# Set the initial active adapter to the policy ('default')
peft_model.set_adapter("default")
print(f"Adapter before context: '{peft_model.active_adapter}'")
with torch.no_grad(): policy_logits = peft_model(**inputs).logits
# Use the context manager to switch to the reference adapter
with trainer_with_named_adapters.null_ref_context():
print(f"Adapter INSIDE context: '{peft_model.active_adapter}'")
with torch.no_grad(): ref_logits = peft_model(**inputs).logits
print(f"Adapter AFTER context: '{peft_model.active_adapter}'")
# --- Verification ---
are_they_equal = torch.allclose(policy_logits, ref_logits, atol=1e-4)
print(f"\nAre policy and ref logits the same? -> {are_they_equal}")
print(f"Sample 'default' adapter logit: {policy_logits[0, -1, 100].item():.4f}")
print(f"Sample 'ref_adapter' logit: {ref_logits[0, -1, 100].item():.4f}")
# --- SCENARIO A: Switching between 'default' and 'ref_adapter' ---
# Adapter before context: 'default'
# Adapter INSIDE context: 'ref_adapter'
# Adapter AFTER context: 'default'
#
# Are policy and ref logits the same? -> False
# Sample 'default' adapter logit: -5.7001
# Sample 'ref_adapter' logit: -5.7029
print("\n--- SCENARIO B: Disabling the 'default' adapter to get base model output ---")
trainer_with_one_adapter = DummyTrainer(model=peft_model, is_peft=True, model_adapter="default", ref_adapter=None)
# Activate the policy adapter
peft_model.set_adapter("default")
print(f"Adapter before context: '{peft_model.active_adapter}'")
with torch.no_grad(): policy_logits_2 = peft_model(**inputs).logits
# Use context manager, which will now disable adapters instead of switching
with trainer_with_one_adapter.null_ref_context():
print(f"Active adapters INSIDE context: {peft_model.active_adapters}")
with torch.no_grad(): ref_logits_2 = peft_model(**inputs).logits
print(f"Adapter AFTER context: '{peft_model.active_adapter}'")
# Get logits from the original base model for comparison
with torch.no_grad(): base_model_logits = base_model(**inputs).logits
# --- Verification ---
are_they_equal_2 = torch.allclose(ref_logits_2, base_model_logits)
print(f"\nAre disabled-adapter and base-model logits the same? -> {are_they_equal_2}")
print(f"Sample policy logit: {policy_logits_2[0,-1,100].item():.4f}")
print(f"Sample disabled-adapter logit: {ref_logits_2[0,-1,100].item():.4f}")
print(f"Sample base-model logit: {base_model_logits[0,-1,100].item():.4f}")
# --- SCENARIO B: Disabling the 'default' adapter to get base model output ---
# Adapter before context: 'default'
# Active adapters INSIDE context: []
# Adapter AFTER context: 'default'
#
# Are disabled-adapter and base-model logits the same? -> True
# Sample policy logit: -5.7001
# Sample disabled-adapter logit: -5.6983
# Sample base-model logit: -5.6983
The null_ref_context
method is a powerful utility that makes PPO training with PEFT both flexible and efficient. It correctly handles the two primary ways you might use adapters for the reference model: either by disabling the policy adapter to fall back to the base model, or by switching to a completely different adapter designated for reference calculations. This all happens automatically, ensuring the right model state is used at the right time.
Purpose: This entire block of code doesn't do any training itself. Instead, it's the critical setup phase. It prepares all the necessary variables, configurations, data loaders, and tracking systems needed before the main PPO training loop can begin. It's like a pre-flight checklist for the training process.
def repeat_generator():
while True:
yield from dataloader
iter_dataloader = iter(repeat_generator())
What it does: Unlike traditional training that goes through a dataset epoch by epoch, PPO training runs for a fixed number of "updates" or "episodes". It constantly needs fresh batches of data. This code creates an infinite data generator. The `while True` loop ensures that whenever the `dataloader` runs out of data, it just starts over from the beginning. This way, the training loop can simply call `next(iter_dataloader)` forever without ever getting a "StopIteration" error.
Example: Imagine your `dataloader` has just two batches: `["prompt A", "prompt B"]` and `["prompt C", "prompt D"]`. The `iter_dataloader` would yield:
generation_config = GenerationConfig(
max_new_tokens=args.response_length, # e.g., 50
temperature=(args.temperature + 1e-7), # e.g., 0.7
top_k=0.0,
top_p=1.0,
do_sample=True,
)
What it does: This object tells the language model exactly *how* it should generate text.
max_new_tokens
: The maximum length of the response to generate.do_sample=True
: This tells the model to be creative instead of "greedy". A greedy model always picks the single word with the highest probability, which can be repetitive. Sampling means it picks from a distribution of possible words.temperature
: Controls the "craziness" of the sampling. A high temperature (e.g., > 1.0) makes the model's choices more random and creative (and more likely to make mistakes). A low temperature (e.g., 0.7) makes the output safer and more focused. A temperature of 0 is the same as greedy decoding.top_k
and top_p
: Other ways to control sampling, but setting them to 0.0 and 1.0 respectively effectively disables them in favor of temperature-based sampling.Numerical Example (Temperature): Imagine the model has to choose the next word and its top 3 choices have these raw scores (logits): `[2.0, 1.5, 0.5]`.
import torch
import torch.nn.functional as F
logits = torch.tensor([2.0, 1.5, 0.5])
# Low temperature (more confident, less random)
probs_low_temp = F.softmax(logits / 0.5, dim=-1)
# -> tensor([0.8430, 0.1425, 0.0145]) - Almost certainly picks the first word.
# High temperature (less confident, more random)
probs_high_temp = F.softmax(logits / 1.5, dim=-1)
# -> tensor([0.4578, 0.3340, 0.2082]) - Might pick any of the top 3.
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
approxkl_stats = torch.zeros(stats_shape, device=device)
pg_loss_stats = torch.zeros(stats_shape, device=device)
vf_loss_stats = torch.zeros(stats_shape, device=device)
# ... and others
What it does: The trainer needs to keep track of many important metrics to see how well it's learning. This code creates empty "storage containers" (tensors full of zeros) to hold these statistics. Each container's size (`stats_shape`) is designed to hold a value for every single optimization step within a PPO update.
approxkl_stats
: Stores the KL divergence, a measure of how much the policy is changing.pg_loss_stats
: Stores the policy gradient loss (the "actor's" loss).vf_loss_stats
: Stores the value function loss (the "critic's" loss).self.state.global_step = 0
self.state.episode = 0
self.state.max_steps = args.num_total_batches
# ...
if args.save_steps < 1:
self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
else:
self.state.save_steps = args.save_steps
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
What it does: This initializes the `TrainerState` object, which is like the main scoreboard for the entire training run.
on_train_begin(...)
: This is a call to any special functions (callbacks) that need to run right before training starts, like setting up a connection to a logging service like Weights & Biases.This whole section is the essential "boot-up" sequence for the trainer. It ensures that data will always be available, the model knows how to generate text, empty containers are ready to record performance, and the main training scoreboard is initialized and ready to go.
Purpose: This section implements the core PPO algorithm - the heart of the training process. PPO (Proximal Policy Optimization) is a policy gradient method that learns to improve a language model's responses by maximizing expected rewards while preventing the policy from changing too drastically.
The fundamental goal is to maximize the expected reward:
$$J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[R(\tau)]$$
Where:
The policy gradient is:
$$\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[\sum_{t=0}^T \nabla_\theta \log \pi_\theta(a_t|s_t) \cdot A_t]$$
Where $A_t$ is the advantage function (how much better action $a_t$ is compared to average).
PPO uses importance sampling to reuse data from an old policy $\pi_{\theta_{old}}$ to update a new policy $\pi_\theta$:
$$r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$$
The surrogate objective becomes:
$$L^{CPI}(\theta) = \mathbb{E}_t[r_t(\theta) \cdot A_t]$$
To prevent large policy updates, PPO clips the ratio:
$$L^{CLIP}(\theta) = \mathbb{E}_t[\min(r_t(\theta) \cdot A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \cdot A_t)]$$
Where $\epsilon$ is the clipping parameter (typically 0.2).
# Generate responses using current policy
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
query_responses, logitss = batch_generation(
unwrapped_model.policy,
queries,
args.local_rollout_forward_batch_size,
processing_class.pad_token_id,
generation_config,
)
# Extract just the response part (excluding the input query)
response = query_response[:, context_length:]
# Compute log probabilities for the generated tokens
logprob = selective_log_softmax(logits, response)
Mathematical Meaning: This computes $\log \pi_\theta(a_t|s_t)$ for each token in the response. The model generates text and we calculate how likely each generated token was according to the current policy.
# Get reference policy probabilities (ฯ_ฮธ_old)
if ref_policy is None:
# PEFT case: temporarily disable adapters to get base model behavior
with self.null_ref_context():
ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
else:
# Separate reference model case
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.temperature + 1e-7
ref_logprob = selective_log_softmax(ref_logits, response)
Mathematical Meaning: This computes $\log \pi_{\theta_{old}}(a_t|s_t)$ - the reference policy's log probabilities for the same tokens. This is crucial for calculating the importance sampling ratio.
import torch
import torch.nn.functional as F
import numpy as np
# Simulated scenario: Model completing "The weather is"
query = "The weather is"
response_tokens = ["sunny", "and", "warm"] # Generated response
vocab_size = 50257 # GPT-2 vocab size
# Simulate token IDs
sunny_id, and_id, warm_id = 19989, 290, 5814
# Example logits from current policy (higher = more likely)
policy_logits = torch.tensor([
[2.1, 1.8, 0.9], # logits for ["sunny", "and", "warm"]
])
# Example logits from reference policy (slightly different)
ref_logits = torch.tensor([
[2.0, 1.7, 0.8], # reference logits for same tokens
])
# Convert to probabilities and then log probabilities
policy_probs = F.softmax(policy_logits, dim=-1)
ref_probs = F.softmax(ref_logits, dim=-1)
policy_log_probs = torch.log(policy_probs)
ref_log_probs = torch.log(ref_probs)
print("Policy Probabilities:", policy_probs)
print("Reference Probabilities:", ref_probs)
print("Policy Log Probs:", policy_log_probs)
print("Reference Log Probs:", ref_log_probs)
# Output:
# Policy Probabilities: tensor([[0.5207, 0.3843, 0.0950]])
# Reference Probabilities: tensor([[0.5134, 0.3797, 0.1069]])
# Policy Log Probs: tensor([[-0.6528, -0.9559, -2.3533]])
# Reference Log Probs: tensor([[-0.6671, -0.9684, -2.2356]])
# Calculate importance sampling ratios: r_t = ฯ_ฮธ(a_t|s_t) / ฯ_ฮธ_old(a_t|s_t)
# In log space: log(r_t) = log ฯ_ฮธ(a_t|s_t) - log ฯ_ฮธ_old(a_t|s_t)
log_ratios = policy_log_probs - ref_log_probs
ratios = torch.exp(log_ratios)
print("Log Ratios:", log_ratios)
print("Ratios:", ratios)
# Example advantage values (how good each token choice was)
advantages = torch.tensor([[0.5, -0.2, 0.8]]) # Positive = good, negative = bad
# Unclipped surrogate loss: L^CPI = r_t * A_t
unclipped_loss = ratios * advantages
print("Unclipped Surrogate Loss:", unclipped_loss)
# PPO clipped loss with ฮต = 0.2
epsilon = 0.2
clipped_ratios = torch.clamp(ratios, 1 - epsilon, 1 + epsilon)
clipped_loss = clipped_ratios * advantages
print("Clipped Ratios:", clipped_ratios)
print("Clipped Surrogate Loss:", clipped_loss)
# Final PPO loss: min(unclipped, clipped)
ppo_loss = torch.min(unclipped_loss, clipped_loss)
print("Final PPO Loss:", ppo_loss)
# Output:
# Log Ratios: tensor([[ 0.0143, 0.0125, -0.1177]])
# Ratios: tensor([[1.0144, 1.0126, 0.8889]])
# Unclipped Surrogate Loss: tensor([[ 0.5072, -0.2025, 0.7111]])
# Clipped Ratios: tensor([[1.0144, 1.0126, 0.8889]])
# Clipped Surrogate Loss: tensor([[ 0.5072, -0.2025, 0.7111]])
# Final PPO Loss: tensor([[ 0.5072, -0.2025, 0.7111]])
The PPO algorithm is brilliant because it:
This code section implements the core mathematical foundation of PPO: generating responses, computing probability ratios between current and reference policies, and preparing the data needed for the clipped surrogate objective. The beauty lies in how it transforms abstract mathematical concepts into practical, working code that can train language models to be more helpful and aligned with human preferences.