Direct Preference Optimization (DPO)
什么是 DPO
DPO 是一种用于大语言模型(LLM)对齐的新方法,它的主要目的是让 AI 模型的输出更好地符合人类偏好。 它是 RLHF(基于人类反馈的强化学习)的一种替代方案,该方法于 2023 年提出, 显著简化了语言模型与人类偏好的对齐过程
在 DPO 之前提出之前
在 DPO 提出之前,RLHF 的训练过程通常分为两个阶段:首先是监督微调(SFT),然后是强化学习(RL)。
第一步是训练 reward model。
训练数据是同一个 prompt 的 2 个回答,让人或 GPT4 标注哪个回答更好,reward model 会去优化如下的 loss
其中 就是 reward model 用来给回答打分。 是训练数据集, 是 prompt, 和 分别是好的回答和不好的回答。也就是说,要尽可能让好的回答的得分比不好的回答高,拉大他们之间的差别。
第二步是用 RL 算法来提升模型的得分。 。
loss 如下:
其中是我们在训练的 LLM,是训练的初始值。这个 loss 意思是希望 LLM 输出的回答的评分能尽可能高,同时不要偏离太多,保证它还能正常做回答,不要训成一个评分很高但是回答乱码的东西。
DPO 的贡献
DPO 的公式推导
如果我们归一化一下分母,即取
TIP
注意这里是关于的函数与无关,这里很重要因为我们的 loss 是不需要计算关于输入的损失的 所以下面变换时可以忽略
也就可以构造出一个新的概率分布:
我们将归一化后的新概率分布替换原来的分子那么我们可以将优化目标重新表示为:
TIP
这里有一个非常重要的 point 我们将 表达式中的 巧妙的转换成其他新 loss 的 KL 散度的形式,
The key point is
我们不需要再面对公式中 sample 操作,这样我们可以进行梯度下降 backprop 同时我们将 loss 转换成了和的 KL 散度。这非常的 amazing 啊! 我们目前只需要最小化新的概率分布和 就完成了之前的复杂的 RL 算法
现在的问题是如何 get的概率分布,我们可以通过的定义式 get 到的关系,那么我们做一些简单的变换可以得到
这时候我们将的表达式直接带入到上面的 reward model 的 loss function
就会得到下面的表达式
OK 现在我们只差最后一步,在上面推导出的表达式 中我们可以发现我们实现需要的是最小化和的散度, 那么当我们令散度等于 0 时实际我么就是在优化
所以我们不如直接将 loss 直接写成这样
这非常的 amazing 啊!我们竟然可以通过这个 trick 将两阶段的 train 简化成一阶段,同时还不用优化了 sampling 的操作让损失可以直接 backprop
实际应用
可以用于大语言模型的微调和对齐
在实践中比传统 RLHF 方法更容易实现和部署
已经在多个开源项目中得到应用和验证
结论
通过以上的推导,我们可以看到 DPO 方法通过将 RLHF 的两阶段训练简化为单阶段训练,大大简化了训练过程。 同时使用 KL 散度的形式来优化训练过程使得模型可以直接 backprop 而不需要用 trick 去解决 sample 的问题并使得训练内存需求得以缓解。