基于策略梯度的人类反馈强化学习(RLHF)实现教程

前言

人类反馈强化学习(Reinforcement Learning from Human Feedback, RLHF)是近年来备受关注的技术,特别是在大语言模型(LLM)的训练中发挥了重要作用。本文将从最基础的角度出发,通过一个简单的MNIST数字识别任务,展示如何实现一个最小化的RLHF系统。

一、问题背景

传统的监督学习依赖于大量标注数据,但在某些场景下:

  1. 标注成本高昂

  2. 难以定义明确的损失函数

  3. 需要根据人类偏好进行优化

RLHF的核心思想是:让模型与环境(人类)交互,根据人类反馈的奖励信号来优化策略。这与传统的监督学习有本质区别:

  • 监督学习:直接拟合标签,使用交叉熵等损失函数

  • RLHF:通过奖励信号,使用策略梯度方法优化期望奖励

二、数学原理

2.1 策略梯度基础

在强化学习中,我们有一个策略网络 ,表示在状态 下选择动作 的概率。我们的目标是最大化期望奖励:

其中 表示一个轨迹(trajectory), 是轨迹的总奖励。

策略梯度定理告诉我们,目标函数关于参数的梯度为:

2.2 REINFORCE算法

REINFORCE是最基础的策略梯度算法,其更新公式为:

其中:

  • 是学习率

  • 是策略的对数概率梯度

  • 是奖励信号

2.3 优势函数(Advantage)

为了减少方差,我们通常使用优势函数 ,其中:

  • 是动作价值函数

  • 是状态价值函数(baseline)

在我们的简化实现中,使用 ,其中 baseline 可以设为0或历史奖励的平均值。

2.4 应用到分类任务

对于分类任务,我们可以将:

  • 状态 :输入图像

  • 动作 :预测的类别(0-9)

  • 策略 :模型输出的softmax概率分布

  • 奖励 :人类反馈(+1表示正确,-1表示错误)

损失函数为:

注意负号是因为我们要最大化奖励,而优化器通常是最小化损失。

三、代码实现

3.1 模型结构

 
class MNISTNet(nn.Module):
 
    def __init__(self):
 
        super(MNISTNet, self).__init__()
 
        self.fc1 = nn.Linear(28 * 28, 128)
 
        self.fc2 = nn.Linear(128, 10)
 
        self.relu = nn.ReLU()
 
    def forward(self, x):
 
        x = x.view(-1, 28 * 28)
 
        x = self.relu(self.fc1(x))
 
        x = self.fc2(x)
 
        return x
 

这是一个简单的两层全连接网络,输入是28×28的MNIST图像,输出是10个类别的logits。

3.2 预训练阶段

首先使用监督学习预训练模型:

 
def main():
 
    # ... 数据加载 ...
 
    model = MNISTNet().to(device)
 
    optimizer = optim.Adam(model.parameters(), lr=0.001)
 
    criterion = nn.CrossEntropyLoss()
 
    for epoch in range(epochs):
 
        train(model, train_loader, optimizer, criterion, device)
 
        accuracy = test(model, test_loader, device)
 
        print(f'Epoch {epoch+1}/{epochs}, Test Accuracy: {accuracy:.2f}%')
 
    torch.save(model.state_dict(), 'checkpoint.pth')
 

这一步是标准的监督学习,为后续的RLHF提供初始策略。

3.3 RLHF核心实现

RLHF的核心在于策略梯度更新:

 
def rlhf_loop():
 
    # 加载预训练模型
 
    model = MNISTNet().to(device)
 
    model.load_state_dict(torch.load('checkpoint.pth', map_location=device))
 
    model.train()
 
    optimizer = optim.Adam(model.parameters(), lr=0.0005)
 
    while True:
 
        # 1. 模型预测
 
        output = model(data)
 
        probs = torch.softmax(output, dim=1)
 
        _, predicted = torch.max(output, 1)
 
        predicted_num = predicted.item()
 
        # 2. 计算策略的对数概率
 
        dist = torch.distributions.Categorical(probs)
 
        log_prob = dist.log_prob(predicted)
 
        # 3. 显示图片,获取人类反馈
 
        # ... 显示图片 ...
 
        feedback = input('反馈 (y=正确/n=错误/q=退出): ')
 
        # 4. 计算奖励和优势
 
        reward = 1.0 if feedback in ['y', 'yes'] else -1.0
 
        baseline = 0.0
 
        advantage = reward - baseline
 
        # 5. 策略梯度更新
 
        optimizer.zero_grad()
 
        policy_loss = -log_prob * advantage
 
        policy_loss.backward()
 
        optimizer.step()
 

