大语言模型框架-Megatron-LM源码分析

大语言模型框架-Megatron-LM源码分析

大语言模型框架-Megatron-LM源码分析

原创 MLOps社区  2023年11月11日 11:44北京

Megatron-LM是NVIDIA开源的大语言模型框架,是很多披露的大语言模型的训练使用的源头框架,很多公司基于其二次开发新的语言模型系统,例如Megatron-LM-DeepSpeed。

Megatron核心解决问题就是提供多种分布式切分并行策略,让大语言模型能够部署在多卡分布式环境下。本文将针对,张量并行,流水并行,数据并行的实现展开源码分析。MoE我们可以当成一种特殊的稀疏化结构,就不在本章进行介绍。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

图片来自DeepSpeed(本文不介绍ZeRO,感兴趣读者可参考相关论文)

1 张量并行

张量并行分为行切和列切并行(指的是对输入矩阵切法),具体读者可以参考Megatron论文,其实现方式是继承实现Linear层,进而实现其中的并行策略,只需要替换模型中的Linear即可,后面我们也会看到MoE也是这种实现技巧。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

图来源Megatron-LM

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

world_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, world_size)通过这部分代码进行并行partition划分,worldsize是配置的tensorparallel的卡数,将完整input切成这么多份数,在每个执行这个代码的rank进行权重创建。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

如果输入不是并行切分好的,通过scatter去拿这部分权重对应的输入数据。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

scatter[图来源PyTorch官网]

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

forward实现是配合异步allreduce进而将计算和comm通信并发执行。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

async_grad_allreduce (bool required): Do the allreduce of input
        gradients asyncronously with the computation of weight
        gradients. If sequence_parallel is True, this must be
        False, as no all reduce is performed.是对在BP阶段输入的gradients是否进行异步计算

这个linear前向是标准的torch matmul,除非sequence有并行设置才会进行一定的allgather通信。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

之前异步都被RowLinear配置false,所以BP核心是fused kernel。且可以使用低精度16bit的内核。  gradient_accumulation_fusion (bool): If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA
        extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install APEX with
        --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\"
        ". Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion.
        Defaults to False.


www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

如果不选fuse kernel则执行执行矩阵乘完成BP反向传播计算。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

有意思的是Apex库也用的这个linear实现,库之前二次开发逐渐成为当前常用方式

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

这个fusekernel最终还是调用的cublas的gemm

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

FP16使用的BF16,且也是cublas gemm,并可以利用tensor core的加速。也有版本可以选择    at::Half* A,类型的bit 16类型内核。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

下面是上面的kernel优化目的Gradient accumulation fusion的介绍。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

所以C矩阵在kernel的输出是32 bit,输入是两个16bit矩阵保证了输出累计的数据精度。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

接下来我们再看列切分配置,其类似行切分。好处是省去下一步的GEMM之前的allreduce通信,所以attention和mlp的第一层gemm megatron选择列切,之后再行切下一个阶段gemm。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

其中gather_output由模型设计者决定,进而和后面的层进行配合,看输出是当前层是否需要聚合还是不需要聚合收集。这点是比较有意思,当前属于人工硬编码选择配置好策略。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

初始化决定列切切分维度,这个world size 通过tensor parallel的输入shell进行配置。

一般配置是8路,也就是张量并行在一个8卡的server,如果用户是4卡server则配置4即可。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析
www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

和Row并行的区别是用户配置是否进行一次all gather output

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

allgather原因是列切完,我们产生了两块局部结果。如下文所示是两个partition 1 和 2。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析
www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

all gather [图来源PyTorch官网]

2 流水并行

