最近在训练一个图像篡改检测网络时,为了提升模型的鲁棒性,我对数据集进行了随机 JPEG 压缩作为预处理手段。本以为这只是一个常规的数据增强操作,没想到却引发了一场艰难的 Debug 之旅——模型训练过程中 Loss 突然变成了 NaN。
经过一系列常规排查无果后,最终锁定了混合精度训练与数据分布之间的问题。在此记录下排查思路和解决方法。
训练设置:我使用MMSeg框架的自动混合精度AmpOptimWrapper。
2. 曲折的排查过程(试错)
遇到 Loss Nan,我首先按照常规经验进行了一系列排查,但均未解决问题:
- 梯度裁剪:怀疑梯度爆炸,加入了
clip_grad_norm,无效。 - 降低学习率:怀疑步长过大导致跳出最优解,将 LR 降低了 10 倍甚至更多,无效。
- 数据清洗:怀疑数据集中存在损坏的图片导致读取错误,编写脚本扫描了所有数据,未发现异常。
- 检查网络结构:检查了自定义层是否有除以零等逻辑错误,未发现异常。
3. 根因分析:FP16 与数据分布的冲突
在排除了上述常规原因后,我又在github上面查找相关问题解决办法,在一条评论中发现有人说改用fp32精度,尝试过后确实没有报错。我又查找了为什么fp16为什么会报错的原因:
- JPEG 压缩的副作用:高强度的随机 JPEG 压缩会在图像中引入复杂的压缩伪影(Artifacts)。这导致输入数据的分布变得极其不规律,可能产生某些极端的像素值或特征值。
- Attention 机制的数值溢出:在 Transformer 或类似的 Attention 模块计算中(通常包含
Softmax(Q @ K^T / scale)),如果输入特征的数值差异过大,点积后的结果会非常大。 - FP16 的局限性:
- FP16(半精度浮点数)的最大表示范围仅为 65504。
- 当 Attention 中的数值或中间梯度超过这个范围(Overflow)时,在 FP16 下就会直接变成
inf(无穷大)。 - 随后的计算(如
inf * 0或inf - inf)就会导致NaN的产生,并迅速传播到整个网络。
4. 解决方案
针对精度不够导致溢出的问题,主要有两种解决方案:
方案一:切换至 FP32
最直接的方法是关闭混合精度训练,全程使用 FP32 (float32),在我的框架中将AmpOptimWrapper改为OptimWrapper。
- 优点:数值稳定性极高,几乎不会出现溢出问题(。
- 缺点:显存占用显著增加。
方案二:使用 BF16
如果显卡支持 BF16 ,可以使用。
- 原理:BF16 牺牲了精度(尾数位)来换取与 FP32 相同的指数位。这意味着 BF16 拥有和 FP32 一样宽广的动态范围,不容易溢出。
- 优点:解决了 FP16 易溢出的问题,同时保持了较低的显存占用和较快的计算速度。