๐Ÿš€ PPO Trainer Code Analysis

Interactive exploration of the Proximal Policy Optimization trainer implementation

821
Lines of Code
3
Q&A Items
~15
Methods
trl/trl/trainer/ppo_trainer.py
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import math
import os
import textwrap
import time
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from pathlib import Path
from typing import Optional, Union

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from accelerate import Accelerator
from accelerate.utils import broadcast, gather_object
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
    BaseImageProcessor,
    DataCollatorWithPadding,
    FeatureExtractionMixin,
    GenerationConfig,
    PreTrainedTokenizerBase,
    ProcessorMixin,
    Trainer,
    TrainerCallback,
    TrainerControl,
    is_wandb_available,
)
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
from transformers.utils import is_peft_available, is_rich_available

from ..core import masked_mean, masked_whiten
from ..models import create_reference_model
from ..models.utils import unwrap_model_for_generation
from .ppo_config import PPOConfig
from .utils import (
    OnlineTrainerState,
    batch_generation,
    disable_dropout_in_model,
    empty_cache,
    exact_div,
    first_true_indices,
    forward,
    generate_model_card,
    get_comet_experiment_url,
    get_reward,
    log_table_to_comet_experiment,
    peft_module_casting_to_bf16,
    prepare_deepspeed,
    print_rich_table,
    selective_log_softmax,
    truncate_response,
)


if is_peft_available():
    from peft import PeftConfig, PeftModel, get_peft_model

if is_wandb_available():
    import wandb


INVALID_LOGPROB = 1.0


# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29
# we did this we can do a single `model = accelerator.prepare(model)`
class PolicyAndValueWrapper(nn.Module):
    def __init__(self, policy, value_model) -> None:
        super().__init__()
        self.policy = policy
        self.value_model = value_model
        self.critic_backbone = getattr(value_model, value_model.base_model_prefix)

    def forward(self, **kwargs):
        output = self.critic_backbone(**kwargs)
        logits = self.value_model.score(output.hidden_states[-1])
        return self.policy(**kwargs), logits


