从RLHF到DPO:直接偏好优化(Direct Preference Optimization)的数学推导与稳定性优势

好的,下面开始我们的讲座。

从RLHF到DPO:直接偏好优化(Direct Preference Optimization)的数学推导与稳定性优势

大家好!今天我们来深入探讨一个在大型语言模型(LLM)对齐领域非常重要的技术——直接偏好优化(Direct Preference Optimization,简称DPO)。DPO作为一种RLHF(Reinforcement Learning from Human Feedback)的替代方案,近年来受到了广泛关注,其数学推导的简洁性和训练的稳定性是其主要优势。我们将从RLHF的背景出发,逐步推导出DPO的数学公式,并探讨其稳定性的来源。

1. RLHF的背景与挑战

RLHF旨在使LLM的行为与人类的价值观和偏好对齐。传统的RLHF流程通常包含三个阶段:

  1. 预训练模型(Pre-trained Model): 利用大规模文本数据预训练一个基础的LLM。
  2. 奖励模型(Reward Model): 收集人类对不同模型输出的偏好数据,训练一个奖励模型,该模型预测给定模型输出的质量。
  3. 强化学习(Reinforcement Learning): 使用强化学习算法(如PPO)优化预训练模型,使其生成的输出能够最大化奖励模型的预测值。

RLHF虽然有效,但也存在一些挑战:

  • 复杂性: 需要训练奖励模型,并使用强化学习算法优化策略,整个流程较为复杂。
  • 不稳定性: 强化学习训练过程可能不稳定,需要精细的调参。
  • 奖励模型偏差: 奖励模型可能存在偏差,导致模型优化方向与人类偏好不一致。

2. DPO的数学推导

DPO旨在直接从人类偏好数据中优化模型,而无需显式地训练奖励模型。其核心思想是将奖励函数隐式地嵌入到模型优化过程中。

2.1 偏好数据

假设我们有一组人类偏好数据,包含三元组 (x, yw, yl),其中:

  • x:输入文本提示(prompt)。
  • yw:模型对于输入 x 的胜出(preferred/winner)的响应。
  • yl:模型对于输入 x 的落败(dispreferred/loser)的响应。

我们的目标是训练一个模型 πθ(y|x),使其生成 yw 的概率高于生成 yl 的概率,从而与人类偏好对齐。

2.2 Bradley-Terry模型

DPO基于Bradley-Terry模型来模拟人类偏好。Bradley-Terry模型假设,对于给定的输入 x,胜出响应 yw 优于落败响应 yl 的概率由以下公式给出:

P(yw ≻ yl | x) = exp(r(x, yw)) / (exp(r(x, yw)) + exp(r(x, yl)))

其中,r(x, y) 是一个奖励函数,表示响应 y 对于输入 x 的质量。

2.3 RLHF的目标函数

在RLHF中,我们通常使用强化学习算法(如PPO)来优化模型 πθ(y|x),使其最大化以下目标函数:

JRLHF(θ) = Ex,y~πθ [r(x, y) – β KL(πθ(y|x), πref(y|x))]

其中:

  • πθ(y|x) 是待优化的策略模型。
  • πref(y|x) 是一个参考模型(通常是预训练模型),用于限制策略模型的偏离程度。
  • r(x, y) 是奖励模型对 (x, y) 的奖励预测。
  • β 是一个超参数,控制KL散度的强度。

2.4 最优策略的推导

我们可以推导出最优策略 π*θ(y|x) 与奖励函数 r(x, y) 之间的关系。为了最大化 JRLHF(θ),我们可以对 πθ(y|x) 求导并令其等于零。经过一系列数学推导,可以得到:

π*θ(y|x) ∝ πref(y|x) exp(r(x, y) / β)

这个公式表明,最优策略与参考模型和奖励函数相关。奖励越高,模型生成该响应的概率越高;参考模型则起到正则化的作用,防止模型过度偏离。

2.5 DPO的损失函数

DPO的核心思想是直接优化模型,而无需显式地训练奖励模型。我们可以将最优策略的表达式代入Bradley-Terry模型,得到:

P(yw ≻ yl | x) = exp(r(x, yw)) / (exp(r(x, yw)) + exp(r(x, yl)))
= exp(β log(πθ(yw|x) / πref(yw|x))) / (exp(β log(πθ(yw|x) / πref(yw|x))) + exp(β log(πθ(yl|x) / πref(yl|x))))
= π
θ(yw|x) / (πθ(yw|x) + πθ(yl|x))

