在生产环境测试 Stable Diffusion 3.5 FP8 时,发现了一些违反直觉的现象。原本认为 FP8 量化会牺牲质量,实测结果却显示几乎无损,同时推理速度翻倍、显存占用减半。这背后的技术逻辑值得深入探讨。
FP8 量化:一个被误解的技术选择
为什么扩散模型'不怕'降精度?
通常认为量化是用精度换性能,但在 SD 3.5 FP8 的实际应用中,这一理解过于简单。
扩散模型的去噪过程本质上是迭代纠错的过程。在 50 步采样中,单步的小误差会被后续步骤自动修正。例如,FP32 到 FP16 肉眼基本看不出区别,而 FP16 到 FP8 的关键在于如何量化,而非精度本身。
真正的突破点是将'全局统一量化'改为'张量自适应量化'。不同层、不同数据采用不同的策略。
E4M3 和 E5M2:两种量化格式的分工
SD 3.5 采用了混合策略:
- E4M3 (4-bit 指数 + 3-bit 尾数):用于权重和激活值,保证数值精度。
- E5M2 (5-bit 指数 + 2-bit 尾数):用于梯度累积,需要更大的动态范围。
示例代码展示了简化版的自适应量化逻辑:
def adaptive_fp8_quantize(tensor, stage):
if stage == "attention":
# 注意力矩阵对异常值敏感
pass
elif stage == "conv":
# 卷积层可以更激进
pass
elif stage == "time_embed":
# 时间嵌入保持 FP16
pass
核心思想是哪里需要精度就给精度,哪里需要范围就给范围,避免一刀切。
三个容易被忽略的技术细节
MMDiT 架构带来的量化挑战
SD 3.5 将文本和图像放在同一个 Transformer 中处理(MMDiT 架构)。由于文本 token 是离散的,图像 latent 是连续的,数值分布完全不同。
解决办法是为文本和图像分别计算缩放因子:
class MMDiTBlockFP8(nn.Module):
def forward(self, img_latent, txt_latent):
# 文本和图像分别缩放
img_scale = compute_scale(img_latent, percentile=99.9)
txt_scale = compute_scale(txt_latent, percentile=99.99)
img_fp8 = quantize(img_latent / img_scale, "E4M3")
txt_fp8 = quantize(txt_latent / txt_scale, )
.attention(img_fp8 * img_scale, txt_fp8 * txt_scale)