class PPOTrainer(Trainer):
    _tag_names = ["trl", "ppo"]

    def __init__(
        self,
        args: PPOConfig,
        processing_class: Optional[
            Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
        ],
        model: nn.Module,
        ref_model: Optional[nn.Module],
        reward_model: nn.Module,
        train_dataset: Dataset,
        value_model: nn.Module,
        data_collator: Optional[DataCollatorWithPadding] = None,
        eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
        # less commonly used
        optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        callbacks: Optional[list[TrainerCallback]] = None,
        peft_config: Optional["PeftConfig"] = None,
    ) -> None:
        if ref_model is model:
            raise ValueError(
                "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
                "same as `model`, you must make a copy of it, or `None` if you use peft."
            )

        self.args = args
        self.processing_class = processing_class
        self.policy_model = model

        # Define the collator if not provided
        if data_collator is None:
            data_collator = DataCollatorWithPadding(self.processing_class)

        # Handle stop token settings: update policy model's generation_config to use provided stop token
        if args.stop_token and args.stop_token_id:
            raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
        elif args.stop_token:
            if args.stop_token == "eos":
                self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
            else:
                raise ValueError(
                    f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
                )
        else:
            self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id  # None or int

        # Check that the kl estimator is valid
        if self.args.kl_estimator not in {"k1", "k3"}:
            raise ValueError(
                "kl_estimator must be either 'k1' (straightforward, unbiased) or 'k3' (lower variance, unbiased, "
                "appears to be a strictly better estimator). See "
                "[Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for details."
            )

        # peft support
        if not is_peft_available() and peft_config is not None:
            raise ImportError(
                "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
            )
        elif is_peft_available() and peft_config is not None:
            # if model is a peft model and we have a peft_confg, we merge and unload it first
            if isinstance(self.policy_model, PeftModel):
                self.policy_model = self.policy_model.merge_and_unload()

            # get peft model with the given config
            self.policy_model = get_peft_model(self.policy_model, peft_config)
            if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
                peft_module_casting_to_bf16(self.policy_model)

        self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
        self.model_adapter_name = args.model_adapter_name
        self.ref_adapter_name = args.ref_adapter_name

        if ref_model:
            self.ref_model = ref_model
        elif self.is_peft_model:
            self.ref_model = None
        else:
            self.ref_model = create_reference_model(self.policy_model)

        self.reward_model = reward_model
        self.train_dataset = train_dataset
        self.train_dataset_len = len(train_dataset)
        self.value_model = value_model
        self.data_collator = data_collator
        self.eval_dataset = eval_dataset
        self.optimizer, self.lr_scheduler = optimizers
        self.optimizer_cls_and_kwargs = None  # needed for transformers >= 4.47

        #########
        # calculate various batch sizes
        #########
        if args.total_episodes is None:  # allow the users to define episodes in terms of epochs.
            args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
        accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
        self.accelerator = accelerator
        args.world_size = accelerator.num_processes
        args.local_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps
        args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
        args.batch_size = int(args.local_batch_size * args.world_size)
        args.mini_batch_size = exact_div(
            args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
        )
        args.local_mini_batch_size = exact_div(
            args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
        )
        if args.whiten_rewards:
            assert args.local_mini_batch_size >= 8, (
                f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
            )
        # `per_rank_rollout_batch_size` is our `args.local_batch_size`
        # `per_rank_minibatch_size` is our `args.local_mini_batch_size`
        args.num_total_batches = math.ceil(
            args.total_episodes / args.batch_size
        )  # we may train for more than `total_episodes`
        time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
        time_int = broadcast(time_tensor, 0).item()  # avoid different timestamps across processes
        args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
        self.local_seed = args.seed + accelerator.process_index * 100003  # Prime
        if args.num_sample_generations > 0:
            self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
        self.local_dataloader_batch_size = args.local_batch_size

        #########
        # setup model, optimizer, and others
        #########
        for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
            if module is not None:
                disable_dropout_in_model(module)
        self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
        self.model.config = self.policy_model.config  # needed for pushing to hub
        self.create_optimizer_and_scheduler(
            num_training_steps=args.num_total_batches
        )  # note that we are calling `self.lr_scheduler.step()` manually only at the batch level

        #########
        ### trainer specifics
        #########
        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
        self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
        self.callback_handler = CallbackHandler(
            self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
        )
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
        self.control = TrainerControl()
        self.state = OnlineTrainerState(
            is_local_process_zero=self.is_local_process_zero(),
            is_world_process_zero=self.is_world_process_zero(),
            stateful_callbacks=[
                cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
            ],
        )
        self.current_flos = 0
        self.hp_search_backend = None
        self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
        self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
        # Create distant repo and output directory if needed
        self.hub_model_id = None
        if self.args.push_to_hub:
            self.init_hf_repo()
        if self.args.should_save:
            os.makedirs(self.args.output_dir, exist_ok=True)

        # Add tags for models that have been loaded with the correct transformers version
        if hasattr(self.model, "add_model_tags"):
            self.model.add_model_tags(self._tag_names)

        #########
        ### setup dataloader
        #########
        self.dataloader = DataLoader(
            self.train_dataset,
            batch_size=self.local_dataloader_batch_size,
            shuffle=True,
            collate_fn=self.data_collator,
            drop_last=True,  # needed; otherwise the last batch will be of ragged shape
        )
        # sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
        # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
        torch.manual_seed(args.seed)
        self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
        torch.manual_seed(self.local_seed)  # reset the local seed again

        self.eval_dataloader = DataLoader(
            self.eval_dataset,
            batch_size=args.per_device_eval_batch_size,
            collate_fn=self.data_collator,
            drop_last=True,
        )  # no need to shuffle eval dataset
        self.eval_dataloader = accelerator.prepare(self.eval_dataloader)

        if self.is_deepspeed_enabled:
            self.reward_model = prepare_deepspeed(
                self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
            )

            if self.ref_model is None:
                if not self.is_peft_model:
                    raise ValueError("No reference model and model is not a Peft model.")
            else:
                self.ref_model = prepare_deepspeed(
                    self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
                )
        else:
            if self.ref_model is None:
                if not self.is_peft_model:
                    raise ValueError("No reference model and model is not a Peft model.")
            else:
                self.ref_model = self.ref_model.to(self.accelerator.device)
            self.reward_model = self.reward_model.to(self.accelerator.device)

    def get_train_dataloader(self) -> DataLoader:
        return self.dataloader

    def get_eval_dataloader(self) -> DataLoader:
        return self.eval_dataloader

    @contextmanager
    def null_ref_context(self):
        """Context manager for handling null reference model (that is, peft adapter manipulation)."""
        with (
            self.accelerator.unwrap_model(self.model.policy).disable_adapter()
            if self.is_peft_model and not self.ref_adapter_name
            else nullcontext()
        ):
            if self.ref_adapter_name:
                self.model.policy.set_adapter(self.ref_adapter_name)
            yield
            if self.ref_adapter_name:
                self.model.policy.set_adapter(self.model_adapter_name or "default")

    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
        backup_model = self.model
        self.model = self.model.policy  # save only the policy

        if self.is_deepspeed_enabled:
            backup_deepspeed = self.deepspeed
            self.deepspeed = self.model

        super().save_model(output_dir, _internal_call)

        self.model = backup_model

        if self.is_deepspeed_enabled:
            self.deepspeed = backup_deepspeed

    def train(self):
        args = self.args
        accelerator = self.accelerator
        optimizer = self.optimizer
        model = self.model
        ref_policy = self.ref_model
        reward_model = self.reward_model
        processing_class = self.processing_class
        dataloader = self.dataloader
        device = accelerator.device

        def repeat_generator():
            while True:
                yield from dataloader

        iter_dataloader = iter(repeat_generator())
        generation_config = GenerationConfig(
            max_new_tokens=args.response_length,
            temperature=(args.temperature + 1e-7),
            top_k=0.0,
            top_p=1.0,
            do_sample=True,
        )

        accelerator.print("===training policy===")
        start_time = time.time()
        stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
        approxkl_stats = torch.zeros(stats_shape, device=device)
        pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
        pg_loss_stats = torch.zeros(stats_shape, device=device)
        vf_loss_stats = torch.zeros(stats_shape, device=device)
        vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
        entropy_stats = torch.zeros(stats_shape, device=device)
        ratio_stats = torch.zeros(stats_shape, device=device)
        model.train()

        # trainer state initialization
        self.state.global_step = 0
        self.state.episode = 0
        self.state.max_steps = args.num_total_batches
        self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
        # Compute absolute values for logging, eval, and save if given as ratio
        if args.logging_steps is not None:
            if args.logging_steps < 1:
                self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
            else:
                self.state.logging_steps = args.logging_steps
        if args.eval_steps is not None:
            if args.eval_steps < 1:
                self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
            else:
                self.state.eval_steps = args.eval_steps
        if args.save_steps is not None:
            if args.save_steps < 1:
                self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
            else:
                self.state.save_steps = args.save_steps
        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

        # backward compatibility
        if self.is_deepspeed_enabled:
            self.deepspeed = self.model
            self.model_wrapped = self.model

        for update in range(1, args.num_total_batches + 1):
            self.state.episode += 1 * args.batch_size
            data = next(iter_dataloader)
            with torch.no_grad():
                queries = data["input_ids"].to(device)
                context_length = queries.shape[1]
                responses = []
                postprocessed_responses = []
                logprobs = []
                ref_logprobs = []
                scores = []
                sequence_lengths = []
                values = []
                with unwrap_model_for_generation(
                    self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
                ) as unwrapped_model:
                    query_responses, logitss = batch_generation(
                        unwrapped_model.policy,
                        queries,
                        args.local_rollout_forward_batch_size,
                        processing_class.pad_token_id,
                        generation_config,
                    )

                for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
                    query = queries[i : i + args.local_rollout_forward_batch_size]
                    query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
                    response = query_response[:, context_length:]
                    logits = logitss[i : i + args.local_rollout_forward_batch_size]
                    logprob = selective_log_softmax(logits, response)
                    del logits
                    empty_cache()

                    if ref_policy is None:
                        with self.null_ref_context():
                            ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
                    else:
                        ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
                    ref_logits = ref_output.logits[:, context_length - 1 : -1]
                    ref_logits /= args.temperature + 1e-7
                    ref_logprob = selective_log_softmax(ref_logits, response)
                    del ref_output, ref_logits
                    empty_cache()

                    # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
                    postprocessed_response = response
                    if self.stop_token_id is not None:  # handle the edge case when stop_token_id exists but is 0
                        postprocessed_response = truncate_response(
                            self.stop_token_id, processing_class.pad_token_id, response
                        )

                    # Response Processing 2. run reward model on the truncated responses
                    postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
                    sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
                    unwrapped_value_model = accelerator.unwrap_model(model).value_model
                    full_value, _, _ = get_reward(
                        unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
                    )
                    value = full_value[:, context_length - 1 : -1].squeeze(-1)
                    _, score, _ = get_reward(
                        reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
                    )

                    responses.append(response)
                    postprocessed_responses.append(postprocessed_response)
                    logprobs.append(logprob)
                    ref_logprobs.append(ref_logprob)
                    sequence_lengths.append(sequence_length)
                    scores.append(score)
                    values.append(value)
                responses = torch.cat(responses, 0)
                postprocessed_responses = torch.cat(postprocessed_responses, 0)
                logprobs = torch.cat(logprobs, 0)
                ref_logprobs = torch.cat(ref_logprobs, 0)
                sequence_lengths = torch.cat(sequence_lengths, 0)
                scores = torch.cat(scores, 0)
                values = torch.cat(values, 0)
                del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
                empty_cache()
                gc.collect()

                # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
                # Completions not passing that filter will receive a lower score.
                contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
                if self.args.missing_eos_penalty is not None:
                    scores[~contain_eos_token] -= self.args.missing_eos_penalty
                # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")

                # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
                response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
                padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
                logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
                ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
                sequence_lengths_p1 = sequence_lengths + 1
                padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
                values = torch.masked_fill(values, padding_mask_p1, 0)

                # 4. compute rewards
                # Formula used by http://joschu.net/blog/kl-approx.html for the k1 and k3 estimators
                logr = ref_logprobs - logprobs
                kl = -logr if args.kl_estimator == "k1" else (logr.exp() - 1) - logr  # Else statement is k3
                non_score_reward = -args.kl_coef * kl
                rewards = non_score_reward.clone()
                actual_start = torch.arange(rewards.size(0), device=rewards.device)
                actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
                rewards[[actual_start, actual_end]] += scores

                # 5. whiten rewards
                if args.whiten_rewards:
                    rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
                    rewards = torch.masked_fill(rewards, padding_mask_p1, 0)

                # 6. compute advantages and returns
                lastgaelam = 0
                advantages_reversed = []
                gen_length = responses.shape[1]
                for t in reversed(range(gen_length)):
                    nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
                    delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
                    lastgaelam = delta + args.gamma * args.lam * lastgaelam
                    advantages_reversed.append(lastgaelam)
                advantages = torch.stack(advantages_reversed[::-1], axis=1)
                returns = advantages + values
                advantages = masked_whiten(advantages, ~padding_mask)
                advantages = torch.masked_fill(advantages, padding_mask, 0)
                empty_cache()

            # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
            for ppo_epoch_idx in range(args.num_ppo_epochs):
                b_inds = np.random.permutation(args.local_batch_size)
                minibatch_idx = 0
                for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
                    mini_batch_end = mini_batch_start + args.local_mini_batch_size
                    mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
                    gradient_accumulation_idx = 0
                    for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
                        with accelerator.accumulate(model):
                            micro_batch_end = micro_batch_start + args.per_device_train_batch_size
                            micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
                            mb_advantage = advantages[micro_batch_inds]
                            mb_responses = responses[micro_batch_inds]
                            mb_query_responses = query_responses[micro_batch_inds]
                            mb_logprobs = logprobs[micro_batch_inds]
                            mb_return = returns[micro_batch_inds]
                            mb_values = values[micro_batch_inds]

                            output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
                            logits = output.logits[:, context_length - 1 : -1]
                            logits /= args.temperature + 1e-7
                            new_logprobs = selective_log_softmax(logits, mb_responses)
                            new_logprobs = torch.masked_fill(
                                new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
                            )
                            vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
                            vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
                            vpredclipped = torch.clamp(
                                vpred,
                                mb_values - args.cliprange_value,
                                mb_values + args.cliprange_value,
                            )
                            vf_losses1 = torch.square(vpred - mb_return)
                            vf_losses2 = torch.square(vpredclipped - mb_return)
                            vf_loss_max = torch.max(vf_losses1, vf_losses2)
                            vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
                            vf_clipfrac = masked_mean(
                                (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
                            )
                            logprobs_diff = new_logprobs - mb_logprobs
                            ratio = torch.exp(logprobs_diff)
                            pg_losses = -mb_advantage * ratio
                            pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
                            pg_loss_max = torch.max(pg_losses, pg_losses2)
                            pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
                            loss = pg_loss + args.vf_coef * vf_loss
                            accelerator.backward(loss)
                            optimizer.step()
                            optimizer.zero_grad()
                            with torch.no_grad():
                                pg_clipfrac = masked_mean(
                                    (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
                                )
                                prob_dist = torch.nn.functional.softmax(logits, dim=-1)
                                entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
                                approxkl = 0.5 * (logprobs_diff**2).mean()
                                approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
                                pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
                                    pg_clipfrac
                                )
                                pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
                                vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
                                vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
                                    vf_clipfrac
                                )
                                entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
                                ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
                        gradient_accumulation_idx += 1
                    minibatch_idx += 1
                    # del everything and empty cache
                    # fmt: off
                    del (
                        output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
                        vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
                        pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
                        mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
                    )
                    # fmt: on
                    empty_cache()
            with torch.no_grad():
                mean_kl = kl.sum(1).mean()
                mean_entropy = (-logprobs).sum(1).mean()
                mean_non_score_reward = non_score_reward.sum(1).mean()
                rlhf_reward = mean_non_score_reward + scores.mean()
                eps = int(self.state.episode / (time.time() - start_time))
                metrics = {}
                metrics["eps"] = eps
                metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
                metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
                metrics["objective/non_score_reward"] = (
                    self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
                )
                metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
                metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
                metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
                metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
                metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
                metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
                metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
                metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
                metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
                metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
                metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
                metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
                metrics["episode"] = self.state.episode
                self.state.epoch = self.state.episode / self.train_dataset_len  # used by self.log
                self.state.global_step += 1
                self.log(metrics)

            self.lr_scheduler.step()
            self.control = self.callback_handler.on_step_end(args, self.state, self.control)
            if self.control.should_save:
                self._save_checkpoint(model, trial=None)
                self.control = self.callback_handler.on_save(self.args, self.state, self.control)
            del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
            empty_cache()
            gc.collect()

            if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
                self.generate_completions(sampling=True)
                empty_cache()
            del (
                query_responses,
                responses,
                postprocessed_responses,
                logprobs,
                ref_logprobs,
                values,
                sequence_lengths,
                contain_eos_token,
                sequence_lengths_p1,
                response_idxs,
                padding_mask,
                padding_mask_p1,
                rewards,
                actual_start,
                actual_end,
                advantages,
                returns,
            )
            empty_cache()

        # HF trainer specifics
        self.control = self.callback_handler.on_train_end(args, self.state, self.control)
        if self.control.should_save:
            self._save_checkpoint(model, trial=None, metrics=None)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

    def generate_completions(self, sampling: bool = False):
        args = self.args
        processing_class = self.processing_class
        generation_config = GenerationConfig(
            max_new_tokens=self.args.response_length,
            temperature=(0.01 + 1e-7),
            top_k=0.0,
            top_p=1.0,
            do_sample=True,
        )

        table = defaultdict(list)
        with unwrap_model_for_generation(
            self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
        ) as unwrapped_model:
            for batch in self.eval_dataloader:
                query = batch["input_ids"]
                with torch.no_grad():
                    context_length = query.shape[1]
                    query_response, _ = batch_generation(
                        unwrapped_model.policy,
                        query,
                        query.shape[0],
                        processing_class.pad_token_id,
                        generation_config,
                    )
                    response = query_response[:, context_length:]
                    postprocessed_response = response
                    if self.stop_token_id is not None:  # handle the edge case when stop_token_id exists but is 0
                        postprocessed_response = truncate_response(
                            self.stop_token_id, processing_class.pad_token_id, response
                        )
                    table["query"].extend(
                        gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
                    )
                    table["model response"].extend(
                        gather_object(processing_class.batch_decode(postprocessed_response))
                    )

                    postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
                    _, score, _ = get_reward(
                        self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
                    )
                    table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())

                if sampling:
                    break
        df = pd.DataFrame(table)

        if self.accelerator.is_main_process:
            if is_rich_available():
                print_rich_table(df.iloc[0 : 0 + 5])
            if "wandb" in args.report_to:
                import wandb

                if wandb.run is not None:
                    wandb.log({"completions": wandb.Table(dataframe=df)})

            if "comet_ml" in args.report_to:
                log_table_to_comet_experiment(
                    name="completions.csv",
                    table=df,
                )

    # Ensure the model card is saved along with the checkpoint
    def _save_checkpoint(self, model, trial):
        if self.args.hub_model_id is None:
            model_name = Path(self.args.output_dir).name
        else:
            model_name = self.args.hub_model_id.split("/")[-1]
        self.create_model_card(model_name=model_name)
        super()._save_checkpoint(model, trial)

    def create_model_card(
        self,
        model_name: Optional[str] = None,
        dataset_name: Optional[str] = None,
        tags: Union[str, list[str], None] = None,
    ):
        """
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            model_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the model.
            dataset_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the dataset used for training.
            tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
                Tags to be associated with the model card.
        """
        if not self.is_world_process_zero():
            return

        if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
            base_model = self.model.config._name_or_path
        else:
            base_model = None

        # normalize `tags` to a mutable set
        if tags is None:
            tags = set()
        elif isinstance(tags, str):
            tags = {tags}
        else:
            tags = set(tags)

        if hasattr(self.model.config, "unsloth_version"):
            tags.add("unsloth")

        tags.update(self._tag_names)

        citation = textwrap.dedent("""\
        @article{mziegler2019fine-tuning,
            title        = {{Fine-Tuning Language Models from Human Preferences}},
            author       = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
            year         = 2019,
            eprint       = {arXiv:1909.08593}
        }""")

        model_card = generate_model_card(
            base_model=base_model,
            model_name=model_name,
            hub_model_id=self.hub_model_id,
            dataset_name=dataset_name,
            tags=tags,
            wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
            comet_url=get_comet_experiment_url(),
            trainer_name="PPO",
            trainer_citation=citation,
            paper_title="Fine-Tuning Language Models from Human Preferences",
            paper_id="1909.08593",
        )

        model_card.save(os.path.join(self.args.output_dir, "README.md"))
    

