前言
在阅读本文前,建议熟悉 PyTorch 常用算子。以下列举核心算子及其功能:
torch.where(condition, x, y):根据条件选择张量元素。
condition:条件掩码。
x:条件为 True 时选择的值。
y:条件为 False 时选择的值。
Tensor.scatter_(dim, index, src):将源张量按索引写入输出张量。
torch.gather(input, dim, index):scatter 的反向操作,用于提取特定索引的值。
torch.sort(input, dim=-1, descending=False):对张量进行排序。
torch.softmax(input, dim=None):计算 Softmax 概率分布。
cumsum(input, dim, dtype=None):累加求和。
Tensor.masked_fill(mask, value):根据掩码填充指定值。
torch.topk(input, k, dim=None):返回 k 个最大值及其索引。
torch.multinomial(input, num_samples, replacement=False):根据概率分布抽取样本。
torch.div(input, other, rounding_mode=None):除法运算。
模型生成策略概述
大模型通常继承自 PreTrainedModel,预测时调用 GenerationMixin 的 generate 方法。模型生成回答主要涉及以下几种搜索与采样方法:
Contrastive Search
Contrastive Search 是一种改进的解码策略,旨在平衡生成的流畅性与多样性。它通过对比当前 token 与历史上下文的差异来避免重复,具体实现略。
Multinomial Sampling(多项式采样)
与总是选择概率最高的标记作为下一个标记的贪婪搜索不同,Multinomial Sampling 根据模型给出的整个词汇表的概率分布随机选择下一个标记。这增加了生成的随机性,有助于打破死循环。
只需将 do_sample 设为 True 即可启用:
outputs = model.generate(**inputs, do_sample=True, num_beams=1, max_new_tokens=100)
源码逻辑如下:
while True:
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
next_token_logits = outputs.logits[:, -1, :]
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
Beam Search 的实现
Beam Search 是对贪心策略的改进。思路是稍微放宽考察范围,在每个时间步保留 num_beams 个最优序列,而非仅保留分数最高的 1 个。当 num_beams=1 时,集束搜索退化为贪心搜索。
要实现 Beam Search,设置参数如下:
outputs = model.generate(**inputs, num_beams=5, max_new_tokens=50)
源码核心流程解读:
- 输入扩张:将 input_ids 扩张成
num_beams 份。
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
- 模型预测:获取下一层 logits。
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
- 获取最佳候选:使用
topk 选取每个 beam 下最好的 token。
next_token_logits = outputs.logits[:, -1, :]
next_token_scores = nn.functional.log_softmax(next_token_logits, dim=-1)
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
n_eos_tokens = len(eos_token_id) if eos_token_id else 0
next_token_scores, next_tokens = torch.topk(
next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True
)
- 状态更新与保存:处理 beam scorer 逻辑,保存中间结果。
beam_outputs = beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)
- 拼接结果:将新的 token 拼接到输入序列中。
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
- 最终化:当预测结束,从所有 beams 中选择最优路径。
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)
针对多 batch 场景,BeamSearchScorer 类负责管理多个假设路径。若数量超过设置的 num_beams,评分最差的路径会被剔除。
解码参数与重复惩罚机制
为了解决'复读机'问题(即模型陷入循环重复生成相同内容),通常需要调整以下解码参数:
Temperature(温度)
温度控制输出的随机性。温度为 0 时,模型倾向于确定性输出(类似贪婪搜索);温度越高,概率分布越平滑,随机性越大。
class TemperatureLogitsWarper:
def __call__(self, input_ids, scores):
scores = scores / self.temperature
return scores
Top-P (Nucleus Sampling)
动态设置 tokens 候选列表的大小。将可能性之和不超过特定值的 top tokens 列入候选名单。Top-p 通常设置为较高值(如 0.9),目的是限制可能被采样的低概率 token 的长度,同时保留高概率区域的多样性。
class TopPLogitsWarper:
def __call__(self, input_ids, scores):
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
sorted_indices_to_remove[..., -self.min_tokens_to_keep:] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
Top-K
允许其他高分 tokens 有机会被选中。这种采样引入的随机性有助于在很多情况下提升生成质量。Top-k 参数设置为 3 意味着只选择前三个 tokens 的概率分布进行采样。如果 k 和 p 都启用,则 p 在 k 之后起作用。
class TopKLogitsWarper:
def __call__(self, input_ids, scores):
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
Repetition Penalty(重复性惩罚)
重复性惩罚方法通过在模型推理过程中加入重复惩罚因子,降低已生成 token 的后续出现概率,从而有效抑制循环重复。
class RepetitionPenaltyLogitsProcessor:
def __call__(self, input_ids, scores):
score = torch.gather(scores, 1, input_ids)
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
return scores
总结
大模型生成中的'复读机'现象通常源于贪婪搜索策略或概率分布过于集中。通过上述技术组合可以有效缓解:
- Temperature 增加随机性,避免模型锁定在局部最优解。
- Top-K/Top-P 限制采样空间,过滤掉极低概率的噪声,同时保留合理多样性。
- Repetition Penalty 直接对历史已出现的 token 施加惩罚,从数学上降低其再次被选中的概率。
- Beam Search 虽然主要用于提升连贯性,但配合适当的早停策略也能减少无效重复。
在实际应用中,建议结合任务需求调整这些参数。例如,创意写作可提高 Temperature 并开启 Top-P,而代码生成或事实性问答则应降低 Temperature 并启用 Repetition Penalty 以确保准确性与唯一性。