从Mistral Nemo到Large2 核心技术详解

从Mistral Nemo到Large2 核心技术详解

从Mistral Nemo到Large2 核心技术详解

作者:Kevin吴嘉文,新加坡管理大学 信息技术硕士
原文:https://zhuanlan.zhihu.com/p/711294388

在本文中,梳理了 Mistral 系列模型(Mistral 7B, Mixtral 8x7B,Mixtral 8x22B,Mistral Nemo, Mistral Large 2)的关键信息,包括它们的主要特点、亮点以及相关资源链接。

Mistral 7B

官方博客:https://mistral.ai/news/announcing-mistral-7b/
mistral 7B 论文:https://arxiv.org/abs/2310.06825

Mistral 7B模型的亮点包括:

Sliding Window Attention

Mistral 采用的 window size 为 4096,而后一共有 32 层layer,那么采用 SWA 之后,理论上在进行 attention 的时候,理论上可以收集到约 131K tokens 的信息。(虽然论文里提到的 window size 是 4096,但 官方提供的 huggingface 上的权重[1] 中 max_position_embeddings 为 32768,且在新一点的版本中,比如 mistral-7b-instruct-v0.2[2] ,都不采用 sliding window 了)

www.zeeklog.com  - 从Mistral Nemo到Large2 核心技术详解

由于代用了固定的 attention 窗口大小,因此我们只需要一个大小为 W=window size 的 cache ,在计算第 i 个 token 的 cache 的时候,只需要覆盖 cache 中 i mod M 位置上的 hidden state 即可。

参考 huggingface 的 mistral 实现,Sliding window attention 通过 attention_mask 来控制:

# huggignface mistral attn mask 实现
def _update_causal_mask(
self,
    attention_mask: torch.Tensor,
    input_tensor: torch.Tensor,
    cache_position: torch.Tensor,
    past_key_values:Cache,
):
# ... 省略部分无关代码
    past_seen_tokens = cache_position[0]if past_key_values isnotNoneelse0
    using_static_cache = isinstance(past_key_values,StaticCache)
    using_sliding_window_cache = isinstance(past_key_values,SlidingWindowCache)

    dtype, device = input_tensor.dtype, input_tensor.device
    min_dtype = torch.finfo(dtype).min
    sequence_length = input_tensor.shape[1]
# SlidingWindowCache
if using_sliding_window_cache:
        target_length = max(sequence_length,self.config.sliding_window)
# StaticCache
elif using_static_cache:
        target_length = past_key_values.get_max_length()
# DynamicCache or no cache
else:
        target_length =(
            attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length +1
)

if attention_mask isnotNoneand attention_mask.dim()==4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max()!=0:
raiseValueError('Custom 4D attention mask should be passed in inverted form with max==0`')
        causal_mask = attention_mask
else:
        causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
        exclude_mask = torch.arange(target_length, device=device)> cache_position.reshape(-1,1)
ifself.config.sliding_window isnotNone:
ifnot using_sliding_window_cache or sequence_length >self.config.sliding_window:
                exclude_mask.bitwise_or_(
                    torch.arange(target_length, device=device)
<=(cache_position.reshape(-1,1)-self.config.sliding_window)
)
        causal_mask *= exclude_mask
        causal_mask = causal_mask[None,None,:,:].expand(input_tensor.shape[0],1,-1,-1)
if attention_mask isnotNone:
            causal_mask = causal_mask.clone()# copy to contiguous memory for in-place edit
if attention_mask.dim()==2:
                mask_length = attention_mask.shape[-1]
                padding_mask = causal_mask[:,:,:,:mask_length]+ attention_mask[:,None,None,:]
                padding_mask = padding_mask ==0
                causal_mask[:,:,:,:mask_length]= causal_mask[:,:,:,:mask_length].masked_fill(
                    padding_mask, min_dtype
)

return causal_mask

GQA (Grouped Query Attention)

Paper:GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
Abs:https://arxiv.org/abs/2305.13245
www.zeeklog.com  - 从Mistral Nemo到Large2 核心技术详解

grouped-query attention 指出,Multi-Query Attention[3] 提高了推理速度的同时,却可能极大地降低回复质量。因此根据上图,GQA 在推理速度和质量之间作了权衡。

以下为 GQA 文中的实验结果,值得注意的是论文中使用原 MHA checkpoint 转换为 GQA 权重后,还进行了额外的预训练:

www.zeeklog.com  - 从Mistral Nemo到Large2 核心技术详解