๐Ÿ” Code Overview

Main Components

  • PPOTrainer - The main trainer class that inherits from Transformers' Trainer
  • PolicyAndValueWrapper - Wrapper class combining policy and value models
  • Key Methods:
    • __init__ - Initialization and setup
    • train - Main training loop
    • generate_completions - Generate and evaluate completions
    • null_ref_context - Context manager for reference model handling

Key Features

  • ๐ŸŽฏ PPO Algorithm - Proximal Policy Optimization implementation
  • ๐Ÿ”„ Multi-GPU Support - Via Accelerate and DeepSpeed
  • ๐Ÿ“Š Comprehensive Logging - Metrics tracking and visualization
  • ๐Ÿ”ง PEFT Integration - Parameter-efficient fine-tuning support
  • ๐ŸŽฎ Flexible Configuration - Extensive customization options

๐Ÿ”ง Key Methods Breakdown

1. Initialization (__init__)

Sets up the trainer with models, datasets, and configuration. Handles PEFT integration, batch size calculations, and accelerator setup.

2. Training Loop (train)

Main PPO training algorithm implementation:

  • Generate responses using the policy model
  • Compute rewards using the reward model
  • Calculate advantages using GAE (Generalized Advantage Estimation)
  • Update policy and value networks using PPO loss

