大语言模型框架-Megatron-LM源码分析

大语言模型框架-Megatron-LM源码分析

大语言模型框架-Megatron-LM源码分析

原创 MLOps社区  2023年11月11日 11:44 北京

Megatron-LM是NVIDIA开源的大语言模型框架,是很多披露的大语言模型的训练使用的源头框架,很多公司基于其二次开发新的语言模型系统,例如Megatron-LM-DeepSpeed。

Megatron核心解决问题就是提供多种分布式切分并行策略,让大语言模型能够部署在多卡分布式环境下。本文将针对,张量并行,流水并行,数据并行的实现展开源码分析。MoE我们可以当成一种特殊的稀疏化结构,就不在本章进行介绍。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

图片来自DeepSpeed(本文不介绍ZeRO,感兴趣读者可参考相关论文)

1 张量并行

张量并行分为行切和列切并行(指的是对输入矩阵切法),具体读者可以参考Megatron论文,其实现方式是继承实现Linear层,进而实现其中的并行策略,只需要替换模型中的Linear即可,后面我们也会看到MoE也是这种实现技巧。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

图来源Megatron-LM

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

world_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, world_size)通过这部分代码进行并行partition划分,worldsize是配置的tensorparallel的卡数,将完整input切成这么多份数,在每个执行这个代码的rank进行权重创建。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

如果输入不是并行切分好的,通过scatter去拿这部分权重对应的输入数据。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

scatter[图来源PyTorch官网]

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

forward实现是配合异步allreduce进而将计算和comm通信并发执行。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析
www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

async_grad_allreduce (bool required): Do the allreduce of input
        gradients asyncronously with the computation of weight
        gradients. If sequence_parallel is True, this must be
        False, as no all reduce is performed.是对在BP阶段输入的gradients是否进行异步计算

这个linear前向是标准的torch matmul,除非sequence有并行设置才会进行一定的allgather通信。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

之前异步都被RowLinear配置false,所以BP核心是fused kernel。且可以使用低精度16bit的内核。  gradient_accumulation_fusion (bool): If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA
        extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install APEX with
        --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\"
        ". Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion.
        Defaults to False.


www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

如果不选fuse kernel则执行执行矩阵乘完成BP反向传播计算。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

有意思的是Apex库也用的这个linear实现,库之前二次开发逐渐成为当前常用方式

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

这个fusekernel最终还是调用的cublas的gemm

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

FP16使用的BF16,且也是cublas gemm,并可以利用tensor core的加速。也有版本可以选择    at::Half* A,类型的bit 16类型内核。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

下面是上面的kernel优化目的Gradient accumulation fusion的介绍。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

所以C矩阵在kernel的输出是32 bit,输入是两个16bit矩阵保证了输出累计的数据精度。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

接下来我们再看列切分配置,其类似行切分。好处是省去下一步的GEMM之前的allreduce通信,所以attention和mlp的第一层gemm megatron选择列切,之后再行切下一个阶段gemm。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

其中gather_output由模型设计者决定,进而和后面的层进行配合,看输出是当前层是否需要聚合还是不需要聚合收集。这点是比较有意思,当前属于人工硬编码选择配置好策略。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

初始化决定列切切分维度,这个world size 通过tensor parallel的输入shell进行配置。

一般配置是8路,也就是张量并行在一个8卡的server,如果用户是4卡server则配置4即可。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析
www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

和Row并行的区别是用户配置是否进行一次all gather output

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

allgather原因是列切完,我们产生了两块局部结果。如下文所示是两个partition 1 和 2。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析
www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

all gather [图来源PyTorch官网]

2 流水并行