因此,我们可以直接最大化以下似然函数:

LDPO(θ) = Ex,yw,yl [log (πθ(yw|x) / (πθ(yw|x) + πθ(yl|x)))]

为了方便优化,我们可以将其转化为损失函数:

LossDPO(θ) = – Ex,yw,yl [log (πθ(yw|x) / (πθ(yw|x) + πθ(yl|x)))]

更常用的形式是:

LossDPO(θ) = – Ex,yw,yl [log (sigmoid(β * (log πθ(yw|x) – log πref(yw|x) – (log πθ(yl|x) – log πref(yl|x)))))]

其中,sigmoid(x) = 1 / (1 + exp(-x))。

2.6 DPO的算法流程

DPO的算法流程如下:

  1. 数据准备: 收集人类偏好数据,形成三元组 (x, yw, yl)。
  2. 模型初始化: 初始化待优化的策略模型 πθ(y|x) 和参考模型 πref(y|x) (通常是预训练模型)。
  3. 迭代优化:
    • 从偏好数据集中抽取一个batch的三元组 (x, yw, yl)。
    • 计算DPO损失函数 LossDPO(θ)。
    • 使用梯度下降算法更新策略模型 πθ(y|x) 的参数。
  4. 模型评估: 使用验证集评估模型性能,并调整超参数。

3. DPO的稳定性优势

DPO相比于RLHF,在训练稳定性方面具有显著优势,这主要归因于以下几点:

  • 直接优化: DPO直接优化模型,避免了训练奖励模型带来的误差和不稳定性。奖励模型可能存在偏差,导致强化学习算法朝着错误的方向优化。
  • 梯度稳定性: DPO的损失函数是基于交叉熵的,梯度较为稳定,避免了强化学习中常见的梯度爆炸或消失问题。
  • KL散度约束: DPO通过引入参考模型,隐式地约束了策略模型的偏离程度,防止模型过度优化,从而提高了稳定性。

4. 代码实现示例 (PyTorch)

import torch
import torch.nn.functional as F

def dpo_loss(policy_logps, reference_logps, rewards, beta):
    """Compute pairwise DPO loss.

    Args:
        policy_logps: Log probabilities from the policy model.
        reference_logps: Log probabilities from the reference model (e.g., the pre-trained model).
        rewards: Rewards associated with each response.
        beta: DPO hyperparameter controlling the strength of the reference model.

    Returns:
        Pairwise DPO loss.
    """
    pi_logratios = policy_logps - reference_logps
    loss = -F.logsigmoid(beta * (rewards)).mean()
    return loss

def compute_logprobs(model, tokenizer, prompts, responses):
    """Compute log probabilities of responses given prompts using the model.

    Args:
        model: The language model.
        tokenizer: The tokenizer.
        prompts: List of prompts.
        responses: List of responses.

    Returns:
        Log probabilities of the responses given the prompts.
    """
    encoded_prompts = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(model.device)
    encoded_responses = tokenizer(responses, return_tensors="pt", padding=True, truncation=True).to(model.device)

    with torch.no_grad():
        prompt_outputs = model(**encoded_prompts)
        response_outputs = model(**encoded_responses) #注意此处通常是使用labels参数,计算loss而不是直接预测

        # Extract log probabilities (example - adjust as needed based on your model's output)
        # This part heavily depends on your model's architecture.
        # The following is a placeholder; you'll need to adapt it to your model.
        logprobs = torch.log_softmax(response_outputs.logits, dim=-1)  # Example: Assuming logits are in response_outputs.logits
        # Select log probabilities corresponding to the tokens in the response
        # This requires aligning the logits with the tokens and handling padding correctly.
        # This is a simplified example; you likely need more sophisticated indexing.

        #下面是更常规的计算方式,使用labels来计算loss,并从中提取log_prob
        # encoded_responses['labels'] = encoded_responses['input_ids'].clone() #确保labels存在
        # outputs = model(**encoded_responses)
        # logprobs = -outputs.loss #简化起见,直接使用loss的负数,实际需要更精确的计算每个token的logprob
        # print(f"Model Output keys: {outputs.keys()}") #Debug

        #Simplified example:  Assuming you can get log probabilities for each token and sum them
        # This part needs to be adapted to your specific model's output structure.
        # logprobs = torch.sum(logprobs, dim=-1) #非常简化,需要根据实际情况修改

        #更为复杂的计算方式,需要遍历每个token,并计算其logprob
        # 首先,确保响应长度不大于模型输入长度
        max_len = min(encoded_prompts['input_ids'].shape[1] + encoded_responses['input_ids'].shape[1], model.config.max_position_embeddings)
        # 获取模型输出的logits
        logits = model(**encoded_prompts, **encoded_responses).logits
        # 获取label,即response的token ids
        labels = encoded_responses['input_ids']
        # 计算每个token的log probability
        log_probs = torch.gather(logits[:, :-1, :], dim=2, index=labels[:, 1:].unsqueeze(2)).squeeze(2)
        # 计算总的log probability,mask掉padding
        mask = (labels[:, 1:] != tokenizer.pad_token_id)
        log_probs = (log_probs * mask).sum(dim=1)

    return logprobs

