大语言模型词表裁剪方法与实践
如何裁剪大语言模型(LLM)的词表以降低参数量。以 Bloom 模型为例,阐述了词表裁剪的核心原理,包括 Embedding 层和 LM Head 权重的映射与复制。提供了完整的 Python 代码实现,展示了如何验证新词表是否为原词表子集、如何更新模型配置以及如何进行一致性检查。实验结果显示,裁剪后模型参数量可显著减少,且生成结果保持一致。文章还补充了关于特殊 Token 保留、分词策略一致性及 OOV 处理等最佳实践建议。

如何裁剪大语言模型(LLM)的词表以降低参数量。以 Bloom 模型为例,阐述了词表裁剪的核心原理,包括 Embedding 层和 LM Head 权重的映射与复制。提供了完整的 Python 代码实现,展示了如何验证新词表是否为原词表子集、如何更新模型配置以及如何进行一致性检查。实验结果显示,裁剪后模型参数量可显著减少,且生成结果保持一致。文章还补充了关于特殊 Token 保留、分词策略一致性及 OOV 处理等最佳实践建议。

在多语言大语言模型(LLM)的应用场景中,原始模型的词表(Vocabulary)往往非常庞大,包含了全球多种语言的字符和子词。然而在实际的下游任务中,我们可能只需要支持特定的语言,例如仅中文和英文。此时,对词表进行裁剪是一个有效的优化手段。
通过裁剪词表,我们可以显著减少模型的参数量,降低显存占用,同时保留模型在目标语言上的性能表现。本文将基于 Bloom 模型为例,详细介绍如何进行词表裁剪的操作流程、代码实现及注意事项。
大语言模型的词表大小直接决定了嵌入层(Embedding Layer)和语言模型头(LM Head)的维度。假设原词表大小为 $V_{old}$,隐藏层维度为 $H$,则嵌入层的参数量为 $V_{old} \times H$,LM Head 的参数量为 $H \times V_{new}$(通常 $V_{new} = V_{old}$)。
当我们将词表裁剪至 $V_{new}$ 时,我们需要从原模型中提取对应于新词表中 token 的权重向量,并构建新的模型结构。核心逻辑如下:
vocab_size 字段。首先,我们需要准备两个关键组件:
bigscience/bloom-560m)。tokenizer.json 或 vocab.txt)。注意:新词表必须是原词表的子集,即所有新词表中的 Token 都必须存在于原词表中,否则会导致索引越界错误。
以下是一个完整的词表裁剪类实现,基于 Hugging Face Transformers 库。
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
class VocabularyPruner(object):
def check(self, old_model_name_or_path, new_model_name_or_path, text):
"""
检查模型裁剪后,生成结果是否一致
"""
max_length = 20
# 使用老模型对文本编码
old_model = AutoModelForCausalLM.from_pretrained(old_model_name_or_path)
old_tokenizer = AutoTokenizer.from_pretrained(old_model_name_or_path)
old_input_ids = old_tokenizer(text, return_tensors='pt').input_ids
old_output = old_model.generate(old_input_ids, max_length=max_length)
old_output_text = old_tokenizer.batch_decode(old_output)
print('old_output:{}'.format(old_output_text))
# 使用新模型对文本编码
new_model = AutoModelForCausalLM.from_pretrained(new_model_name_or_path)
new_tokenizer = AutoTokenizer.from_pretrained(new_model_name_or_path)
new_input_ids = new_tokenizer(text, return_tensors='pt').input_ids
new_output = new_model.generate(new_input_ids, max_length=max_length)
new_output_text = new_tokenizer.batch_decode(new_output)
print('new_output:{}'.format(new_output_text))
if old_output_text == new_output_text:
print('output is same, succeed to prune.')
else:
print('output is not same, fail to prune.')
def update_embeddings(self, model, new2old_token_id, new_embeds, new_lm_head):
raise NotImplementedError
def prune(self, model_name_or_path, new_tokenizer_name_or_path, save_path, new_name_or_path=None):
# 创建输出目录
if not os.path.exists(save_path):
os.makedirs(save_path)
# 加载新词表。如果是中文,就是中文的词表
new_tokenizer = AutoTokenizer.from_pretrained(new_tokenizer_name_or_path)
# 加载原词表。一般为多语言模型的词表
old_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
# 检查新词表是否为原词表的子集
old_vocab = old_tokenizer.vocab
new_vocab = new_tokenizer.vocab
for token in tqdm(new_vocab.keys()):
if token not in old_vocab:
raise Exception('{} not exist'.format(token))
print('new_tokenizer is subset of old_tokenizer')
# 获得新词表中每个 token_id 到原词表的 token_id 的映射
new2old_token_id = {}
for token, token_id in tqdm(new_vocab.items()):
old_token_id = old_vocab[token]
new2old_token_id[token_id] = old_token_id
# 加载多语言模型
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype='auto')
# 计算原模型的参数量
old_params = sum(p.numel() for p in model.parameters())
print("Total params of original model: %.2fM" % (old_params / 1e6))
# 对于新词表中的每个 token,取出其对应的权重,复制到新模型中
vocab_size = len(new_tokenizer)
hidden_size = model.config.hidden_size
new_embeds = torch.nn.Embedding(vocab_size, hidden_size, dtype=model.dtype)
new_lm_head = torch.nn.Linear(in_features=hidden_size, out_features=vocab_size, bias=False, dtype=model.dtype)
# 更新词表权重
self.update_embeddings(model, new2old_token_id, new_embeds, new_lm_head)
model.config.__dict__['vocab_size'] = vocab_size
if new_name_or_path is not None:
model.config.__dict__['_name_or_path'] = new_name_or_path
# 计算新模型的参数量
new_params = sum(p.numel() for p in model.parameters())
print("Total params of new model : %.2fM" % (new_params / 1e6))
print('词表缩小为原来的:{}%'.format(round(len(new_tokenizer) / len(old_tokenizer), 4)*100))
print('模型参数量缩小为原来的:{}%'.format(round(new_params / old_params, 4)*100))
model.save_pretrained(save_path)
new_tokenizer.save_pretrained(save_path)
针对 Bloom 模型架构,需要特殊处理 Embedding 层的位置。
class BloomVocabularyPruner(VocabularyPruner):
def update_embeddings(self, model, new2old_token_id, new_embeds, new_lm_head):
for token_id, old_token_id in tqdm(new2old_token_id.items()):
# 复制 Embedding 权重
new_embeds.weight.data[token_id] = model.transformer.word_embeddings.weight.data[old_token_id]
# 复制 LM Head 权重
new_lm_head.weight.data[token_id] = model.lm_head.weight.data[old_token_id]
# 替换模型中的权重对象
model.transformer.word_embeddings.weight = new_embeds.weight
model.lm_head.weight = new_lm_head.weight
# 需要进行裁剪的模型路径
model_name_or_path = 'bigscience/bloom-560m'
# 自己制作的词表的路径
new_tokenizer_name_or_path = 'YeungNLP/bloom-396m-zh'
save_path = './pruned_bloom_zh'
pruner = BloomVocabularyPruner()
# 裁剪
pruner.prune(model_name_or_path, new_tokenizer_name_or_path, save_path)
# 检查裁剪的模型与原模型是否一致
pruner.check(model_name_or_path, save_path, text='长风破浪会有时')
执行上述脚本后,控制台将输出参数变化及一致性检查结果:
100%|██████████| 46145/46145 [00:00<00:00, 1309531.65it/s]
new_tokenizer is subset of old_tokenizer
100%|██████████| 46145/46145 [00:00<00:00, 1120687.88it/s]
Total params of original model: 559.21M
100%|██████████| 46145/46145 [00:01<00:00, 41641.55it/s]
Total params of new model : 396.82M
词表缩小为原来的:18.41%
模型参数量缩小为原来的:70.96%
old_output:['长风破浪会有时,直挂云帆济沧海。愿你,在人生的旅途中,能遇见最美的风景,遇见最美的自己。</s>']
new_output:['长风破浪会有时,直挂云帆济沧海。愿你,在人生的旅途中,能遇见最美的风景,遇见最美的自己。</s>']
output is same, succeed to prune.
可以看到,模型参数量从 559.21M 降至 396.82M,缩减了约 29%,而生成结果完全一致。
在进行词表裁剪时,除了代码实现外,还需注意以下几点以确保模型效果稳定:
<pad>, <eos>, <unk>, <bos> 等特殊符号的索引在原模型和新模型中保持一致。如果这些符号被移除或索引变更,可能导致推理时的解码错误或序列截断异常。<unk> 或类似兜底机制,以防遇到新词表之外的罕见词汇导致 OOV(Out Of Vocabulary)问题。通过上述步骤,您可以有效地针对特定语言环境优化大语言模型,在保持性能的同时提升部署效率。

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online