万字综述:全面梳理 FP8 训练和推理技术 -- 附录

万字综述:全面梳理 FP8 训练和推理技术 -- 附录

万字综述:全面梳理 FP8 训练和推理技术 -- 附录

原创 AI闲谈  2024年07月21日 20:02 北京

一、背景

在上一篇文章()中我们通过几篇论文具体介绍了 FP8 的发展历程以及在 AI 模型训练和推理中的应用。然而由于篇幅的原因,部分内容并没有具体展开,这篇文章中我们对其补充,并结合代码来介绍。

二、FP8-LM:FP8 梯度和 AllReduce 通信

我们在介绍  [2310.18313] FP8-LM: Training FP8 Large Language Models 时提到其有 FP8 通信,FP8 优化器,以及 FP8 分布式并行训练 3 个方面的优化,但没有具体介绍 FP8 通信是怎么实现的(这个部分比较晦涩),这里进行补充。

如果有 N 个 GPU 要进行梯度聚合,直接使用 FP8 梯度进行梯度聚合会导致精度降低。

如下图所示为 pre-scaling 方案,在求和之前先分别除以 N,此时容易出现 Underflow 的问题:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

如下图所示为 post-scaling,其主要区别是在求和之后再除以 N,此时容易出现 Overflow 的问题:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

作者提出了相应的优化方案,假设有 4 个 GPU 要进行梯度的聚合,分别有 FP16 的梯度:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

假设最终聚合后的 FP16 梯度为:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

每一个 FP16 的 Tensor 要转换为 FP8,都是由(FP8 的 Tensor,Scale 值)共同表示。上述 FP16 的 Tensor 对应的 FP8 表示为:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

如果想要直接使用 FP8 进行 AllReduce,则需要有一个全局的 Scale 值,否则 Reduce 就不等价了。对应的全局 scale 变量如下所示,其中使用 min 可以避免 Overflow:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

复原后的 FP8 梯度如下所示(等价于 FP16 梯度使用相同的 Scale):

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

通过 AllReduce 聚合后的 FP8 梯度可以表示为:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

所以,相当于 ge 的 FP8 表示如下所示,这里有个 Trick,聚合后的梯度并没有除以 N,而是让 Scale 值乘以 N:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

作者论文中为什么会介绍 μ 值(Auto Scaling)呢?应该是想要控制

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

都尽量在 FP8 的范围内(PS:开源代码里并未实现),比如:

如果发现

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

中有超过 0.001% 超过了 FP8 表示的最大值(PS:这里是因为 Delayed Scaling 导致?如果每个 Tensor 都使用 Just-in-time Scaling 是不是就都不会超过?),则下一次迭代的时候让

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

都变为原来的 1/2,降低 Overflow 的风险。

如果在接下来的 1000 次迭代中都没有出现超过 0.001% 的情况,则让

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

的指数增加 2,降低 Underflow 的风险。

三、FP8 梯度聚合代码实现

上述对应的代码位于 :msamp/nn/distributed.py#L124-L193。如下图所示:

第一步:collect 相应的 Gradient,并初始化 Meta(维护 Scale,amax 相关信息):

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

第二步:分别求每个 Gradient 的 amax(绝对值的最大值),然后 AllReuce 操作获得全局的最大值,这里求 max 是因为 Scale 一般为 FP8 可表示的最大值 / amax,也就等价于求 Scale 的 min:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

第三步:根据全局最大 amax,FP8 可表示的最大值等计算全局 Scale 值,其中的 world_size 也就是 N:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

需要说明的是,这里本来应该是 Auto Scaling,也就是对应上述 μ 值的部分,然而作者实际上并没有集成 Auto Scaling,而是使用了经验值 1/sqrt(N),以缓解 Underflow 和 Overflow。可以参考:https://github.com/Azure/MS-AMP/issues/117:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

第四步:将 FP16 的 Tensor 转换为 FP8 的 Tensor:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

如上图所示存在 Gradient 除以 N 的操作,然而,因为 Gradient 为 ScalingTensor,所以实际除的时候是操作的 Scale 值,如下图所示:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

第五步:对 FP8 Gradient 进行 AllReduce 操作:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

