从计算角度解读 LLM 内部结构与推理过程
从计算视角深入解析大型语言模型(LLM)的内部结构,涵盖张量形状(Tensor Shape)、Embedding 层、多头注意力机制(MHA)及前馈神经网络(MLP)的输入输出数据结构与计算量分析。详细阐述了各算子的矩阵运算逻辑、维度变换过程及浮点运算次数估算,帮助读者建立对 LLM 推理过程中数据流转与计算负载的直观理解。文中还补充了 PyTorch 代码示例以辅助理解实际计算流程。

从计算视角深入解析大型语言模型(LLM)的内部结构,涵盖张量形状(Tensor Shape)、Embedding 层、多头注意力机制(MHA)及前馈神经网络(MLP)的输入输出数据结构与计算量分析。详细阐述了各算子的矩阵运算逻辑、维度变换过程及浮点运算次数估算,帮助读者建立对 LLM 推理过程中数据流转与计算负载的直观理解。文中还补充了 PyTorch 代码示例以辅助理解实际计算流程。

本文从计算的角度分析每个算子的输入数据结构、输出数据结构,以及统计对应的计算量,以此对 LLM 内部结构有一个更加深刻的认识。
张量(Tensor)是深度学习中用于表示数据的核心结构,它是一个多维数组,可以看作是标量、向量和矩阵的泛化,是深度学习框架如 TensorFlow 和 PyTorch 中的核心数据结构。在机器学习里面,一切数据皆 Tensor,在机器学习模型中,所有的数据都是以 Tensor 的格式流转的。
Tensor Shape 是理解计算过程的重要属性。 定义:张量的形状(Shape)是一个整数数组,描述了张量在每个维度上的大小。
import torch
vector = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
print(vector.shape)
# 输出:torch.Size([5])
tensor_3d = torch.tensor([[[3, 7, 2, 5], [1, 4, 6, 0], [9, 8, 3, 2]], [[4, 1, 7, 3], [5, 0, 2, 8], [6, 9, 4, 1]]])
print(tensor_3d.shape)
# 输出:torch.Size([2, 3, 4])
词表 Embedding 是将词汇表中的单词转换为稠密向量表示的过程,也即是将人类理解的分词词语转换成计算机理解的向量数据。对应到每个模型,每一个大语言模型都有一个词表,这个词表表示了该 LLM 能够输出的不同词语的数量。例如 Llama-2-7b 的词表大小是 128256。
在 LLM 推理过程中,LLM 的第一步,就是将我们输入的文字做分词,并通过 Embedding 转成向量数据。
[B, S], [V, H][B, S, H]其中:
Embedding 过程主要是查表操作,就是把每个字拿去和词表 [V, H] 去做匹配,拿到对应的 Hidden Tensor。其中 Shape = [V, H] 的数据就是预训练得到的,Shape = [B, S] 是输入数据。一般来说 B=1,但是在 LLM 后端处理并发的时候,可能把多个人的 Prompt 同时作为输入,这个时候 B 就会大于 1。
在 MHA 部分,将会有几个算子需要重点关注的:
Normalization 计算的作用是将输入数据规范化到标准分布,以稳定神经网络的计算过程。
[B, S, H][B, S, H]Q/K/V Projection 计算的作用是 Attention,将输入 Q 分别和 Q、K、V 预训练的权重进行运算。其中 QKV 的 Shape 都是 [B,S,H],权重 Shape 是 [H,H]。
[B, S, H], [H, H][B, S, H]这里特别说明下,如果输入的 Tensor 有两个,那表达的意思就是两个 Tensor 进行矩阵乘法,这种运算,要求第一个 Tensor 的最后一维和第二个 Tensor 的倒数第二维一致,得到的结果就是消掉这两维度的结果。
Q/K matmul 是计算 Attention 分数的关键步骤,这里引入一个新的变量 a,它表示多头注意力的头数,由 H 进行分组 a 得到的;同时 K 进行了一个转置,所以 Q、K 的 Shape 有以下变化:
Q: [B, S, H] ⇒ [B, a, S, H/a]
K: [B, S, H] ⇒(转置) [B, H, S] ⇒(多头) [B, a, H/a, S]
[B,a, S, H/a], [B, a, H/a, S][B, a, S, S]同理,[B,a, S, H/a] 的 Q Shape 和 [B, a, H/a, S] 的 K Shape 矩阵相乘,消掉 Q 的 S 维和 K 的 H/a 维,最终得到 [B, a, S, S] Shape 的输出。
[B, a, S, S][B, a, S, S]V Shape 同样做了多头注意力的 a 分组,所以 V 的 Shape 会变成:
[B, S, H] ⇒ [B, a, S, H/a]
[B, a, S, S], [B, a, S, H/a][B, a, S, H/a]O Projection 和 Q/K/V Projection 差不多,不同的是 O Projection 的输入是 SV Matmul 出来的,Shape 不一样。所以 O 的输入会有一个 Reshape 操作:
Reshape([B, a, S, H/a]) => [B, S, H]
[B, S, H], [H, H][B, S, H]从 MHA 算子来看,输入和输出的 Tensor Shape 都是 [B, S, H],包括后面的 MLP 也是,在设计上是挺规整的。在 MHA 内部,主要实现的就是注意力公式的计算。我们通过拆分计算过程,更清晰地理解了 MHA 的内部细节。
有了前面 MHA 的基础之后,MLP 就比较好理解了。MLP 层通常包含以下组件:
重点还是 Gate/Up/Down Projection。
在 MLP 层,有一个新的参数 ffn_dim,它是 MLP 内部的隐藏层大小,和前面介绍的 Hidden Size 不一样。
[B, S, H], [H, ffn_dim][B, S, ffn_dim]输入 [B, S, H] 和 [H, ffn_dim],消掉两个 H 维之后,得到结果 [B, S, ffn_dim]。Gate/Up Projection 过程相当于一个放大的计算逻辑,把 Hidden Size 的隐藏层大小放大到 ffn_dim 的隐藏层大小,扩大数据特征。
Down Projection 做的事相反的计算逻辑,缩小的计算操作,把 ffn_dim 缩放到 Hidden Size 大小,返回原形。
[B, S, ffn_dim], [ffn_dim, H][B, S, H]从计算过程可以看到 MLP 做的主要工作就是一个放大和缩小,把 Hidden Size 放大到 ffn_dim,再缩小到 Hidden Size,这个线性变换的过程有利于提取语义特征,捕捉计算过程的语义信息,增强表达能力。不过,在这里也需要说明一下,不同模型的 MLP 层是有些不一样的,Llama 模型是引入了 Gate 机制,对于更加一般的模型来说,MLP 所做的事情大概就是一个标准的两层全连接网络结构。
可以看到,不管是进入 Transformer Block 还是出来,数据的 Shape 都是 [B, S, H],而这样的数据计算过程还需要经过 Layers 层,每层的计算过程都是一样的,不同的是预训练的权重不一样,也即是 [H, H] 矩阵和 [ffn_dim, H] 矩阵。
从 Transformer Block 出来之后,还需要将 [B, S, H] 反映射成词语,这里和 Embedding 做的事情是类似的但是是相反的,通常涉及一个 Linear 层和一个 Softmax 或 Logits 采样。
本文从计算角度,详细解析了大型语言模型(LLM)的内部结构,并且分析了各个算子的输入输出数据结构及其计算量。结合前面的介绍,相信大家对 LLM 的内部结构已经非常清晰明了。在实际部署中,理解这些计算细节有助于进行显存优化、算子融合以及量化加速等工程实践。
为了更直观地理解上述计算过程,以下是一个简化的 Attention 机制 PyTorch 实现片段:
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(query, key, value, mask=None):
# query, key, value shape: [B, a, S, H/a]
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, value)
return output
通过此代码可验证各阶段的 Tensor Shape 变化及计算逻辑。

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online