Variance-Based RL Sample Selection

Nikola Balic (@nibzard)· emerging

问题

并非所有训练样本对强化学习的价值都是均等的:

  • 零方差样本:模型每次得到的分数完全一致(要么始终正确,要么始终错误)→ 无学习信号
  • 计算资源浪费:在模型不存在不确定性的样本上训练,会浪费成本高昂的RL探索资源
  • 数据利用率低下:在训练预算有限的情况下,需要最大化每个样本的学习价值
  • 训练潜力不明:难以判断数据集是否能支撑有效的RL训练

Theo在FinQA基准测试上开展基线评估时发现,约85%的样本存在零方差(模型要么始终答对,要么始终答错),这意味着仅约15%的样本能真正为学习过程贡献价值。

方案

为每个样本执行多次基线评估以识别差异,优先选择高差异样本用于训练。

差异图方法论:

  1. 基线评估:在每个样本上运行基础模型3-5次
  2. 可视化差异:绘制结果以识别存在差异的样本
  3. 样本分类
    • 始终正确(差异=0):模型已掌握这类样本
    • 始终错误(差异=0):模型无法学习这类样本(难度过高或需要其他方法)
    • 有时正确(差异>0):RL的首选训练样本
  4. 聚焦训练:优先或仅使用高差异样本

理解差异图:

分数
1.0 ●━━━━━━━●━━━━━━●━━━━━━━●    ← 始终正确(无需学习)
    ┃       ┃      ┃       ┃
0.5 ┃   ●━━━●━━━●  ┃   ●━━━●━━━●    ← 高差异(此处可学习!)
    ┃   ┃   ▼       ┃   ┃
0.0 ●━━━●━━━━━━●━━━●━━━━━━●━━━━━━●    ← 始终错误(无需学习)
    └───┴───┴───┴───┴───┴───┴───→
        样本索引

    ● = 最佳分数(图表中显示为红色叉号)
    ━ = 平均分数(蓝色粗线)
    ▼ = 差异范围(蓝色细线)

实现代码:

import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

class VarianceAnalyzer:
    """
    分析基线差异以识别高价值训练样本
    """
    def __init__(self, agent, dataset, n_runs=3):
        self.agent = agent
        self.dataset = dataset
        self.n_runs = n_runs
        self.results = defaultdict(list)

    def run_baseline_evals(self):
        """
        在每个样本上多次运行智能体(agent)
        """
        print(f"为每个样本执行{self.n_runs}次评估...")

        for sample_idx, sample in enumerate(self.dataset):
            for run_idx in range(self.n_runs):
                score = self.agent.evaluate(sample)
                self.results[sample_idx].append(score)

            if sample_idx % 10 == 0:
                print(f"已完成 {sample_idx}/{len(self.dataset)} 个样本")

        return self.results

    def compute_variance_metrics(self):
        """
        计算每个样本的差异统计指标
        """
        metrics = []

        for sample_idx in sorted(self.results.keys()):
            scores = self.results[sample_idx]

            metrics.append({
                'sample_idx': sample_idx,
                'mean_score': np.mean(scores),
                'best_score': np.max(scores),
                'worst_score': np.min(scores),
                'variance': np.var(scores),
                'std_dev': np.std(scores),
                'scores': scores
            })

        return metrics

    def plot_variance(self, metrics, title="基线差异分析"):
        """
        创建差异可视化图(类似Theo的图表样式)
        """
        sample_indices = [m['sample_idx'] for m in metrics]
        mean_scores = [m['mean_score'] for m in metrics]
        best_scores = [m['best_score'] for m in metrics]
        std_devs = [m['std_dev'] for m in metrics]

        plt.figure(figsize=(14, 6))

        # 绘制带误差棒(代表差异)的平均分数
        plt.errorbar(
            sample_indices,
            mean_scores,
            yerr=std_devs,
            fmt='o',
            linewidth=2,
            markersize=3,
            label='平均值 ± 标准差',
            color='cornflowerblue',
            elinewidth=1
        )

        # 叠加最佳分数点
        plt.scatter(
            sample_indices,
            best_scores,
            marker='x',
            s=50,
            color='red',
            label='最佳分数',
            alpha=0.7
        )

        plt.xlabel('样本索引')
        plt.ylabel('分数')
        plt.title(title)
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()

        return plt

    def identify_high_variance_samples(self, metrics, variance_threshold=0.01):
        """
        筛选具有显著差异的样本
        """
        high_variance = [
            m for m in metrics
            if m['variance'] > variance_threshold
            and 0 < m['mean_score'] < 1.0  # 排除始终正确或始终错误的样本
        ]

        print(f"\n差异分析结果:")
        print(f"  总样本数: {len(metrics)}")
        print(f"  高差异样本数: {len(high_variance)} "
              f"({100*len(high_variance)/len(metrics):.1f}%)")
        print(f"  始终正确的样本数: {sum(1 for m in metrics if m['best_score'] == 1.0 and m['variance'] == 0)}")
        print(f"  始终错误的样本数: {sum(1 for m in metrics if m['best_score'] == 0.0)}")

        return high_variance

    def compute_improvement_potential(self, metrics):
        """
        计算如果模型始终能达到N次运行中的最佳性能,性能可提升的幅度
        """
        current_avg = np.mean([m['mean_score'] for m in metrics])
        best_of_n_avg = np.mean([m['best_score'] for m in metrics])

        potential_gain = best_of_n_avg - current_avg

        print(f"\n性能提升潜力:")
        print(f"  当前平均分数: {current_avg:.3f}")
        print(f"  {self.n_runs}次运行最佳成绩的平均分: {best_of_n_avg:.3f}")
        print(f"  潜在提升幅度: {potential_gain:.3f} "
              f"(相对提升 {100*potential_gain/current_avg:.1f}%)")

        return {
            'current': current_avg,
            'best_of_n': best_of_n_avg,
            'potential_gain': potential_gain
        }


