【LLM】LLaMA架构(RMSNorm+ KV cache+Rotary Positional Encodings+门控FFN+MoE)
文章目录
- 一、LLaMA架构
- 二、重要组成部分
一、LLaMA架构
1. 基本介绍
《LLaMA: Open and Efficient Foundation Language Models》论文地址:https://arxiv.org/abs/2302.13971
LLaMA并非一个全新的、从零开始设计的架构。它巧妙整合并验证了多个当时业界公认的最高效的Transformer改进方案。例如:
- 预归一化:使用RMSNorm层进行归一化,提高训练稳定性。
- SwiGLU激活函数:替代传统的Relu,提升模型表达能力。
- 旋转位置编码:使用RoPE,能更好地处理长序列。
2. 技术路线图:对比Transformer

从对比图可以看出,LLaMA架构在位置编码和自注意力机制上做了较大的调整。Transformer-Decoder中的位置编码不再是给input embedding做改进,而是给经过QK做编码。为了提升计算效率,LLaMA的自注意机制采用的是KV缓存:加入多头为8个头,此时我们要求将KV分别生成两个矩阵并将这两个矩阵保存,针对不同的Q,都是使用缓存的这两个KV来计算。
3. 思考
3.1 为什么LLaMA使用的是Transformer中的Decoder解码器?
- 任务需求的匹配:自回归生成。Decoder的在训练时使用因果注意力掩码(Causal Attention Mask) ,也称为前瞻掩码(Look-ahead Mask),在计算注意力权重矩阵时,会做一个时间窗静止当前单词对后续单词进行询问,在预测下一个单词的时候,只利用之前的所有词。这本是就是一个天然、高效的文本生成器结构。Encoder模型(BERT)主要擅长理解型任务(eg. 文本分类、文本匹配、语义相似度),而不是生成类任务。这个
架构效率:在先沟通的计算预算下,一个巨型纯Decoder模型可能比一个同等规模的Encoder-Decoder模型在生成任务上表现更好,因为所有参数都聚焦于同一个目标。

Bert是Encoder-only模型,GPT是Decoder-only模型,上图是具体的对比。
3.2 为什么RoPE只给Q和K做位置编码?
- 注意力分数的计算原理: 注意力机制的核心是“注意力分数”矩阵(Q·K T)计算的是查询与键的匹配度或相关性。这个相关性必须包含位置信息,因为一个词与另一个词的相关性高度依赖于它们之间的相对位置(例如,“apple”在“eat”前面和后面含义完全不同)。因此,位置编码的核心目的是为了让模型在计算“谁应该关注谁”(Q·K^T)时,能够感知到位置关系。
- 注意力分数的计算原理:RoPE是一种相对位置编码,它通过旋转矩阵将位置信息注入到Q和K中,得到的点积(Q_i· K_ j ^ T)结果只依赖于相对位置 i - j,而不是绝对位置 i 或 j 。如果我们将RoPE也应用到V上,可能会扭曲V本身所携带的语义信息,且是没必要的,因为softmax分数已经包含了位置感知,这个分数决定了V的权重,旋转V并不会改变“该关注哪个token的决策”。
二、重要组成部分
1. Embedding
在PyTorch,nn.Embedding层是用于处理离散数据的关键组件,主要功能是将输入的整数索引映射到连续的高维向量空间中,即将索引转化为嵌入向量。
import torch import torch.nn as nn # 定义Embedding层 embedding = nn.Embedding(10,3)# num_embeddings=10, embedding_dim=3# 输入索引 input_indices = torch.tensor([1,2,3])# 获取嵌入向量 output = embedding(input_indices)print(output)2. RMSNorm均方根层归一化
2.1 Layer Normalization 和 Batch Normalization 的区别
最重要的区别在于计算均值和方差的方向不同,LN在一次更新迭代中统计同一层内的所有神经元节点的输出分布(同一个样本下);BN是在一个Batch内统计某特定神经元的输出分布(跨样本)。

在NLP任务中会经常处理长度不同的句子,使用LN时可以不需要考虑其他样本的长度是否,如果按照Batch维度进行统计的话,会存在一定问题:为了让样本均衡,一般会对样本进行裁剪或者填补,里面一定有大量为0的特征值,因此在计算特征均值和方差肯定会受到影响。
2.2 LN与RMSNorm的区别

那么问题来了,为什么通过RMSNorm可以起到归一化的作用?
首先,要先回顾下归一化层的作用。归一化层是为了防止梯度爆炸/消失,实现手段是控制尺度,而非严格的中心化。RMS(x) 衡量的是向量的"典型幅度",类似于向量的L2范数(相差一个根号n因子)。经过RMSNorm后,输出向量的RMS值为1:RMS(RMSNorm(x))=1,这就强制输出的尺寸保有一致性。
import torch import torch.nn as nn classRMSNorm(nn.Module):def__init__(self, dim:int, eps:float=1e-8):super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim))def_norm(self, x):# 计算 RMS 归一化因子 rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True)+ self.eps)return x * rms defforward(self, x):# 保持计算精度为float,然后转换回输入类型 output = self._norm(x.float()).type_as(x)return output * self.weight # 正确使用方式 x = torch.randn(2,3,768) rmsnorm = RMSNorm(dim=768, eps=1e-6)# 创建实例 norm = rmsnorm(x)# 通过实例调用print("输出形状:", norm.shape)#输出: torch.Size([2, 3, 768])3. 旋转位置编码Rotary Positional Encodings