# Example usage:
if __name__ == '__main__':
    from transformers import AutoModelForCausalLM, AutoTokenizer

    # Load a pre-trained model and tokenizer (replace with your actual model)
    model_name = "facebook/opt-350m" # or any other suitable model
    model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token # 确保pad token设置正确

    # Example prompts and responses
    prompts = ["Write a short story about a cat.", "Translate 'hello' to French."]
    winner_responses = ["Once upon a time, there was a cat...", "Bonjour"]
    loser_responses = ["A cat walked...", "Guten Tag"]

    # Compute log probabilities
    winner_logps = compute_logprobs(model, tokenizer, prompts, winner_responses)
    loser_logps = compute_logprobs(model, tokenizer, prompts, loser_responses)

    # Example rewards (replace with your actual reward calculation)
    rewards = winner_logps - loser_logps #简单的奖励定义

    # DPO parameters
    beta = 0.1

    # Calculate DPO loss
    loss = dpo_loss(winner_logps, loser_logps, rewards, beta)

    print(f"DPO Loss: {loss.item()}")

代码解释:

  • dpo_loss 函数:计算DPO损失,输入策略模型和参考模型的log概率,以及奖励值。
  • compute_logprobs 函数:根据prompt和response,计算模型生成response的log概率。 这部分代码需要根据你使用的模型架构进行调整,确保能够正确提取log概率。
  • 示例用法:展示了如何加载预训练模型和tokenizer,准备数据,计算log概率,并计算DPO损失。

注意: 上述代码只是一个简化的示例,实际应用中需要根据具体情况进行调整。 特别是compute_logprobs 函数,需要根据模型输出的格式进行适配。

5. DPO的局限性

DPO虽然具有许多优点,但也存在一些局限性:

  • 对偏好数据质量的依赖性: DPO的性能高度依赖于人类偏好数据的质量。如果偏好数据存在噪声或偏差,则可能导致模型性能下降。
  • 探索能力有限: DPO是一种on-policy算法,其探索能力有限。它只能根据已有的偏好数据进行优化,难以发现新的、更好的策略。
  • 超参数敏感: DPO的性能对超参数(如β)比较敏感,需要进行仔细的调参。

6. DPO的变体与扩展

近年来,研究者们提出了许多DPO的变体和扩展,以克服其局限性:

  • IPO (Identity Preference Optimisation): 将参考模型直接设置为预训练模型,简化了DPO的实现。
  • KTO (Kahneman-Tversky Optimisation): 引入前景理论,更好地模拟人类的决策行为。
  • PRO (Pairwise Ranking Optimization): 使用pairwise ranking loss来优化模型,提高了训练的稳定性。

这些变体和扩展在不同的场景下可能具有更好的性能,值得进一步研究。

稳定性和简单性让DPO成为有力的工具

DPO通过直接优化模型,避免了训练奖励模型带来的误差和不稳定性,同时,其梯度稳定性以及KL散度约束也提高了训练的稳定性。尽管存在一些局限性,但DPO凭借其简洁性和稳定性,成为了一个非常有力的LLM对齐工具。

未来研究方向

未来的研究方向包括:如何提高DPO对偏好数据质量的鲁棒性,如何增强DPO的探索能力,以及如何自动调整DPO的超参数。 此外,将DPO与其他技术(如主动学习)结合,也有望进一步提高LLM的对齐效果。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注