基于策略梯度的人类反馈强化学习(RLHF)实现教程
前言
人类反馈强化学习(Reinforcement Learning from Human Feedback, RLHF)是近年来备受关注的技术,特别是在大语言模型(LLM)的训练中发挥了重要作用。本文将从最基础的角度出发,通过一个简单的MNIST数字识别任务,展示如何实现一个最小化的RLHF系统。
一、问题背景
传统的监督学习依赖于大量标注数据,但在某些场景下:
-
标注成本高昂
-
难以定义明确的损失函数
-
需要根据人类偏好进行优化
RLHF的核心思想是:让模型与环境(人类)交互,根据人类反馈的奖励信号来优化策略。这与传统的监督学习有本质区别:
-
监督学习:直接拟合标签,使用交叉熵等损失函数
-
RLHF:通过奖励信号,使用策略梯度方法优化期望奖励
二、数学原理
2.1 策略梯度基础
在强化学习中,我们有一个策略网络 ,表示在状态 下选择动作 的概率。我们的目标是最大化期望奖励:
其中 表示一个轨迹(trajectory), 是轨迹的总奖励。
策略梯度定理告诉我们,目标函数关于参数的梯度为:
2.2 REINFORCE算法
REINFORCE是最基础的策略梯度算法,其更新公式为:
其中:
-
是学习率
-
是策略的对数概率梯度
-
是奖励信号
2.3 优势函数(Advantage)
为了减少方差,我们通常使用优势函数 ,其中:
-
是动作价值函数
-
是状态价值函数(baseline)
在我们的简化实现中,使用 ,其中 baseline 可以设为0或历史奖励的平均值。
2.4 应用到分类任务
对于分类任务,我们可以将:
-
状态 :输入图像
-
动作 :预测的类别(0-9)
-
策略 :模型输出的softmax概率分布
-
奖励 :人类反馈(+1表示正确,-1表示错误)
损失函数为:
注意负号是因为我们要最大化奖励,而优化器通常是最小化损失。
三、代码实现
3.1 模型结构
class MNISTNet(nn.Module):
def __init__(self):
super(MNISTNet, self).__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
这是一个简单的两层全连接网络,输入是28×28的MNIST图像,输出是10个类别的logits。
3.2 预训练阶段
首先使用监督学习预训练模型:
def main():
# ... 数据加载 ...
model = MNISTNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
train(model, train_loader, optimizer, criterion, device)
accuracy = test(model, test_loader, device)
print(f'Epoch {epoch+1}/{epochs}, Test Accuracy: {accuracy:.2f}%')
torch.save(model.state_dict(), 'checkpoint.pth')
这一步是标准的监督学习,为后续的RLHF提供初始策略。
3.3 RLHF核心实现
RLHF的核心在于策略梯度更新:
def rlhf_loop():
# 加载预训练模型
model = MNISTNet().to(device)
model.load_state_dict(torch.load('checkpoint.pth', map_location=device))
model.train()
optimizer = optim.Adam(model.parameters(), lr=0.0005)
while True:
# 1. 模型预测
output = model(data)
probs = torch.softmax(output, dim=1)
_, predicted = torch.max(output, 1)
predicted_num = predicted.item()
# 2. 计算策略的对数概率
dist = torch.distributions.Categorical(probs)
log_prob = dist.log_prob(predicted)
# 3. 显示图片,获取人类反馈
# ... 显示图片 ...
feedback = input('反馈 (y=正确/n=错误/q=退出): ')
# 4. 计算奖励和优势
reward = 1.0 if feedback in ['y', 'yes'] else -1.0
baseline = 0.0
advantage = reward - baseline
# 5. 策略梯度更新
optimizer.zero_grad()
policy_loss = -log_prob * advantage
policy_loss.backward()
optimizer.step()
关键点解析:
-
使用argmax而非采样:我们使用
torch.max选择最优动作,而不是从分布中采样。这避免了采样带来的随机性问题,同时保持了策略梯度的形式。 -
对数概率计算:
dist.log_prob(predicted)计算的是选择动作 的对数概率 。 -
优势函数:
advantage = reward - baseline,这里baseline设为0,但结构上已经准备好可以扩展为移动平均等更复杂的baseline。 -
损失函数:
policy_loss = -log_prob * advantage,负号是因为我们要最大化期望奖励,而优化器最小化损失。
四、实验结果分析
4.1 实验设置
-
预训练:使用10000个样本,训练5个epoch
-
RLHF:学习率0.0005,人类反馈奖励+1/-1
-
测试集:MNIST测试集(10000个样本)
4.2 结果观察
从实验结果可以看到:
-
预训练后准确率:约93.21%
-
RLHF后准确率:根据反馈质量在91-92%之间波动
为什么准确率可能下降?
这实际上反映了RLHF的一个重要特点:它优化的是人类反馈的奖励,而不是准确率。如果:
-
人类反馈有误(误判正确/错误)
-
反馈样本有限,无法覆盖所有情况
-
学习率过大导致过拟合
都会导致模型在测试集上的表现下降。
4.3 与监督学习的对比
| 方法 | 损失函数 | 更新方式 | 优点 | 缺点 |
|------|---------|---------|------|------|
| 监督学习 | 交叉熵 | 直接拟合标签 | 稳定、高效 | 需要大量标注 |
| RLHF | 策略梯度 | 基于奖励信号 | 灵活、可优化偏好 | 方差大、可能不稳定 |
五、关键设计决策
5.1 为什么使用argmax而非采样?
在最初的实现中,我们使用了采样:
action = dist.sample() # 从分布中采样
但这样会导致问题:即使模型预测正确,采样也可能选到错误动作,从而得到负奖励。使用argmax可以:
-
避免随机性带来的噪声
-
更直接地优化最优策略
-
保持策略梯度的数学形式
5.2 Baseline的作用
虽然当前实现中baseline=0,但引入baseline的概念很重要:
Baseline可以减少方差,因为:
-
如果 ,说明这个动作比平均好,应该增加其概率
-
如果 ,说明这个动作比平均差,应该减少其概率
在实际应用中,可以使用:
-
移动平均:
-
价值网络:训练一个价值函数 来估计baseline
5.3 学习率的选择
RLHF的学习率(0.0005)通常比监督学习(0.001)更小,因为:
-
策略梯度更新可能不稳定
-
人类反馈是稀疏的,需要更谨慎的更新
-
避免破坏预训练模型的知识
六、扩展方向
6.1 更复杂的奖励模型
当前实现使用简单的+1/-1奖励,可以扩展为:
-
连续奖励:根据置信度给出0-1之间的奖励
-
多级奖励:非常好(+2)、好(+1)、一般(0)、差(-1)
-
相对奖励:比较两个预测,给出相对偏好
6.2 批量更新
当前实现是逐个样本更新,可以改为:
-
收集一批反馈后批量更新
-
使用经验回放(Experience Replay)
-
减少更新频率,提高稳定性
6.3 价值函数baseline
实现一个价值网络来估计baseline:
class ValueNet(nn.Module):
def __init__(self):
# 类似MNISTNet的结构
pass
def forward(self, x):
# 输出标量值
return value
然后使用 来计算优势。
6.4 PPO等高级算法
可以进一步实现PPO(Proximal Policy Optimization)等更稳定的算法:
-
重要性采样
-
裁剪机制
-
多步更新
七、总结
本文实现了一个最小化的RLHF系统,展示了:
-
策略梯度的基本原理:通过奖励信号优化策略
-
REINFORCE算法的实现:最基础但有效的策略梯度方法
-
在分类任务中的应用:将RLHF思想应用到MNIST识别
核心思想:RLHF不是简单的”人类标注+监督学习”,而是通过策略梯度让模型学习最大化人类反馈的奖励。这使得模型可以:
-
优化难以用损失函数表达的目标
-
根据人类偏好进行个性化调整
-
在交互中持续改进
局限性:
-
方差较大,需要大量反馈
-
可能不稳定,需要仔细调参
-
人类反馈的质量直接影响效果
适用场景:
-
需要优化人类偏好的任务
-
难以定义明确损失函数的场景
-
需要个性化定制的应用
希望这个简单的实现能帮助读者理解RLHF的核心思想,为进一步的研究和应用打下基础。
参考文献
-
Williams, R. J. (1992). Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8(3-4), 229-256.
-
Schulman, J., et al. (2017). Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347.
-
Ouyang, L., et al. (2022). Training language models to follow instructions with human feedback. Advances in Neural Information Processing Systems, 35, 27730-27744.