3. Completion Generation (generate_completions)

Generates sample completions for evaluation and monitoring training progress.

4. Reference Model Context (null_ref_context)

Context manager for handling reference model when using PEFT adapters.

๐Ÿ’ฌ Q&A Section

Q1: Explain PolicyAndValueWrapper in full detail with a numerical LLM example - what do getattr, critic_backbone, and base_model_prefix do?

๐Ÿง  PolicyAndValueWrapper Deep Dive

Purpose: This class is a clever optimization used in PPO training. It wraps a policy model (the actor, which generates text) and a value model (the critic, which estimates the quality of the generated text) into a single `nn.Module`. Its main goal is to allow sharing the computationally expensive transformer backbone between the actor and critic, assuming they have a similar architecture. This significantly speeds up training and reduces memory usage.

๐Ÿ”ง Key Components Explained:

Class Definition
class PolicyAndValueWrapper(nn.Module):
    def __init__(self, policy, value_model):
        super().__init__()
        self.policy = policy
        self.value_model = value_model
        # The critic backbone is implicitly shared if policy and value models are the same base.
        # This design assumes a shared backbone for efficiency.
    
    def forward(self, **kwargs):
        # Request hidden states from the model's forward pass.
        kwargs['output_hidden_states'] = True
        
        # Run the policy model ONCE to get both logits and hidden states.
        # This is efficient because the expensive backbone computation is not repeated.
        policy_output = self.policy(**kwargs)
        
        # The `hidden_states` is a tuple of all layer outputs. We take the last one.
        last_hidden_state = policy_output.hidden_states[-1]
        
        # Pass the shared hidden states to the value model's scoring head.
        value_estimates = self.value_model.score(last_hidden_state)
        
        return policy_output, value_estimates

๐ŸŽฏ What Each Component Does:

1. base_model_prefix:

  • This is a string attribute on Hugging Face `PreTrainedModel` classes that tells you the name of the underlying core transformer model (the part without the task-specific head).
  • Examples: For GPT2LMHeadModel, it's "transformer". For BertForSequenceClassification, it's "bert". For T5ForConditionalGeneration, it's "encoder".

2. getattr(object, 'attribute_name'):

  • This is a standard Python function. getattr(x, 'y') is the same as writing x.y.
  • It's used here for flexibility. Instead of hardcoding value_model.transformer, it uses the base_model_prefix to dynamically fetch the correct backbone, making the wrapper work for various model architectures (like BERT, T5, etc.).

3. self.critic_backbone:

  • This variable stores the result of the `getattr` callโ€”it holds the core transformer layers (embeddings, attention blocks, layer norms) of the value model, but not its final value prediction head.
  • This is the part that will be shared to avoid re-computation.

๐Ÿš€ Real LLM Example: GPT-2 for PPO Training

Let's simulate a setup for Reinforcement Learning from Human Feedback (RLHF).

Step 1: Setup the Models
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn

# --- Policy Model (Actor) ---
# This is a standard causal LM that generates text.
policy_model = AutoModelForCausalLM.from_pretrained("gpt2")
print(f"Policy model is: {type(policy_model)}")
# A `base_model_prefix` of 'transformer' means its backbone is at `policy_model.transformer`
print(f"Policy model's base_model_prefix: '{policy_model.base_model_prefix}'")

# --- Value Model (Critic) ---
# We create a separate model that will learn to predict a scalar "value" or "score".
# It shares the same core architecture but will have a different head.
value_model = AutoModelForCausalLM.from_pretrained("gpt2")