此外 Mistral,Llama2 的部分模型使用 GQA 时,采用的 kv head 数量似乎都是 8。

为什么现在大家都在用 MQA 和 GQA?[4] 文中提到 MQA 和 GQA 能获得巨大加速的一个点在于:GPU 内存强的限制。由于 MQA 和 GQA 都降低了内存中数据的读取量,减少了计算单元的等待时间,因此推理速度的提高比想象中的要快更多。

Mixtral 8*7B

论文:https://arxiv.org/abs/2401.04088
huggingface 模型权重:https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1
官方博客:https://mistral.ai/news/mixtral-of-experts/
huggingface 模型代码:https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
混合专家模型基础(推荐):https://huggingface.co/blog/zh/moe

官方给出的评分来看,mixtral 8*7 和 GPT3.5 有的一比。

• 发布时间:23年12月

• 模型大小:8 个 expert MLP 层,一共45B 大小。

• 训练:除了预训练外,Mixtral MOE 后续还开源了一个经过 SFT + DPO 微调的版本。

• 模型效果:

www.zeeklog.com  - 从Mistral Nemo到Large2 核心技术详解

• 架构:Mixtral 的 MOE 架构类似于,在 MoE 模型中,只有 FFN 层被视为独立的专家,而模型的其他参数是共享的。大致参数为:

www.zeeklog.com  - 从Mistral Nemo到Large2 核心技术详解

参考 huggingface 中的 mixtral 和 mistral 实现对比,差异在于 mixtral 中将传统 transformer decoder layer 中的 FFN 替换为了 block_sparse_moe。

www.zeeklog.com  - 从Mistral Nemo到Large2 核心技术详解

主要逻辑

G(x)=Softmax(TopK(x⋅Wgate))final hidden states=∑i=0n−1G(x)i⋅Ei(x)

其中 Ei(x) 为专家对应的网络,具体展示为下面 huggingface 实现中的 MixtralBlockSparseTop2MLP。mixtral 中采用了 8 个 expert,每次推理使用选取 top 2 的 expert 进行推理。比如输入一句话 你好,今天,那么我们每个 token 都会选出 top 2 的 expert 来负责这个 token 的预测,因此在推理 你好,今天 时,有概率所有 expert 都会参与到计算当中,具体可以参考 MixtralSparseMoeBlock 的实现。

www.zeeklog.com  - 从Mistral Nemo到Large2 核心技术详解

mixtral 论文中提到专家分配在不同主题(如ArXiv论文、生物学和哲学文档)中没有明显的模式,只有在DM数学中显示出边际上的差异,这可能是由于其数据集的合成性质和有限的自然语言覆盖范围所致。router 在某些句法结构上表现出一定的结构化行为(比如 python 的 self 等),同时连续标记通常被分配给相同的专家。

huggingface 中的 mixtral 核心代码

class MixtralDecoderLayer(nn.Module):
def __init__(self, config:MixtralConfig, layer_idx:int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)

self.block_sparse_moe =MixtralSparseMoeBlock(config)
self.input_layernorm =MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm =MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
        hidden_states: torch.Tensor,
        attention_mask:Optional[torch.Tensor]=None,
# 此处省略参数 ..
)->Tuple[torch.FloatTensor,Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:

        residual = hidden_states
        hidden_states =self.input_layernorm(hidden_states)
        hidden_states, self_attn_weights, present_key_value =self.self_attn(
# 此处省略参数 
)
        hidden_states = residual + hidden_states
        residual = hidden_states
        hidden_states =self.post_attention_layernorm(hidden_states)

# Mixtral 将原本的 hidden_states = self.FFN(hidden_states) 替换为了:
        hidden_states, router_logits =self.block_sparse_moe(hidden_states)

        hidden_states = residual + hidden_states
        outputs =(hidden_states,)

return outputs

huggingface 中 block_sparse_moe 的实现(省略部分次要代码):

class MixtralSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok

self.gate = nn.Linear(self.hidden_dim,self.num_experts, bias=False)
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config)for _ in range(self.num_experts)])

self.jitter_noise = config.router_jitter_noise

def forward(self, hidden_states: torch.Tensor)-> torch.Tensor:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        router_logits =self.gate(hidden_states)# (batch * sequence_length, n_experts)

        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights,self.top_k, dim=-1)
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
        routing_weights = routing_weights.to(hidden_states.dtype)

        final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2,1,0)

# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
            expert_layer =self.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)

            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