这段代码不用记了,记不住的···,知道旋转位置编码的原理即可。
#定义频率计算defprecompute_pos_cis(dim:int, max_position:int, theta:float=10000.0):#频率 freqs =1.0/(theta **(torch.arange(0, dim,2)[:(dim //2)].float()/ dim))#位置编码m m = torch.arange(max_position, device=freqs.device)#频率乘以位置编码、外积 freqs = torch.outer(m, freqs).float()# pos_cis = torch.polar(torch.ones_like(freqs), freqs)return pos_cis #将频率用于q、k矩阵defapply_rotary_emb(xq, xk, pos_cis):defunite_shape(pos_cis, x): ndim = x.ndim assert0<=1< ndim print(pos_cis.shape)print(x.shape[1])print(x.shape[-1])assert pos_cis.shape ==(x.shape[1], x.shape[-1]) shape =[d if i ==1or i == ndim -1else1for i, d inenumerate(x.shape)]return pos_cis.view(*shape) xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1],-1,2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1],-1,2)) pos_cis = unite_shape(pos_cis, xq_) xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3) xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)return xq_out.type_as(xq), xk_out.type_as(xk)4. Self-Attention(Grouped Multi-Query Attetion with KV cache)
4.1 基本概念
KV缓存(Key-Value Cache)是Transformer模型自回归生成任务中的一个重要加速技术。在文本生成任务中,每一步生成一个新的token,这个新的token跟之前所有的tokens一起重新输入到模型中,预测下一个token。对于每一步的生成,模型会重新计算所有tokens的注意力分数,这个过程是非常耗时的,因此,为了避免重复计算注意力层中的K和V,在生成后续token时,模型只需要计算token的Query,直接调用缓存中的Key和Vaule。
4.2 工作原理
KV缓存大部分时候适用于推理过程中。
- 初始化:
当模型开始时,模型计算输入序列的Key和Value,并将这些计算结果缓存起来,保存在内存中。大部分时候,每个注意力层都会有一对Key-Value缓存,这个缓存是自回归的每次循环中共用的。还有一种做法是:在多头注意力机制中,只保留一个头或者两个头以上的KV值,并共享给所有头使用。 - 生成过程:
当生成下一个token时,模型不需要重新计算前面已经生成的token的Key和Value始终保持更新,包含所有已生成的tokens。最终会包含所有生成序列的Key和Value。 - 更新缓存:
对于每一个生成步骤,模型还会将当前生成的token的Key和Value加入缓存,确保缓存中的Key和Value始终保持更新,能够包含所有已经生成的tokens。 - 加速效果:
由于每个生成步骤只需要计算当前token的Query,而不需要重新计算整个序列的Key和Value,这大大减少了计算量。随着序列长度增加,缓存的使用能够显著减少时间复杂度,使生成过程更快。
4.3 模型代码
4.3.1 分组查询注意力
多个查询头共享一组K、V头。
场景:假设有
- 查询头数(n_q_heads):8
- KV缓存头数量(n_kv_heads):2
- 每个头的维度(head_dim):64
那么我们需要让:
- 这里每个KV需要被8/2=4个查询头共享
defrepeat_kv(x:torch.Tensor, n_rep:int):''' :param x: tensor , shape (bs, slen, n_kv_heads, head_dim) :param n_rep: 重复次数 :return: ''' bs, slen, n_kv_heads, head_dim = x.shape #bs: 批次大小 (batch size)# slen: 序列长度 (sequence length)# n_kv_heads: KV 头的数量 (number of key-value heads)# head_dim: 每个头的维度大小 (dimension size of each head)if n_rep ==1:return x return( x[:,:,:,None,:].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads * n_rep, head_dim))4.3.2 注意力机制
import torch import torch.nn as nn from Config import LMConfig import torch.nn.functional as F import math from typing import Optional, Tuple from dataclasses import dataclass @dataclassclassLMConfig: dim:int=4096 n_heads:int=32 n_kv_heads: Optional[int]=None max_seq_len:int=2048 dropout:float=0.1 flash_attn:bool=Truedef__post_init__(self):if self.n_kv_heads isNone: self.n_kv_heads = self.n_heads #旋转位置编码defapply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor )-> Tuple[torch.Tensor, torch.Tensor]:""" 应用旋转位置编码到查询和键上 """# 将复数转换为实数和虚数部分 xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1],-1,2)) xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1],-1,2))# 应用旋转 freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2)# 添加batch和head维度 xq_out = torch.view_as_real(xq_complex * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_complex