# Define a custom "value head" that predicts a single number from the hidden states.
class ValueHead(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        # A simple linear layer to map hidden state to a scalar value
        self.value_head = nn.Linear(hidden_size, 1, bias=False)
    
    def forward(self, hidden_states):
        # hidden_states shape: (batch_size, seq_len, hidden_size)
        # We return a value for each token in the sequence.
        return self.value_head(hidden_states)

# Attach our custom head to the value model as the `.score` attribute.
value_model.score = ValueHead(value_model.config.hidden_size)
# We must also tell the wrapper where to find the backbone.
value_model.base_model_prefix = "transformer"
Step 2: Use the PolicyAndValueWrapper
# ๐Ÿšจ FIX NOTE: The previous version of this wrapper was buggy.
# The corrected version below is more efficient and ensures `output_hidden_states=True` is set.
class PolicyAndValueWrapper(nn.Module):
    def __init__(self, policy, value_model):
        super().__init__()
        self.policy = policy
        self.value_model = value_model
        # The critic backbone is implicitly shared if policy and value models are the same base.
        # This design assumes a shared backbone for efficiency.
    
    def forward(self, **kwargs):
        # Request hidden states from the model's forward pass.
        kwargs['output_hidden_states'] = True
        
        # Run the policy model ONCE to get both logits and hidden states.
        # This is efficient because the expensive backbone computation is not repeated.
        policy_output = self.policy(**kwargs)
        
        # The `hidden_states` is a tuple of all layer outputs. We take the last one.
        last_hidden_state = policy_output.hidden_states[-1]
        
        # Pass the shared hidden states to the value model's scoring head.
        value_estimates = self.value_model.score(last_hidden_state)
        
        return policy_output, value_estimates

# Instantiate the wrapper
model_wrapper = PolicyAndValueWrapper(policy_model, value_model)

# --- Numerical Example ---
tokenizer = AutoTokenizer.from_pretrained("gpt2")
text = "The capital of France is"
inputs = tokenizer(text, return_tensors="pt")
# inputs['input_ids'] is tensor([[ 464, 2159, 286, 6701, 318]])

# Perform a forward pass
with torch.no_grad():
    # The wrapper takes the same inputs as a standard Hugging Face model
    policy_output, value_estimates = model_wrapper(**inputs)

print("\n--- Outputs ---")
# 1. Policy Output (from the full policy model)
print(f"Policy Logits Shape: {policy_output.logits.shape}")
# -> torch.Size([1, 5, 50257]) (batch_size, sequence_length, vocab_size)

# 2. Value Estimates (from shared backbone + score head)
print(f"Value Estimates Shape: {value_estimates.shape}")
# -> torch.Size([1, 5, 1]) (batch_size, sequence_length, 1)
print(f"Value for each token:\n{value_estimates.squeeze()}")
# -> tensor([0.1521, 0.1833, 0.1685, 0.1587, 0.1764]) (Example values)

๐Ÿ”„ Why This (Corrected) Design is Efficient

In PPO, for every step, you need both the action probabilities (from the policy) and the state value (from the critic).

The Naive, Inefficient Way:

  1. policy_output = policy_model(inputs) (Full transformer pass)
  2. value_output = value_model(inputs) (Another full transformer pass)
  3. Problem: You run the expensive transformer layers twice on the exact same input!

The Efficient Wrapper Way:

  1. A single forward pass through `policy_model` with `output_hidden_states=True` produces both the final logits and all the internal hidden states.
  2. The final logits are used for the policy objective.
  3. The hidden states are immediately reused by the `value_model.score` head to get the value estimate.
  4. Benefit: The transformer backbone is computed only ONCE.

๐Ÿ’ก Key Takeaway

The PolicyAndValueWrapper is not just a container; it's an optimization pattern. The corrected example now properly demonstrates how to implement it efficiently by running the shared backbone once, which is critical for performant training of large language models with actor-critic methods.

Added on ${new Date().toLocaleString()}
Q2: What does this PEFT support part of the code do? Please add a numerical LLM based example as well.

๐Ÿš€ PEFT (Parameter-Efficient Fine-Tuning) Support Explained

Purpose: This section of the code integrates Hugging Face's `peft` library, allowing users to fine-tune massive language models using a fraction of the memory and computational power. Instead of training all the billions of parameters in a model, PEFT techniques like LoRA (Low-Rank Adaptation) freeze the original model and inject small, trainable "adapter" layers. This makes fine-tuning accessible on consumer hardware.

๐Ÿ”ง Code Breakdown:

PEFT Integration Logic
# 1. Check if PEFT library is installed if a config is provided
if not is_peft_available() and peft_config is not None:
    raise ImportError(...)

# 2. Main PEFT logic block
elif is_peft_available() and peft_config is not None:
    # If the model is already a PEFT model, merge the old adapters first
    # This gives a clean slate before applying the new config.
    if isinstance(self.policy_model, PeftModel):
        self.policy_model = self.policy_model.merge_and_unload()

    # ๐Ÿ”ฅ KEY LINE: Apply the new PEFT config (e.g., LoRA) to the base model
    self.policy_model = get_peft_model(self.policy_model, peft_config)
    
    # Compatibility fix for 4-bit models trained in bfloat16
    if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
        peft_module_casting_to_bf16(self.policy_model)

# 3. Set flags for later use in the trainer
self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
self.model_adapter_name = args.model_adapter_name
self.ref_adapter_name = args.ref_adapter_name

๐ŸŽฏ What Each Step Does:

  1. Prerequisite Check: Ensures the `peft` library is installed if the user intends to use it.
  2. Model Preparation:
    • merge_and_unload(): This is for edge cases. If you pass a model that has *already* been modified with PEFT, this function merges the existing adapter's weights into the base model and removes the adapter layers, effectively "baking in" the old changes to create a standard model again.
    • get_peft_model(): This is the core of PEFT. It takes the original, frozen transformer model and injects the small, trainable adapter layers (e.g., LoRA layers) into the places specified in the `peft_config`.
  3. State Tracking: Sets boolean flags like `is_peft_model` that other parts of the trainer (like the saving logic or reference model handling) use to change their behavior accordingly.

๐Ÿš€ Real LLM Example: Applying LoRA to GPT-2

Let's see the dramatic impact on the number of trainable parameters.

Step 1: Setup the Base Model
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, PeftModel

# Load a standard GPT-2 model
base_model = AutoModelForCausalLM.from_pretrained("gpt2")

# --- Before PEFT ---
total_params = sum(p.numel() for p in base_model.parameters())
trainable_params = sum(p.numel() for p in base_model.parameters() if p.requires_grad)
print(f"--- Base Model ---")
print(f"Total Parameters: {total_params / 1e6:.2f}M")
print(f"Trainable Parameters: {trainable_params / 1e6:.2f}M (100%)")
# Output:
# --- Base Model ---
# Total Parameters: 124.44M
# Trainable Parameters: 124.44M (100%)
Step 2: Apply LoRA with `get_peft_model`
# Define the LoRA configuration
lora_config = LoraConfig(
    r=16,  # Rank of the update matrices. Lower = fewer parameters.
    lora_alpha=32,  # A scaling factor.
    target_modules=["c_attn"], # Target only the attention query, key, value projections.
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# Apply the LoRA config to the base model
peft_model = get_peft_model(base_model, lora_config)

# --- After PEFT ---
peft_total_params = sum(p.numel() for p in peft_model.parameters())
peft_trainable_params = sum(p.numel() for p in peft_model.parameters() if p.requires_grad)

print(f"\n--- PEFT Model (LoRA Applied) ---")
print(f"Total Parameters: {peft_total_params / 1e6:.2f}M")
print(f"Trainable Parameters: {peft_trainable_params / 1e6:.2f}M")
print(f"Trainable %: {peft_trainable_params / peft_total_params * 100:.4f}%")
print("\nModel structure with LoRA layers:")
peft_model.print_trainable_parameters()
# Output:
# --- PEFT Model (LoRA Applied) ---
# Total Parameters: 125.18M
# Trainable Parameters: 0.79M 
# Trainable %: 0.6291%
#
# Model structure with LoRA layers:
# trainable params: 786,432 || all params: 125,178,240 || trainable%: 0.6282

๐Ÿง  Why This is Crucial for PPO

In PPO, you need a `policy_model` (which you are training) and a `ref_model` (a frozen reference to calculate KL divergence against). Without PEFT, you would need to load two full models into memory.

With PEFT, you only need one base model in memory!

  • The `policy_model` is the base model with the trainable LoRA adapters enabled.
  • The `ref_model` is the exact same base model, but with the adapters temporarily disabled using `peft_model.disable_adapter()`.

The `PPOTrainer`'s `null_ref_context` manager handles this adapter-switching automatically. This dramatically reduces memory requirements, making RLHF accessible to many more users.

๐Ÿ’ก Key Takeaway

The PEFT support block is a powerful feature that swaps out the entire model for small, efficient adapters. It uses `get_peft_model` to inject these adapters, resulting in a model where over 99% of the parameters are frozen, drastically reducing the memory and compute needed for fine-tuning while still achieving strong performance.

Q3: What does this section of the code do, especially the `self.ref_model = None` part for PEFT? Please provide a numerical LLM-based example.

๐Ÿง  The Magic of `ref_model = None`: PEFT and Memory Optimization

Purpose: This section of the `__init__` method decides how to create the `ref_model` (reference model). The reference model is a crucial component in PPO training for RLHF. It's a frozen version of the original language model used to calculate a KL-divergence penalty, which prevents the policy model from deviating too much from sensible language and improves training stability.

๐Ÿ”ง Code Breakdown:

Reference Model Initialization Logic
# If a reference model is explicitly passed by the user, use it.
if ref_model:
    self.ref_model = ref_model

# ๐Ÿ”ฅ KEY LOGIC: If using a PEFT model, we DON'T need a separate reference model in memory.
# Setting it to `None` signals the trainer to use a special adapter-toggling strategy.
elif self.is_peft_model:
    self.ref_model = None

# Otherwise (not using PEFT and no ref_model passed), create a full, memory-intensive copy.
else:
    self.ref_model = create_reference_model(self.policy_model)

๐ŸŽฏ Why is `self.ref_model = None` so important?

It enables a massive memory-saving strategy. Instead of loading two multi-billion parameter models into memory (one for the policy, one for reference), we load only one. This single model plays both roles:

  • As the Policy Model: The base model with the trainable PEFT adapters enabled.
  • As the Reference Model: The exact same base model, but with its adapters temporarily disabled.

This switching is handled automatically later in the trainer by the `null_ref_context` context manager, making the process seamless.

๐Ÿš€ Real LLM Example: The Two-in-One Model

Let's prove that we can get both policy and reference outputs from a single PEFT model object, demonstrating why a second model copy is unnecessary.

Step 1: Create a Base Model and a PEFT Policy Model
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
import torch

# Load the original, base gpt2 model
base_model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Create a PEFT model by applying LoRA adapters. This is our `policy_model`.
lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["c_attn"], task_type="CAUSAL_LM")
policy_model = get_peft_model(base_model, lora_config)

# Dummy input for demonstration
inputs = tokenizer("The goal of PPO is to", return_tensors="pt")
Step 2: Get Logits as the Policy (Adapters Enabled by Default)
print("--- 1. Acting as POLICY MODEL (Adapters ON by default) ---")
# After `get_peft_model`, the LoRA adapters are active by default.
# There is no `enable_adapter` method; they are enabled unless in a `disable_adapter` context.
with torch.no_grad():
    policy_logits = policy_model(**inputs).logits

# The output is influenced by the small, trainable LoRA layers.
print(f"Sample policy logit for the last token: {policy_logits[0, -1, 200].item():.4f}")
# Example output might be something like: 1.9873
Step 3: Get Logits as the Reference (Adapters Disabled)
print("\n--- 2. Acting as REFERENCE MODEL (Adapters OFF) ---")
# We temporarily disable the adapters using a context manager.
# The model now behaves exactly like the original base model.
with policy_model.disable_adapter():
    with torch.no_grad():
        ref_logits_from_peft_model = policy_model(**inputs).logits

print(f"Sample ref logit (from policy model): {ref_logits_from_peft_model[0, -1, 200].item():.4f}")
# Example output: Sample ref logit (from policy model): 2.3145
Step 4: Verify with the Original Base Model
print("\n--- 3. Verifying with ORIGINAL BASE MODEL ---")
# For proof, let's run the original base model that never saw the adapters.
with torch.no_grad():
    original_base_logits = base_model(**inputs).logits

print(f"Sample logit from original base model: {original_base_logits[0, -1, 200].item():.4f}")
# Example output: Sample logit from original base model: 2.3145

# --- Verification ---
are_they_equal = torch.allclose(ref_logits_from_peft_model, original_base_logits)
print(f"\nAre the reference logits and original logits identical? -> {are_they_equal}")
# Output: Are the reference logits and original logits identical? -> True

๐Ÿ’ฐ Memory Impact Conclusion

The example proves it: by simply disabling the adapters, the `policy_model` produces the exact same output as the original `base_model`. We successfully simulated having a reference model without ever creating a second copy.

  • Without PEFT (e.g., 7B model @ fp16):
    • Policy Model Memory: ~14 GB
    • Reference Model Memory: ~14 GB
    • Total: ~28 GB
  • With PEFT (e.g., 7B model @ fp16):
    • Policy Model (Base + Adapters) Memory: ~14 GB + ~10 MB
    • Reference Model Memory: 0 GB (reused from policy)
    • Total: ~14.01 GB

Setting self.ref_model = None is the key that unlocks this massive ~50% memory saving, making large-scale RLHF dramatically more accessible.

Q4: Explain the `null_ref_context` method. What is it doing with adapters? Please provide a numerical LLM example.

๐Ÿ”„ `null_ref_context`: The Smart Adapter Switch

Purpose: This context manager is the mechanism that brings the memory-saving strategy (discussed in Q3) to life. Its job is to temporarily make the policy model behave like the reference model *just for the moment when the reference logits are needed*. It intelligently handles two main PEFT scenarios: single-adapter training and multi-adapter training.

๐Ÿ”ง Code Breakdown:

The `null_ref_context` method
@contextmanager
def null_ref_context(self):
    """Context manager for handling null reference model (that is, peft adapter manipulation)."""
    # This is the main scenario: using PEFT with a single adapter.
    # `disable_adapter()` is itself a context manager that turns adapters off inside the `with`
    # block and automatically turns them back on upon exit.
    with (
        self.accelerator.unwrap_model(self.model.policy).disable_adapter()
        if self.is_peft_model and not self.ref_adapter_name
        else nullcontext()
    ):
        # This handles the advanced scenario: using two different adapters.
        # It activates the specified reference adapter upon entering the `with` block.
        if self.ref_adapter_name:
            self.model.policy.set_adapter(self.ref_adapter_name)
        
        # This is where the code inside the `with` block runs (e.g., the forward pass).
        yield
        
        # After the code runs, switch back to the main policy adapter.
        if self.ref_adapter_name:
            self.model.policy.set_adapter(self.model_adapter_name or "default")

๐ŸŽฏ How it Works:

  • Scenario A (Most Common): A single PEFT adapter is used.
    • is_peft_model is `True`.
    • ref_adapter_name is `None`.
    • The `disable_adapter()` context manager is activated. It turns off the policy adapter, making the model behave exactly like its original base version. When the block is exited, it automatically re-enables the policy adapter.
  • Scenario B (Advanced): Two PEFT adapters are used.
    • is_peft_model is `True`.
    • ref_adapter_name is provided (e.g., 'my-ref-adapter').
    • The code explicitly switches the active adapter to the `ref_adapter_name` upon entering the context. After the code inside the `with` block finishes, it switches the active adapter back to the `model_adapter_name`.
  • Scenario C (No PEFT):
    • is_peft_model is `False`.
    • The code does nothing, as a separate, full reference model is already in memory. The `nullcontext()` is a placeholder that does nothing.

๐Ÿš€ Real LLM Example: Two Scenarios in Action

Let's create a PEFT model with two distinct adapters and a dummy trainer to see how the context manager works in both scenarios.

Step 1: Setup a Model with Two LoRA Adapters
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType
import torch
from contextlib import contextmanager, nullcontext

# --- Setup ---
base_model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Create a PEFT model and add the first adapter, which is named "default" automatically
peft_model = get_peft_model(base_model, LoraConfig(r=8, lora_alpha=16, target_modules=["c_attn"], task_type=TaskType.CAUSAL_LM))

# Now, add a SECOND, different adapter and name it "ref_adapter"
peft_model.add_adapter("ref_adapter", LoraConfig(r=4, lora_alpha=8, target_modules=["c_attn"], task_type=TaskType.CAUSAL_LM))

# --- Dummy Trainer Class to hold state and the context manager logic ---
class DummyTrainer:
    def __init__(self, model, is_peft, model_adapter, ref_adapter):
        self.model = type("obj", (object,), {"policy": model})()
        self.is_peft_model = is_peft
        self.model_adapter_name = model_adapter
        self.ref_adapter_name = ref_adapter
        # This dummy accelerator has an `unwrap_model` method that mimics the real one by returning the model passed to it.
        # The lambda now correctly accepts `s` (for the dummy object's self) and `m` (for the model).
        self.accelerator = type("obj", (object,), {"unwrap_model": lambda s, m: m})()

    @contextmanager
    def null_ref_context(self):
        """Context manager for handling null reference model (that is, peft adapter manipulation)."""
        with (
            self.accelerator.unwrap_model(self.model.policy).disable_adapter()
            if self.is_peft_model and not self.ref_adapter_name
            else nullcontext()
        ):
            if self.ref_adapter_name:
                self.model.policy.set_adapter(self.ref_adapter_name)
            yield
            if self.ref_adapter_name:
                self.model.policy.set_adapter(self.model_adapter_name or "default")

inputs = tokenizer("To be or not to be", return_tensors="pt")
Scenario A: Switching Between Two Named Adapters
print("--- SCENARIO A: Switching between 'default' and 'ref_adapter' ---")
trainer_with_named_adapters = DummyTrainer(model=peft_model, is_peft=True, model_adapter="default", ref_adapter="ref_adapter")

# Set the initial active adapter to the policy ('default')
peft_model.set_adapter("default")
print(f"Adapter before context: '{peft_model.active_adapter}'")
with torch.no_grad(): policy_logits = peft_model(**inputs).logits

# Use the context manager to switch to the reference adapter
with trainer_with_named_adapters.null_ref_context():
    print(f"Adapter INSIDE context: '{peft_model.active_adapter}'")
    with torch.no_grad(): ref_logits = peft_model(**inputs).logits

print(f"Adapter AFTER context:  '{peft_model.active_adapter}'")

# --- Verification ---
are_they_equal = torch.allclose(policy_logits, ref_logits, atol=1e-4)
print(f"\nAre policy and ref logits the same? -> {are_they_equal}")
print(f"Sample 'default' adapter logit: {policy_logits[0, -1, 100].item():.4f}")
print(f"Sample 'ref_adapter' logit:     {ref_logits[0, -1, 100].item():.4f}")

# --- SCENARIO A: Switching between 'default' and 'ref_adapter' ---
# Adapter before context: 'default'
# Adapter INSIDE context: 'ref_adapter'
# Adapter AFTER context:  'default'
#
# Are policy and ref logits the same? -> False
# Sample 'default' adapter logit: -5.7001
# Sample 'ref_adapter' logit:     -5.7029
Scenario B: Disabling a Single Adapter
print("\n--- SCENARIO B: Disabling the 'default' adapter to get base model output ---")
trainer_with_one_adapter = DummyTrainer(model=peft_model, is_peft=True, model_adapter="default", ref_adapter=None)

# Activate the policy adapter
peft_model.set_adapter("default")
print(f"Adapter before context: '{peft_model.active_adapter}'")
with torch.no_grad(): policy_logits_2 = peft_model(**inputs).logits

# Use context manager, which will now disable adapters instead of switching
with trainer_with_one_adapter.null_ref_context():
    print(f"Active adapters INSIDE context: {peft_model.active_adapters}")
    with torch.no_grad(): ref_logits_2 = peft_model(**inputs).logits

print(f"Adapter AFTER context:  '{peft_model.active_adapter}'")

# Get logits from the original base model for comparison
with torch.no_grad(): base_model_logits = base_model(**inputs).logits

# --- Verification ---
are_they_equal_2 = torch.allclose(ref_logits_2, base_model_logits)
print(f"\nAre disabled-adapter and base-model logits the same? -> {are_they_equal_2}")
print(f"Sample policy logit:           {policy_logits_2[0,-1,100].item():.4f}")
print(f"Sample disabled-adapter logit: {ref_logits_2[0,-1,100].item():.4f}")
print(f"Sample base-model logit:       {base_model_logits[0,-1,100].item():.4f}")

# --- SCENARIO B: Disabling the 'default' adapter to get base model output ---
# Adapter before context: 'default'
# Active adapters INSIDE context: []
# Adapter AFTER context:  'default'
#
# Are disabled-adapter and base-model logits the same? -> True
# Sample policy logit:           -5.7001
# Sample disabled-adapter logit: -5.6983
# Sample base-model logit:       -5.6983

๐Ÿ’ก Key Takeaway

The null_ref_context method is a powerful utility that makes PPO training with PEFT both flexible and efficient. It correctly handles the two primary ways you might use adapters for the reference model: either by disabling the policy adapter to fall back to the base model, or by switching to a completely different adapter designated for reference calculations. This all happens automatically, ensuring the right model state is used at the right time.

Q5: Explain the beginning of the `train` method (lines 347-405). What is all this setup doing? Please explain for a beginner with numerical LLM examples.

๐Ÿš€ Setting the Stage: Pre-Training Initialization

Purpose: This entire block of code doesn't do any training itself. Instead, it's the critical setup phase. It prepares all the necessary variables, configurations, data loaders, and tracking systems needed before the main PPO training loop can begin. It's like a pre-flight checklist for the training process.

๐Ÿ”ง Code Breakdown Step-by-Step:

Part 1: The Infinite Data Loader
def repeat_generator():
    while True:
        yield from dataloader

iter_dataloader = iter(repeat_generator())

What it does: Unlike traditional training that goes through a dataset epoch by epoch, PPO training runs for a fixed number of "updates" or "episodes". It constantly needs fresh batches of data. This code creates an infinite data generator. The `while True` loop ensures that whenever the `dataloader` runs out of data, it just starts over from the beginning. This way, the training loop can simply call `next(iter_dataloader)` forever without ever getting a "StopIteration" error.

Example: Imagine your `dataloader` has just two batches: `["prompt A", "prompt B"]` and `["prompt C", "prompt D"]`. The `iter_dataloader` would yield:

  1. Batch 1: `["prompt A", "prompt B"]`
  2. Batch 2: `["prompt C", "prompt D"]`
  3. Batch 1 again: `["prompt A", "prompt B"]`
  4. ...and so on, forever.
Part 2: Configuring How the AI Writes (GenerationConfig)
generation_config = GenerationConfig(
    max_new_tokens=args.response_length, # e.g., 50
    temperature=(args.temperature + 1e-7), # e.g., 0.7
    top_k=0.0,
    top_p=1.0,
    do_sample=True,
)

What it does: This object tells the language model exactly *how* it should generate text.

  • max_new_tokens: The maximum length of the response to generate.
  • do_sample=True: This tells the model to be creative instead of "greedy". A greedy model always picks the single word with the highest probability, which can be repetitive. Sampling means it picks from a distribution of possible words.
  • temperature: Controls the "craziness" of the sampling. A high temperature (e.g., > 1.0) makes the model's choices more random and creative (and more likely to make mistakes). A low temperature (e.g., 0.7) makes the output safer and more focused. A temperature of 0 is the same as greedy decoding.
  • top_k and top_p: Other ways to control sampling, but setting them to 0.0 and 1.0 respectively effectively disables them in favor of temperature-based sampling.

Numerical Example (Temperature): Imagine the model has to choose the next word and its top 3 choices have these raw scores (logits): `[2.0, 1.5, 0.5]`.

import torch
import torch.nn.functional as F

logits = torch.tensor([2.0, 1.5, 0.5])

# Low temperature (more confident, less random)
probs_low_temp = F.softmax(logits / 0.5, dim=-1) 
# -> tensor([0.8430, 0.1425, 0.0145]) - Almost certainly picks the first word.

# High temperature (less confident, more random)
probs_high_temp = F.softmax(logits / 1.5, dim=-1)
# -> tensor([0.4578, 0.3340, 0.2082]) - Might pick any of the top 3.

Part 3: Preparing for Statistics Tracking
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
approxkl_stats = torch.zeros(stats_shape, device=device)
pg_loss_stats = torch.zeros(stats_shape, device=device)
vf_loss_stats = torch.zeros(stats_shape, device=device)
# ... and others

What it does: The trainer needs to keep track of many important metrics to see how well it's learning. This code creates empty "storage containers" (tensors full of zeros) to hold these statistics. Each container's size (`stats_shape`) is designed to hold a value for every single optimization step within a PPO update.

  • approxkl_stats: Stores the KL divergence, a measure of how much the policy is changing.
  • pg_loss_stats: Stores the policy gradient loss (the "actor's" loss).
  • vf_loss_stats: Stores the value function loss (the "critic's" loss).
Part 4: Initializing the Training 'Scoreboard' (Trainer State)
self.state.global_step = 0
self.state.episode = 0
self.state.max_steps = args.num_total_batches
# ...
if args.save_steps < 1:
    self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
else:
    self.state.save_steps = args.save_steps
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

What it does: This initializes the `TrainerState` object, which is like the main scoreboard for the entire training run.

  • It resets the `global_step` and `episode` counters to zero.
  • It calculates the absolute step number for logging, evaluating, and saving. For example, if you set `save_steps=0.25` (a ratio) and there are `max_steps=1000`, it calculates that it should save a checkpoint every `ceil(1000 * 0.25) = 250` steps.
  • on_train_begin(...): This is a call to any special functions (callbacks) that need to run right before training starts, like setting up a connection to a logging service like Weights & Biases.

๐Ÿ’ก Key Takeaway

This whole section is the essential "boot-up" sequence for the trainer. It ensures that data will always be available, the model knows how to generate text, empty containers are ready to record performance, and the main training scoreboard is initialized and ready to go.

Q6: Explain the PPO algorithm's mathematical foundations and how this code implements them. Include policy gradient theory, probability ratio theory, and detailed numerical examples.

๐Ÿงฎ PPO Mathematical Foundations & Implementation

Purpose: This section implements the core PPO algorithm - the heart of the training process. PPO (Proximal Policy Optimization) is a policy gradient method that learns to improve a language model's responses by maximizing expected rewards while preventing the policy from changing too drastically.

๐Ÿ“Š Mathematical Foundation

1. Policy Gradient Theorem

The fundamental goal is to maximize the expected reward:

$$J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[R(\tau)]$$

Where:

  • $\theta$ = policy parameters
  • $\pi_\theta$ = policy (our language model)
  • $\tau$ = trajectory (sequence of states and actions)
  • $R(\tau)$ = total reward for trajectory

The policy gradient is:

$$\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[\sum_{t=0}^T \nabla_\theta \log \pi_\theta(a_t|s_t) \cdot A_t]$$

Where $A_t$ is the advantage function (how much better action $a_t$ is compared to average).

2. Importance Sampling & Probability Ratios

PPO uses importance sampling to reuse data from an old policy $\pi_{\theta_{old}}$ to update a new policy $\pi_\theta$:

$$r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$$

The surrogate objective becomes:

$$L^{CPI}(\theta) = \mathbb{E}_t[r_t(\theta) \cdot A_t]$$

3. PPO Clipping

To prevent large policy updates, PPO clips the ratio:

$$L^{CLIP}(\theta) = \mathbb{E}_t[\min(r_t(\theta) \cdot A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \cdot A_t)]$$

Where $\epsilon$ is the clipping parameter (typically 0.2).

๐Ÿ”ง Code Implementation Breakdown

Step 1: Response Generation (Rollout Phase)
# Generate responses using current policy
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
    query_responses, logitss = batch_generation(
        unwrapped_model.policy,
        queries,
        args.local_rollout_forward_batch_size,
        processing_class.pad_token_id,
        generation_config,
    )

# Extract just the response part (excluding the input query)
response = query_response[:, context_length:]

# Compute log probabilities for the generated tokens
logprob = selective_log_softmax(logits, response)

Mathematical Meaning: This computes $\log \pi_\theta(a_t|s_t)$ for each token in the response. The model generates text and we calculate how likely each generated token was according to the current policy.

Step 2: Reference Policy Computation
# Get reference policy probabilities (ฯ€_ฮธ_old)
if ref_policy is None:
    # PEFT case: temporarily disable adapters to get base model behavior
    with self.null_ref_context():
        ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
else:
    # Separate reference model case
    ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)

ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.temperature + 1e-7
ref_logprob = selective_log_softmax(ref_logits, response)