# 使用示例
analyzer = VarianceAnalyzer(
    agent=my_agent,
    dataset=validation_set,
    n_runs=3
)

# 执行基线评估
results = analyzer.run_baseline_evals()

# 分析差异指标
metrics = analyzer.compute_variance_metrics()

# 生成可视化图
analyzer.plot_variance(metrics)

# 识别高价值样本
high_var_samples = analyzer.identify_high_variance_samples(metrics)

# 计算性能提升潜力
potential = analyzer.compute_improvement_potential(metrics)

# 使用高差异样本进行

如何使用

步骤1:基线评估(训练前)

在训练集与验证集的每个样本上,运行基础模型3-5次:

# 每个样本运行多次
for sample in dataset:
    for run in range(3):
        score = agent.evaluate(sample)
        record_score(sample.id, score)

步骤2:生成方差图

通过可视化分析数据特征:

  • X轴:样本索引
  • Y轴:分数(0-1)
  • 红色叉号:多次运行得到的最佳分数
  • 蓝色条形:平均分数(粗线)与方差(细线)

步骤3:解读结果

适合RL训练的正向指标:

  • 高方差样本占比15-30%:具备充足的学习机会
  • Best-of-N 远高于均值:模型通过RL具备明确的提升潜力
  • 方差分布于整个数据集:未集中在少数样本中

警告信号:

  • 高方差占比<10%:数据集可能过易或过难
  • Best-of-N ≈ 均值:模型一致性极强(提升潜力极低)
  • 方差全部集中在尾部:多数样本无法提供有效学习信号

步骤4:设置计算乘数

计算乘数用于控制训练过程中的探索行为:

  • 低方差(10-15%):使用2-4的计算乘数,以提升探索度
  • 中等方差(15-30%):使用1-2的计算乘数
  • 高方差(>30%):计算乘数设为1即可满足需求

步骤5:训练期间监控

跟踪方差的演变趋势:

  • 训练初期:随着模型学习,方差应逐步降低
  • 平台期:模型探索新策略时,方差可能出现回升
  • 收敛期:方差应在较低水平趋于稳定

权衡

优点

  • 数据高效性:将训练聚焦于真正对学习有贡献的样本
  • 可预测性:在开展高成本训练前预估改进潜力
  • 诊断性:判断你的任务是否适配RL
  • 超参数指导:为计算倍率与训练时长的决策提供参考依据

缺点

  • 前期成本高:训练前需完成3-5倍于基线水平的评估工作
  • 小样本局限:当样本量较少(<50)时,方差估计可能存在噪声
  • 无法确保成功:高方差是必要条件,但并非充分条件
  • 方差动态变化:训练过程中方差会发生改变,初始分析的结论可能不再成立

参考文献

关键词

本文献汇总了OpenAI Build Hour系列中与Agent强化微调(RFT)相关的活动资源链接,包含2025年11月的差异分析演示、Prashant主讲的往期专场,同时关联Agent强化微调、推理时缩放等技术模式。

  • 相关技术模式:Agent强化微调、推理时缩放

来源摘要

正在获取来源并生成中文摘要…

来源: https://youtu.be/1s_7RMG4O4U

← 返回社区