# current_state: shape (n_i, hidden_dim)
# 所有 current_state 的长度 n 总和为 batch * sequence_length
            current_hidden_states = expert_layer(current_state)* routing_weights[top_x, idx,None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits

其中:MixtralBlockSparseTop2MLP 长这样:

class MixtralBlockSparseTop2MLP(nn.Module):
def __init__(self, config:MixtralConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size

self.w1 = nn.Linear(self.hidden_dim,self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim,self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim,self.ffn_dim, bias=False)

self.act_fn = ACT2FN[config.hidden_act]

def forward(self, hidden_states):
        current_hidden_states =self.act_fn(self.w1(hidden_states))*self.w3(hidden_states)
        current_hidden_states =self.w2(current_hidden_states)
return current_hidden_states

推理部分的话,根据模型参数量 45B 来推理的话,如果用 fp16 的话推理的话,得需要至少 90GB 以上的显存,如果用 4 bit的话,30GB 显存就够了。量化的生成速度,可以参考这个 redis[5] 中的评论,大致为 :

推理精度设备速度 tokens/s
Q4_K_M单卡 4090 + 7950X3D20
Q4_K_M2 x 309048.26

如果有 100+GB 以上显存,可以用 vllm 快速搭建测试 api:

docker run --gpus all \
    -e HF_TOKEN=$HF_TOKEN -p 8000:8000 \
    ghcr.io/mistralai/mistral-src/vllm:latest \
    --host 0.0.0.0 \
    --model mistralai/Mixtral-8x7B-Instruct-v0.1 \
    --tensor-parallel-size 2 # 100+GB 显存 \
    --load-format pt # needed since both `pt` and `safetensors` are available

Nvidia TensorRT-LLM[6] 博客中,记录了 Mixtral 8*7B 的吞吐量测试(input and output sequence lengths of 128):

www.zeeklog.com  - 从Mistral Nemo到Large2 核心技术详解

input and output sequence lengths of 128

文中没有给出当 sequence lengths 最大时候的吞吐量,但根据上图数据,可以猜测 2个 H100 部署 8*7B 正常服务用户时,平均吞吐量应该可以大于 7500Tokens/秒,根据 H100 的功耗计算电费成本的话,生成 1M token 需要耗约为 0.02 度电。

Mixtral 8*22B

官方博客:https://mistral.ai/news/mixtral-8x22b/
huggingface 开源模型:https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1

• 架构:架构与 mixtral 8*7B 架构一样,在 huggingface 中使用的都是MixtralForCausalLM ,但 22B 的各方面参数大一点,比较特别的是 context window 从 32k 升级到了 65k, vocab_size 也更大一些。

• 支持 function calling,不过好像没有透露具体的 function calling 训练细节。

• 数学和 coding 能力明显超越 llama2 70B。

• 似乎对中文的支持不是很好。

www.zeeklog.com  - 从Mistral Nemo到Large2 核心技术详解

Mistral 团队开源的模型,都比较注重 coding 和 math 的能力,Mixtral 系列的模型在这方便表现也是比较好:

www.zeeklog.com  - 从Mistral Nemo到Large2 核心技术详解

Mistral Nemo

官方博客:https://mistral.ai/news/mistral-nemo/
huggingface 模型权重:https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407

Mistral Nemo 使用的也是 MistralForCausalLM 架构,与 mistral 7B 的差别为:Mistral Nemo 的 hidden_size 从 4096 变为 5120;max_position_embeddings 变为 1024000,num_hidden_layers 增加到 40, vocab_size 增加到 131072,不用 sliding window。

此外,Mistral Nemo 支持 function calling,采用了 Tekken 作为 tokenizer,比 SentencePiece 更高效(压缩率更高,官方描述是~30% more efficient at compressing,不确定是哪个方面的 efficient)

NVIDIA 在这个博客[7]中提到:Mistral Nemo 采用这样的设计,是为了能够适配单个NVIDIA L40S、NVIDIA GeForce RTX 4090或NVIDIA RTX 4500 GPU。模型采用 Megatron-LM[8] 训练,用了 3,072 个 H100 80GB 。

但光采用 FP16 加载整个 Mistral Nemo 就需要花 23 GB 显存,要是要跑满整个 context window size,除了量化外,还是得需要采用 offload 或者其他方法来推理

不过 mistral 官方把 12 B 的模型和其他 8B 的模型对比,感觉好像不太公平:

www.zeeklog.com  - 从Mistral Nemo到Large2 核心技术详解

Mistral Large 2

官方博客:https://mistral.ai/news/mistral-large-2407/
huggingface 模型权重:https://huggingface.co/mistralai/Mistral-Large-Instruct-2407

Mistral Large 2,参数量 123B,主打多语言以及 coding 能力。采用与 mistral 7B 一样的架构,huggingface 中同样使用 MistralForCausalLM;比较值得注意的是 context window size 为 131072,不用 sliding window。同样支持 function call。

Llama 3.1 刚出不久,就拿 Mistral Large 2 和别人来对比:

www.zeeklog.com  - 从Mistral Nemo到Large2 核心技术详解

在代码能力上,Mistral large 2 比 llama 3.1 平均效果更好。

www.zeeklog.com  - 从Mistral Nemo到Large2 核心技术详解

除了 coding 和数学外,在MT Bench 的评分也比 llama 3.1 高,平均生成的回复长度比 llama 3.1 要短

www.zeeklog.com  - 从Mistral Nemo到Large2 核心技术详解

同时,中文能力相对上一代 mistral large 有大步幅提升:

www.zeeklog.com  - 从Mistral Nemo到Large2 核心技术详解
引用链接

[1] huggingface 上的权重: https://huggingface.co/mistralai/mathstral-7B-v0.1/blob/main/config.json
[2] mistral-7b-instruct-v0.2: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/blob/main/config.json
[3] Multi-Query Attention: https://arxiv.org/pdf/1911.02150.pdf
[4] 为什么现在大家都在用 MQA 和 GQA?: https://zhuanlan.zhihu.com/p/647130255
[5] 这个 redis: https://www.reddit.com/r/LocalLLaMA/comments/18jslmf/tokens_per_second_mistral_8x7b_performance/?rdt=57036
[6] TensorRT-LLM: https://developer.nvidia.com/blog/achieving-high-mixtral-8x7b-performance-with-nvidia-h100-tensor-core-gpus-and-tensorrt-llm/?ncid=so-twit-928467/
[7] 这个博客: https://blogs.nvidia.com/blog/mistral-nvidia-ai-model/
[8] Megatron-LM: https://github.com/NVIDIA/Megatron-LM

www.zeeklog.com  - 从Mistral Nemo到Large2 核心技术详解

包包算法笔记

包大人的大模型、深度学习、机器学习笔记。

152篇原创内容

公众号

相关文章

Read more

深入理解 Proxy 和 Object.defineProperty

在JavaScript中,对象是一种核心的数据结构,而对对象的操作也是开发中经常遇到的任务。在这个过程中,我们经常会使用到两个重要的特性:Proxy和Object.defineProperty。这两者都允许我们在对象上进行拦截和自定义操作,但它们在实现方式、应用场景和灵活性等方面存在一些显著的区别。本文将深入比较Proxy和Object.defineProperty,包括它们的基本概念、使用示例以及适用场景,以帮助读者更好地理解和运用这两个特性。 1. Object.defineProperty 1.1 基本概念 Object.defineProperty 是 ECMAScript 5 引入的一个方法,用于直接在对象上定义新属性或修改已有属性。它的基本语法如下: javascript 代码解读复制代码Object.defineProperty(obj, prop, descriptor); 其中,obj是目标对象,prop是要定义或修改的属性名,descriptor是一个描述符对象,用于定义属性的特性。 1.2 使用示例 javascript 代码解读复制代码//

By Ne0inhk

Proxy 和 Object.defineProperty 的区别

Proxy 和 Object.defineProperty 是 JavaScript 中两个不同的特性,它们的作用也不完全相同。 Object.defineProperty 允许你在一个对象上定义一个新属性或者修改一个已有属性。通过这个方法你可以精确地定义属性的特征,比如它是否可写、可枚举、可配置等。该方法的使用场景通常是需要在一个对象上创建一个属性,然后控制这个属性的行为。 Proxy 也可以用来代理一个对象,但是相比于 Object.defineProperty,它提供了更加强大的功能。使用 Proxy 可以截获并重定义对象的基本操作,比如访问属性、赋值、函数调用等等。在这些操作被执行之前,可以通过拦截器函数对这些操作进行拦截和修改。因此,通过 Proxy,你可以完全重写一个对象的默认行为。该方法的使用场景通常是需要对一个对象的行为进行定制化,或者需要在对象上添加额外的功能。 对比 以下是 Proxy 和 Object.defineProperty 的一些区别对比: 方面ProxyObject.defineProperty语法使用 new Proxy(target,

By Ne0inhk