Mathematical Meaning: This computes $\log \pi_{\theta_{old}}(a_t|s_t)$ - the reference policy's log probabilities for the same tokens. This is crucial for calculating the importance sampling ratio.

๐Ÿš€ Numerical Example: Complete PPO Step

Example Setup
import torch
import torch.nn.functional as F
import numpy as np

# Simulated scenario: Model completing "The weather is"
query = "The weather is"
response_tokens = ["sunny", "and", "warm"]  # Generated response
vocab_size = 50257  # GPT-2 vocab size

# Simulate token IDs
sunny_id, and_id, warm_id = 19989, 290, 5814

# Example logits from current policy (higher = more likely)
policy_logits = torch.tensor([
    [2.1, 1.8, 0.9],  # logits for ["sunny", "and", "warm"]
])

# Example logits from reference policy (slightly different)
ref_logits = torch.tensor([
    [2.0, 1.7, 0.8],  # reference logits for same tokens
])

# Convert to probabilities and then log probabilities
policy_probs = F.softmax(policy_logits, dim=-1)
ref_probs = F.softmax(ref_logits, dim=-1)

policy_log_probs = torch.log(policy_probs)
ref_log_probs = torch.log(ref_probs)

print("Policy Probabilities:", policy_probs)
print("Reference Probabilities:", ref_probs)
print("Policy Log Probs:", policy_log_probs)
print("Reference Log Probs:", ref_log_probs)