关键点解析:

  1. 使用argmax而非采样:我们使用 torch.max 选择最优动作,而不是从分布中采样。这避免了采样带来的随机性问题,同时保持了策略梯度的形式。

  2. 对数概率计算dist.log_prob(predicted) 计算的是选择动作 的对数概率

  3. 优势函数advantage = reward - baseline,这里baseline设为0,但结构上已经准备好可以扩展为移动平均等更复杂的baseline。

  4. 损失函数policy_loss = -log_prob * advantage,负号是因为我们要最大化期望奖励,而优化器最小化损失。

四、实验结果分析

4.1 实验设置

  • 预训练:使用10000个样本,训练5个epoch

  • RLHF:学习率0.0005,人类反馈奖励+1/-1

  • 测试集:MNIST测试集(10000个样本)

4.2 结果观察

从实验结果可以看到:

  1. 预训练后准确率:约93.21%

  2. RLHF后准确率:根据反馈质量在91-92%之间波动

为什么准确率可能下降?

这实际上反映了RLHF的一个重要特点:它优化的是人类反馈的奖励,而不是准确率。如果:

  • 人类反馈有误(误判正确/错误)

  • 反馈样本有限,无法覆盖所有情况

  • 学习率过大导致过拟合

都会导致模型在测试集上的表现下降。

4.3 与监督学习的对比

| 方法 | 损失函数 | 更新方式 | 优点 | 缺点 |

|------|---------|---------|------|------|

| 监督学习 | 交叉熵 | 直接拟合标签 | 稳定、高效 | 需要大量标注 |

| RLHF | 策略梯度 | 基于奖励信号 | 灵活、可优化偏好 | 方差大、可能不稳定 |

五、关键设计决策

5.1 为什么使用argmax而非采样?

在最初的实现中,我们使用了采样:

 
action = dist.sample()  # 从分布中采样
 

但这样会导致问题:即使模型预测正确,采样也可能选到错误动作,从而得到负奖励。使用argmax可以:

  • 避免随机性带来的噪声

  • 更直接地优化最优策略

  • 保持策略梯度的数学形式

5.2 Baseline的作用

虽然当前实现中baseline=0,但引入baseline的概念很重要:

Baseline可以减少方差,因为:

  • 如果 ,说明这个动作比平均好,应该增加其概率

  • 如果 ,说明这个动作比平均差,应该减少其概率

在实际应用中,可以使用:

  • 移动平均:

  • 价值网络:训练一个价值函数 来估计baseline

5.3 学习率的选择

RLHF的学习率(0.0005)通常比监督学习(0.001)更小,因为:

  • 策略梯度更新可能不稳定

  • 人类反馈是稀疏的,需要更谨慎的更新

  • 避免破坏预训练模型的知识

六、扩展方向

6.1 更复杂的奖励模型

当前实现使用简单的+1/-1奖励,可以扩展为:

  • 连续奖励:根据置信度给出0-1之间的奖励

  • 多级奖励:非常好(+2)、好(+1)、一般(0)、差(-1)

  • 相对奖励:比较两个预测,给出相对偏好

6.2 批量更新

当前实现是逐个样本更新,可以改为:

  • 收集一批反馈后批量更新

  • 使用经验回放(Experience Replay)

  • 减少更新频率,提高稳定性

6.3 价值函数baseline

实现一个价值网络来估计baseline:

 
class ValueNet(nn.Module):
 
    def __init__(self):
 
        # 类似MNISTNet的结构
 
        pass
 
    def forward(self, x):
 
        # 输出标量值
 
        return value
 

然后使用 来计算优势。

6.4 PPO等高级算法

可以进一步实现PPO(Proximal Policy Optimization)等更稳定的算法:

  • 重要性采样

  • 裁剪机制

  • 多步更新

七、总结

本文实现了一个最小化的RLHF系统,展示了:

  1. 策略梯度的基本原理:通过奖励信号优化策略

  2. REINFORCE算法的实现:最基础但有效的策略梯度方法

  3. 在分类任务中的应用:将RLHF思想应用到MNIST识别

核心思想:RLHF不是简单的”人类标注+监督学习”,而是通过策略梯度让模型学习最大化人类反馈的奖励。这使得模型可以:

  • 优化难以用损失函数表达的目标

  • 根据人类偏好进行个性化调整

  • 在交互中持续改进

局限性

  • 方差较大,需要大量反馈

  • 可能不稳定,需要仔细调参

  • 人类反馈的质量直接影响效果

适用场景

  • 需要优化人类偏好的任务

  • 难以定义明确损失函数的场景

  • 需要个性化定制的应用

希望这个简单的实现能帮助读者理解RLHF的核心思想,为进一步的研究和应用打下基础。

参考文献

  1. Williams, R. J. (1992). Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8(3-4), 229-256.

  2. Schulman, J., et al. (2017). Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347.

  3. Ouyang, L., et al. (2022). Training language models to follow instructions with human feedback. Advances in Neural Information Processing Systems, 35, 27730-27744.