统一FP8:超越混合精度,实现稳定加速的MoE RL训练

我们实现了RL中全FP8采样和训练流程。实验显示,对于MoE模型,使用BF16训练结合FP8 rollout时,模型越大,训练-推理不一致性越严重。相比之下,统一FP8用于训练和rollout,能有效消除量化误差导致的训练-推理不一致,提升RL训练的速度与稳定性。本文详述FP8硬件基础、格式选择、尺度计算及量化策略,支持Qwen3-4B和Qwen3-30B-A3B的miles框架即插即用,由InfiXAI、Ant Group AQ、SGLang RL和Miles团队联合完成。(128字)

TL;DR:我们实现了RL中全FP8采样和训练流程。实验表明,对于MoE模型,使用BF16训练结合FP8 rollout时,模型规模越大,训练-推理不一致性越严重。统一FP8用于训练和rollout,能有效消除量化误差引发的训练-推理不一致,提升RL训练速度与稳定性。

SGLang RL团队与Miles社区在RL训练稳定性和加速方面进行了有趣探索,包括对齐SGLang和FSDP后端以实现严格零KL散度,以及Speculative Decoding结合在线SFT用于draft模型。

在此基础上,我们分享一项平衡稳定性和性能的新进展——端到端FP8 RL训练与采样pipeline。miles框架已完全支持Qwen3-4B和Qwen3-30B-A3B的FP8 RL训练(详见),开箱即用。

本工作由InfiXAI Team、Ant Group AQ Team、SGLang RL Team、Miles Team联合完成。特别感谢Verda Cloud提供算力赞助,以及NVIDIA在Transformer Engine (TE)上的技术支持。

FP8训练的硬件基础

Tensor Cores与低精度支持

低精度计算是硬件-软件协同设计的瑰宝。其硬件基础是Tensor Cores,一种专为大规模矩阵乘累加设计的GPU硬件加速单元,这是深度学习核心计算。相较传统CUDA核心,Tensor Cores对低精度格式(如FP16、BF16、FP8)提供更高吞吐量。其演进从基本FMA指令和DP4A矢量化起步,Volta架构首次引入专用Tensor Cores,此后Ampere、Hopper和Blackwell持续推进:

  • 规模扩展:单操作处理更大矩阵,提升计算-内存比。
  • 精度降低:持续支持FP/BF16、FP8等更低精度格式。
ArchFP64F16INT8INT4FP8MXFP
Volta✅ FP16
Turing✅ FP16
Ampere✅ FP16/BF16
Hopper✅ FP16/BF16
(累加仅支持FP22)
Blackwell✅ FP16/BF16✅ MXFP(8/6/4)
NVFP4
Blackwell Ultra✅ (reduced FLOPs)✅ FP16/BF16✅ (reduced FLOPS)✅ MXFP(8/6/4)
NVFP4

图源:zartbotSemiAnalysis

这一趋势使低精度存储与计算更具吸引力。具体优势包括:

  1. 显著降低内存占用:FP8理论上将模型权重和激活内存减半,缓解VRAM压力。
  2. 理论2×计算吞吐:H100 SXM上,FP8 Tensor Cores达1979 TFLOPS,是BF16(989 TFLOPS)的两倍。
  3. 缓解内存带宽瓶颈:数据更紧凑,减少HBM到计算核心传输。

FP8格式

FP8是一种8位浮点格式,相较FP32(32位)和FP16/BF16(16位),将存储与传输成本降至1/4或1/2,缓解VRAM与带宽瓶颈,提升训练推理性能。目前两大格式:

  • E4M3:4位指数+3位尾数。动态范围小但精度高。
  • E5M2:5位指数+2位尾数。动态范围大但精度低。

FP8 E4M3 vs E5M2

图源:OCP whitepaper

此设计在最大化硬件吞吐的同时,保持足够数值范围与精度。

FP8尺度选择

维度FP32 Scale(全精度缩放因子)E8M0 Scale(仅指数缩放)
格式定义FP32 (IEEE 754单精度浮点)E8M0 (8位指数,0位尾数)
数值特性任意精度实数表示仅支持2的幂,如1、2、0.5;无法表示1.5等
核心思想高精度管理缩放因子,确保训练数值稳定将缩放因子纳入低精度,利用位运算高效
主要优势1. 高精度、稳定训练:精确捕捉动态范围,减少量化误差,防散度。
2. 广泛支持:NVIDIA Transformer Engine默认,生态成熟
1. 极硬件友好:缩放为简单位移,快速低能耗。
2. 统一pipeline:全8位运行,简化硬件设计
主要劣势1. 存储开销:每量化张量需额外FP32尺度,耗VRAM。
2. 计算开销:尺度计算与转换需FP32
1. 精度损失风险:强制四舍五入到2幂引入噪声,反向传播累积致散度。
2. 动态范围分辨率有限:难精细适应复杂张量分布
总结行业最常见、安全方案牺牲精度换极致硬件效率

综合评估后,我们选择FP32作为训练尺度精度。原因:

  1. 精度对齐与训练稳定:FP32尺度精细捕捉张量动态范围,使FP8训练损失曲线接近BF16基线。
  2. 与推理生态一致:主流推理模型也用FP32量化尺度。
  3. 实际硬件收益
    - Hopper (H100/H800):支持FP8 Tensor Cores,但无E8M0专用单元。
    - Blackwell (B100/B200):引入MXFP8,支持E8M0类块级缩放(arXiv:2506.08027)。

因此,在当前H系列集群下,强制E8M0非但无明显加速,还引入软件模拟开销与精度风险。

FP8量化

常见量化策略包括per-tensorper-blockper-token。无论粒度,量化通常分两步:

FP8 quantization flow

图源:InfiR2: A Comprehensive FP8 Training Recipe for Reasoning-Enhanced Language Models

步骤1:计算缩放因子 S

取张量(或块)最大绝对值 max|X|,除以FP8最大可表示值 V_max

S = max|X| / V_max

步骤2:计算量化值 Q

用S将原张量X各元素x除以S并四舍五入:

Q(x) = round(x / S)

因FP8精度低于FP16/BF16,实际需权衡稳定与效率,前向/后向常采用不同策略与粒度:

  • Activations:通常per-token量化。激活常含显著异常值,细粒度可局部化异常影响,保留整体精度。
  • Weights:通常per-block量化。收敛后权重分布平滑(近高斯),少异常值,但对量化误差敏感。块状(如block_size × block_size)平衡精度、硬件优化、效率与内存节省。
  • Gradients:通常per-token量化。