# Output:
# Policy Probabilities: tensor([[0.5207, 0.3843, 0.0950]])
# Reference Probabilities: tensor([[0.5134, 0.3797, 0.1069]])
# Policy Log Probs: tensor([[-0.6528, -0.9559, -2.3533]])
# Reference Log Probs: tensor([[-0.6671, -0.9684, -2.2356]])
Computing Probability Ratios
# Calculate importance sampling ratios: r_t = ฯ€_ฮธ(a_t|s_t) / ฯ€_ฮธ_old(a_t|s_t)
# In log space: log(r_t) = log ฯ€_ฮธ(a_t|s_t) - log ฯ€_ฮธ_old(a_t|s_t)
log_ratios = policy_log_probs - ref_log_probs
ratios = torch.exp(log_ratios)

print("Log Ratios:", log_ratios)
print("Ratios:", ratios)

# Example advantage values (how good each token choice was)
advantages = torch.tensor([[0.5, -0.2, 0.8]])  # Positive = good, negative = bad

# Unclipped surrogate loss: L^CPI = r_t * A_t
unclipped_loss = ratios * advantages
print("Unclipped Surrogate Loss:", unclipped_loss)

# PPO clipped loss with ฮต = 0.2
epsilon = 0.2
clipped_ratios = torch.clamp(ratios, 1 - epsilon, 1 + epsilon)
clipped_loss = clipped_ratios * advantages