其核心代码在p2p_communication中实现有4个重要参数。先沟通改下需要传递的tensor shape,"""Communicate tensor shapes between stages. Used to communicate
    tensor shapes before the actual tensor communication happens.
    This is required when the sequence lengths across micro batches
    are not uniform.

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

下面是真正的传递核心,使用的P2POp,isend传递是异步传递,同时通过函数去确定相邻的rank,这样写这个函数就不用管拓扑了。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

然后将这些操作符做批量通信发送,一批完成和自己上游和下游的send recv异步通信

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

调用在_communicate进行

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

当前方式是相当于有了microbatch的发送方式,也就是既有下个microbatch的前向,也有当前batch的反向,一批次做异步通信。适合整体都已经运行起来了,已经不是第一个batch的场景。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

这是在backward中调用的上面的API

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

流水并行的切分靠的是这个函数获取自己的这个rank到底是哪个model chunk,相当于静态切分好自己是属于哪块,类似张量并行。当前并行策略都是可以考虑这种分配方式,静态编译好partition和rank映射关系,启动后获取这个关系决定自己的通信方式。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

前向传播通过以下函数进行p2p通信。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

底层还是通过_communicate调用实现类似上面的bp过程。当前相当于只需要给下个rank send所以tensor_send_prev为空。如果多个microbatch配置是第一个batch前向,或者没配置多个microbatch场景使用。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

3 数据并行

数据并行一般有几点优化:将梯度通信和BP计算进行overlap,同时可以使用低精度做梯度聚合,将grad组成成小桶聚合。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析
www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

图来自PyTorch,桶聚合相当于batching通信张量

通过hook注册梯度更新事件

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析
www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析
www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

Overlap的核心是当bucket的 paramer累计到都有gradient,触发allreduce同步。BP计算该做自己的继续做,产生好的gradient,这部分只要ready就同时触发allreduce与进行的BP就无关了,但是没产生的gradient还不能进行计算。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

grad_buffer中是核心逻辑。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析
www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

allreduce,图片来源PyTorch

核心逻辑通过GradBuffer进行聚合成连续buffer,再拆解。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

本质是allreduce并可以选择是否是异步。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

在finish中不同等待异步allreduce

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

此处聚合wait所有的handler同步通信。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

 config.finalize_model_grads_func = finalize_model_grads
在没有流水并行下,执行刚才的同步通信等待。

www.zeeklog.com  - 大语言模型框架-Megatron-LM源码分析

AI Infra96

AI Infra · 目录

上一篇大语言模型内核源码分析-4 Paging推理内核下一篇大语言模型内核源码分析-4IO-Awareness内核

Read more

深入理解 Proxy 和 Object.defineProperty

在JavaScript中,对象是一种核心的数据结构,而对对象的操作也是开发中经常遇到的任务。在这个过程中,我们经常会使用到两个重要的特性:Proxy和Object.defineProperty。这两者都允许我们在对象上进行拦截和自定义操作,但它们在实现方式、应用场景和灵活性等方面存在一些显著的区别。本文将深入比较Proxy和Object.defineProperty,包括它们的基本概念、使用示例以及适用场景,以帮助读者更好地理解和运用这两个特性。 1. Object.defineProperty 1.1 基本概念 Object.defineProperty 是 ECMAScript 5 引入的一个方法,用于直接在对象上定义新属性或修改已有属性。它的基本语法如下: javascript 代码解读复制代码Object.defineProperty(obj, prop, descriptor); 其中,obj是目标对象,prop是要定义或修改的属性名,descriptor是一个描述符对象,用于定义属性的特性。 1.2 使用示例 javascript 代码解读复制代码//

By Ne0inhk

Proxy 和 Object.defineProperty 的区别

Proxy 和 Object.defineProperty 是 JavaScript 中两个不同的特性,它们的作用也不完全相同。 Object.defineProperty 允许你在一个对象上定义一个新属性或者修改一个已有属性。通过这个方法你可以精确地定义属性的特征,比如它是否可写、可枚举、可配置等。该方法的使用场景通常是需要在一个对象上创建一个属性,然后控制这个属性的行为。 Proxy 也可以用来代理一个对象,但是相比于 Object.defineProperty,它提供了更加强大的功能。使用 Proxy 可以截获并重定义对象的基本操作,比如访问属性、赋值、函数调用等等。在这些操作被执行之前,可以通过拦截器函数对这些操作进行拦截和修改。因此,通过 Proxy,你可以完全重写一个对象的默认行为。该方法的使用场景通常是需要对一个对象的行为进行定制化,或者需要在对象上添加额外的功能。 对比 以下是 Proxy 和 Object.defineProperty 的一些区别对比: 方面ProxyObject.defineProperty语法使用 new Proxy(target,

By Ne0inhk