其核心代码在p2p_communication中实现有4个重要参数。先沟通改下需要传递的tensor shape,"""Communicate tensor shapes between stages. Used to communicate
    tensor shapes before the actual tensor communication happens.
    This is required when the sequence lengths across micro batches
    are not uniform.

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

下面是真正的传递核心,使用的P2POp,isend传递是异步传递,同时通过函数去确定相邻的rank,这样写这个函数就不用管拓扑了。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

然后将这些操作符做批量通信发送,一批完成和自己上游和下游的send recv异步通信

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

调用在_communicate进行

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

当前方式是相当于有了microbatch的发送方式,也就是既有下个microbatch的前向,也有当前batch的反向,一批次做异步通信。适合整体都已经运行起来了,已经不是第一个batch的场景。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

这是在backward中调用的上面的API

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

流水并行的切分靠的是这个函数获取自己的这个rank到底是哪个model chunk,相当于静态切分好自己是属于哪块,类似张量并行。当前并行策略都是可以考虑这种分配方式,静态编译好partition和rank映射关系,启动后获取这个关系决定自己的通信方式。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

前向传播通过以下函数进行p2p通信。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

底层还是通过_communicate调用实现类似上面的bp过程。当前相当于只需要给下个rank send所以tensor_send_prev为空。如果多个microbatch配置是第一个batch前向,或者没配置多个microbatch场景使用。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

3 数据并行

数据并行一般有几点优化:将梯度通信和BP计算进行overlap,同时可以使用低精度做梯度聚合,将grad组成成小桶聚合。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析
www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

图来自PyTorch,桶聚合相当于batching通信张量

通过hook注册梯度更新事件

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析
www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析
www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

Overlap的核心是当bucket的 paramer累计到都有gradient,触发allreduce同步。BP计算该做自己的继续做,产生好的gradient,这部分只要ready就同时触发allreduce与进行的BP就无关了,但是没产生的gradient还不能进行计算。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

grad_buffer中是核心逻辑。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析
www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

allreduce,图片来源PyTorch

核心逻辑通过GradBuffer进行聚合成连续buffer,再拆解。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

本质是allreduce并可以选择是否是异步。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

在finish中不同等待异步allreduce

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

此处聚合wait所有的handler同步通信。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

 config.finalize_model_grads_func = finalize_model_grads
在没有流水并行下,执行刚才的同步通信等待。

www.zeeklog.com - 大语言模型框架-Megatron-LM源码分析

AI Infra96

AI Infra · 目录

上一篇大语言模型内核源码分析-4 Paging推理内核下一篇大语言模型内核源码分析-4IO-Awareness内核

Read more

【贪心算法】day1

【贪心算法】day1

📝前言说明: * 本专栏主要记录本人的贪心算法学习以及LeetCode刷题记录,按专题划分 * 每题主要记录:(1)本人解法 + 本人屎山代码;(2)优质解法 + 优质代码;(3)精益求精,更好的解法和独特的思想(如果有的话);(4)这个贪心算法正确性的证明 * 文章中的理解仅为个人理解。如有错误,感谢纠错 🎬个人简介:努力学习ing 📋本专栏:C++刷题专栏 📋其他专栏:C语言入门基础,python入门基础,C++学习笔记,Linux 🎀ZEEKLOG主页 愚润泽 你可以点击下方链接,进行其他贪心算法题目的学习 点击链接开始学习贪心day1贪心day2贪心day3贪心day4贪心day5贪心day6贪心day7贪心day8贪心day9贪心day10 也可以点击下面连接,学习其他算法 点击链接开始学习优选专题动态规划递归、搜索与回溯贪心算法 题单获取→ 【贪心算法】题单汇总 题目 * 贪心算法导论 * 860. 柠檬水找零 * 优质解 * 证明 * 2208. 将数组和减半的最少操作次数

By Ne0inhk
极致性能的服务器Redis之Hash类型及相关指令介绍

极致性能的服务器Redis之Hash类型及相关指令介绍

目录 1. Hash介绍 2. hset 3. hget 3. hdel 5. hkeys 6. hvals 编辑 7. hgetall  8. hexists 9. hmget 10. hlen 11. hsetnx 12. hincrby 13. hincrbyfloat 1. Hash介绍 Redis 哈希类型是键值对的集合,字段与值均支持字符串、数字等类型,适合建模用户信息、配置项等对象类数据。其支持单字段 / 多字段的增删改查、字段存在性判断、值自增自减等原子操作,且底层通过压缩列表或哈希表优化存储,空间利用率高、查询效率快,是 Redis 中存储结构化数据的核心类型之一。 在Redis中因为本身就是按照哈希的KV结构来进行存储的,所以当我们想要使用Redis里面的哈希的时候,实际上是哈希的哈希,在后者中,

By Ne0inhk
排序算法的速度美学:快速排序深度漫游

排序算法的速度美学:快速排序深度漫游

目录 一、快速排序思想 二、Hoare版本 1、Hoare版本介绍 2、编码实操 3、时间复杂度分析 4、有序情况优化 4.1 随机选keyi 4.2 三数取中 小贴士: 5、稳定性分析 三、挖坑法 1、挖坑法介绍 2、编码实操 四、lomuto前后指针版本 1、前后指针版本介绍 2、编码实操 3、小区间优化 五、迭代版本(非递归) 1、递归的缺陷 2、非递归思路 3、编码实操 六、三路划分 1、三路划分思想 2、

By Ne0inhk
Flutter 三方库 conduit_password_hash 的鸿蒙化适配指南 - 实现企业级安全密码加盐哈希、支持 Argon2, PBKDF2 与 BCrypt 算法集成

Flutter 三方库 conduit_password_hash 的鸿蒙化适配指南 - 实现企业级安全密码加盐哈希、支持 Argon2, PBKDF2 与 BCrypt 算法集成

欢迎加入开源鸿蒙跨平台社区:https://openharmonycrossplatform.ZEEKLOG.net Flutter 三方库 conduit_password_hash 的鸿蒙化适配指南 - 实现企业级安全密码加盐哈希、支持 Argon2, PBKDF2 与 BCrypt 算法集成 前言 在进行 Flutter for OpenHarmony 的全栈开发时,用户的账户安全是压倒一切的需求。尤其是在构建鸿蒙端侧的本地认证服务或配套的 Dart 服务端时,绝不能以明文存储密码。conduit_password_hash 是一个源自 Conduit 框架的高性能加密库,它提供了多种符合工业安全标准的哈希算法。本文将探讨如何在鸿蒙端利用该库构建牢不可破的密码保护体系。 一、原理解析 / 概念介绍 1.1 基础原理 conduit_password_hash 采用了“慢哈希(Slow

By Ne0inhk