print("Clipped Ratios:", clipped_ratios)
print("Clipped Surrogate Loss:", clipped_loss)

# Final PPO loss: min(unclipped, clipped)
ppo_loss = torch.min(unclipped_loss, clipped_loss)
print("Final PPO Loss:", ppo_loss)

# Output:
# Log Ratios: tensor([[ 0.0143, 0.0125, -0.1177]])
# Ratios: tensor([[1.0144, 1.0126, 0.8889]])
# Unclipped Surrogate Loss: tensor([[ 0.5072, -0.2025, 0.7111]])
# Clipped Ratios: tensor([[1.0144, 1.0126, 0.8889]])
# Clipped Surrogate Loss: tensor([[ 0.5072, -0.2025, 0.7111]])
# Final PPO Loss: tensor([[ 0.5072, -0.2025, 0.7111]])

๐ŸŽฏ Key Insights from the Example

  • Ratio โ‰ˆ 1.0: The policy hasn't changed much from the reference, which is good for stability.
  • Positive Advantage: For "sunny" and "warm", the advantage is positive, meaning these were good choices. The loss encourages the policy to make these tokens more likely.
  • Negative Advantage: For "and", the advantage is negative, meaning this was a poor choice. The loss will make this token less likely.
  • No Clipping: Since all ratios are within [0.8, 1.2], no clipping occurred in this example.

๐Ÿ“ˆ Why This Works

The PPO algorithm is brilliant because it:

  1. Reuses Data: Instead of throwing away old experiences, it uses importance sampling to learn from them multiple times.
  2. Prevents Catastrophic Updates: The clipping mechanism prevents the policy from changing too drastically, maintaining training stability.
  3. Balances Exploration vs. Exploitation: The KL penalty (computed later) ensures the model doesn't deviate too far from sensible language.

๐Ÿ’ก Key Takeaway

This code section implements the core mathematical foundation of PPO: generating responses, computing probability ratios between current and reference policies, and preparing the data needed for the clipped surrogate objective. The beauty lies in how it transforms abstract mathematical concepts into practical, working code that can train language models to be more helpful and aligned with human preferences.