LLM 解码方式详解:贪心、束搜索与采样策略
在训练完大语言模型(LLM)后,推理阶段的解码策略直接决定了生成文本的质量、多样性和效率。理解不同的解码机制对于优化模型输出至关重要。
本文详细解析了大语言模型(LLM)推理阶段的解码策略。主要涵盖贪心搜索、集束搜索及多种采样方法(温度采样、Top-k、Top-p)。文章解释了各方法的数学原理、优缺点及适用场景,并提供了基于 HuggingFace Transformers 的实战代码示例。通过对比分析,指导开发者如何根据任务需求(如准确性或多样性)选择合适的解码参数,以优化模型生成效果。

在训练完大语言模型(LLM)后,推理阶段的解码策略直接决定了生成文本的质量、多样性和效率。理解不同的解码机制对于优化模型输出至关重要。
所谓解码,就是 LLM 的模型预测下一个 token 的过程。这是一个自回归(Autoregressive)的过程,即模型根据已生成的序列预测下一个词,并将该词加入序列继续预测。
以 PyTorch 为例,获取模型输出 outputs,其形状通常为 (batch_size, sequence_length, vocab_size):
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=torch.ones_like(decoder_input_ids).to(device)
)
# outputs[1] 通常是词汇表 logits (predictions)
lm_logits = outputs[1]
# lm_logits shape: (batch_size, sequence_length, vocab_size)
注意:LLM 每次输出的 outputs 其实包含了全部词汇表的预测概率。例如输入为'千问',第一次输出可能是'笑'的概率分布,第二次基于'千问笑'预测'料'的概率分布,以此类推。
无论哪种方式,核心逻辑都是基于当前状态下的概率分布来选择下一个词。以下是主流解码策略的详细解析。
特点:
代码示例:
def sample_greedy(p):
# p shape (vocab_size,)
return np.argmax(p)
优点:
缺点:
特点:
原理说明: 假设输入为"Once upon a time",设置 beam size k=2。第一个词可能选出"a"和"the"。继续推理时,"a cat"的概率为 0.50.4=0.2,而"the people"的概率为 0.30.7=0.21。因此"the people"因概率更高被保留。
优点:
缺点:
特点:
通过调整温度参数 $T$ 来改变概率分布的形状。公式为: $$P_{new}(i) = \frac{\exp(logits_i / T)}{\sum_j \exp(logits_j / T)}$$
优缺点:
仅从概率最高的 k 个词中采样,忽略其余低概率词。可结合温度采样使用。
优缺点:
根据累积概率动态选择候选词的集合 P。保留候选词的总概率至少达到 p(如 p=0.9)。
比 Top-k 更智能,能根据分布动态调整候选词的数量。当概率分布集中时(确定性任务),候选词少;分布分散时(多样性任务),候选词多。
优缺点:
以下是一个使用 Hugging Face Transformers 库实现不同解码策略的示例:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_name = "Qwen/Qwen-7B-Chat"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
input_text = "人工智能的未来是"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
# 1. Greedy Search
output_greedy = model.generate(**inputs, do_sample=False, max_new_tokens=50)
print("Greedy:", tokenizer.decode(output_greedy[0]))
# 2. Temperature Sampling
output_temp = model.generate(
**inputs,
do_sample=True,
temperature=0.7,
max_new_tokens=50
)
print("Temp 0.7:", tokenizer.decode(output_temp[0]))
# 3. Top-p Sampling
output_top_p = model.generate(
**inputs,
do_sample=True,
top_p=0.9,
max_new_tokens=50
)
print("Top-p 0.9:", tokenizer.decode(output_top_p[0]))
选择合适的解码策略取决于具体应用场景:
理解这些底层机制有助于开发者更好地调优模型参数,从而获得符合业务需求的生成效果。

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online