四、FP8 训练和推理过程

Transformer 模型中最主要的操作就是矩阵乘,也就是 Linear Layer,如下图所示(来自 [2309.17224] Training and inference of large language models using 8-bit floating point)为一个 Linear 操作的伪代码,其核心思路就是在 FP8 矩阵乘之前需要转换的 Tensor 转换为 FP8 类型,如下图红框所示;然后在矩阵乘之后 Unscale 回 FP16,如下图蓝框所示:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

而在推理阶段只用离线的的对 Weight 转换一次,Forward 的时候只需对 x 进行相应的 Scale 操作:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

五、Scaling 实现方式

在之前的文章中我们介绍过 Scaling 的实现方式,这里在简单概括一下:

Static Scaling:提前离线计算好每个 Tensor 的 Scale,然后一直不变。为了保证精度,这种方式通常用于推理阶段的 Weight,如下图所示:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

Dynamic Scaling:每次都实时计算每个 Tensor 的 Scale,好处是比较精确,不足是这些计算全部是同步的,如下图所示:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

Delayed Scaling:会保存一些之前的多个 Scale 值,计算当前 Tensor 时根据以前的多个 Scale 预估当前 Scale,然后进行 Scaling 操作,同时异步的计算当前的 Scale 值,但是其实现也比较复杂,无状态变为有状态,如下图所示:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

在 Pytorch 的 FP8 实现中(https://github.com/pytorch-labs/float8_experimental/tree/main),早期的测试表明 Delayed Scaling 反而比 Dynamic Scaling 慢,当然也非常接近:

www.zeeklog.com  - 万字综述:全面梳理 FP8 训练和推理技术 -- 附录

六、参考链接

https://arxiv.org/abs/2310.18313

https://github.com/Azure/MS-AMP/blob/main/msamp/nn/distributed.py#L124-L193

https://github.com/Azure/MS-AMP/issues/117

https://arxiv.org/abs/2309.17224

https://github.com/pytorch-labs/float8_experimental/tree/main

Read more

深入理解 Proxy 和 Object.defineProperty

在JavaScript中,对象是一种核心的数据结构,而对对象的操作也是开发中经常遇到的任务。在这个过程中,我们经常会使用到两个重要的特性:Proxy和Object.defineProperty。这两者都允许我们在对象上进行拦截和自定义操作,但它们在实现方式、应用场景和灵活性等方面存在一些显著的区别。本文将深入比较Proxy和Object.defineProperty,包括它们的基本概念、使用示例以及适用场景,以帮助读者更好地理解和运用这两个特性。 1. Object.defineProperty 1.1 基本概念 Object.defineProperty 是 ECMAScript 5 引入的一个方法,用于直接在对象上定义新属性或修改已有属性。它的基本语法如下: javascript 代码解读复制代码Object.defineProperty(obj, prop, descriptor); 其中,obj是目标对象,prop是要定义或修改的属性名,descriptor是一个描述符对象,用于定义属性的特性。 1.2 使用示例 javascript 代码解读复制代码//

By Ne0inhk

Proxy 和 Object.defineProperty 的区别

Proxy 和 Object.defineProperty 是 JavaScript 中两个不同的特性,它们的作用也不完全相同。 Object.defineProperty 允许你在一个对象上定义一个新属性或者修改一个已有属性。通过这个方法你可以精确地定义属性的特征,比如它是否可写、可枚举、可配置等。该方法的使用场景通常是需要在一个对象上创建一个属性,然后控制这个属性的行为。 Proxy 也可以用来代理一个对象,但是相比于 Object.defineProperty,它提供了更加强大的功能。使用 Proxy 可以截获并重定义对象的基本操作,比如访问属性、赋值、函数调用等等。在这些操作被执行之前,可以通过拦截器函数对这些操作进行拦截和修改。因此,通过 Proxy,你可以完全重写一个对象的默认行为。该方法的使用场景通常是需要对一个对象的行为进行定制化,或者需要在对象上添加额外的功能。 对比 以下是 Proxy 和 Object.defineProperty 的一些区别对比: 方面ProxyObject.defineProperty语法使用 new Proxy(target,

By Ne0inhk