概述
在大型语言模型(LLM)的微调实践中,大多数开发者关注的是如何调用 API 或配置训练脚本,而对于底层 Loss 计算的具体逻辑往往缺乏深入了解。Loss 计算的准确性直接决定了模型能否正确收敛以及生成质量的高低。本文重点剖析 GLM-4-9B 开源模型在微调时的 Loss 计算机制,特别是多轮对话场景下的 Mask 策略与标签处理逻辑。
数据格式规范
GLM-4 系列模型采用标准的 ChatML 风格进行对话交互。在微调数据集中,每条样本通常包含一个 messages 列表,其中每个元素代表一次交互的角色和内容。支持的角色包括 system(系统提示)、user(用户提问)、assistant(助手回复)以及 observation(工具调用观察结果)。
标准的数据结构示例如下:
[
{
"messages": [
{
"role": "system",
"content": "你是一个有用的助手。"
},
{
"role": "user",
"content": "你好,请介绍一下大模型。"
},
{
"role": "assistant",
"content": "大模型是指参数量巨大的深度学习模型..."
}
]
}
]
在实际训练中,这些文本会被 Tokenizer 转换为整数 ID 序列,以便模型进行概率预测。
Loss 计算核心逻辑
在 PyTorch 等深度学习框架中,CrossEntropyLoss 默认会将 ignore_index 设为 -100。这意味着在计算 Loss 时,所有标签为 -100 的位置不会参与梯度更新。GLM-4 的 Loss 计算正是基于这一机制,通过构建特定的 Label Mask 来决定哪些 Token 需要被预测。
核心处理函数 process_batch 负责将原始对话数据转换为模型可接受的输入 ID 和 Label ID。以下是关键代码逻辑分析:
def process_batch(
batch: Mapping[str, Sequence],
tokenizer: PreTrainedTokenizer,
max_input_length: int,
max_output_length: int,
) -> dict[str, list]:
batched_conv = batch['messages']
batched_input_ids = []
batched_labels = []
# 遍历批次中的每一条对话历史
for conv in batched_conv:
# 初始化特殊 token,如开始符等
input_ids = [151331, 151333]
loss_masks = [False, False]
# 遍历对话中的每一个角色消息
for message in conv:
message = process_message(message)
# 设置 mask 掩码:只有 system、user、observation 不参与 Loss 计算
# assistant 角色的内容才需要计算 Loss,即让模型学习如何回复
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
# 获取当前消息的 Token ID 表示
# apply_chat_template 会自动处理模板拼接,[2:] 用于去除前缀特殊 token
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
# 生成对应的 Loss Mask
new_loss_masks = [loss_mask_val] * len(new_input_ids)
# 拼接 Input 和 Mask
input_ids += new_input_ids
loss_masks += new_loss_masks
# 追加结束符 EOS
input_ids.append(tokenizer.eos_token_id)
# 调整 Mask 长度以匹配 Input,开头两个特殊 token 不计算 Loss
loss_masks = [False, *loss_masks]
labels = []
for input_id, mask in zip(input_ids, loss_masks):
if mask:
# 如果该位置需要计算 Loss,Label 等于 Input ID
labels.append(input_id)
else:
# 如果不需要计算 Loss,Label 设为 -100 (ignore_index)
labels.append(-100)
# 截断策略:限制最大长度
max_length = max_input_length + max_output_length + 1
batched_input_ids.append(input_ids[:max_length])
batched_labels.append(labels[:max_length])
return {'input_ids': batched_input_ids, 'labels': batched_labels}
关键点解析:
- 角色掩码策略:代码中明确判断
message['role']。对于system、user和observation,loss_mask_val设为False,最终生成的 Label 为-100。这确保了模型在训练时只关注assistant的回复部分,而忽略用户的指令和系统的预设规则。这是监督微调(SFT)的标准做法,防止模型学习到错误的因果方向。 - Token 对齐:Input IDs 和 Labels 必须严格一一对应。当
mask为True时,Label 使用实际的 Token ID;当mask为False时,Label 使用-100。这种对齐保证了 CrossEntropyLoss 能正确跳过无效位置。 - EOS 标记处理:在对话末尾添加
tokenizer.eos_token_id,并在 Mask 中标记其是否参与计算。通常 EOS 标记本身也作为预测目标的一部分,但在某些实现中可能根据具体需求调整。 - 长度截断:为了防止显存溢出,代码实现了
max_input_length + max_output_length的截断逻辑。超出部分将被丢弃,这在长上下文训练中尤为重要。
数据集加载与训练流程
上述 process_batch 函数通常在 DataLoader 的 collate_fn 或 Dataset 的 transform 中被调用。典型的加载流程如下:
from functools import partial
# 加载模型和分词器
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
data_manager = DataManager(data_dir, ft_config.data_config)
# 定义数据处理函数
process_func = partial(
process_batch,
tokenizer=tokenizer,
max_input_length=ft_config.max_input_length,
max_output_length=ft_config.max_output_length,
)
# 创建训练数据集
train_dataset = data_manager.get_dataset(
Split.TRAIN,
process_func,
batched=True,
)
print('train_dataset:', train_dataset)
在此流程中,functools.partial 用于预填充参数,确保每个 Batch 都能正确应用 Loss Mask 逻辑。batched=True 表明返回的数据已经是按 Batch 组织的张量或列表。
常见问题与最佳实践
- Loss 不下降的原因:如果训练初期 Loss 没有明显下降,首先检查 Label 中是否错误地将所有位置都设为了 -100,或者 Input/Label 长度不一致导致维度错误。此外,确认
apply_chat_template是否正确处理了特殊 Token。 - 多轮对话的累积效应:在多轮对话中,前面的 User 和 Assistant 消息都会成为后续预测的上下文。虽然前面的 Assistant 回复在后续步骤中会被视为 Input,但其对应的 Label 依然遵循'仅预测 Assistant 回复'的原则。这意味着模型需要记住之前的对话历史来生成下一轮回复。
- 与其他框架的对比:相比于早期的 ChatGLM 版本,GLM-4 的多轮对话 Loss 计算更加标准化。InternLM、XTuner、Firefly 等主流微调框架均已原生支持此类 Mask 策略。在使用 LoRA 或其他 PEFT 技术时,需确保 Mask 逻辑与全量微调保持一致。
- 显存优化:在处理长序列时,建议开启 Gradient Checkpointing 或使用 Flash Attention。同时,合理设置
max_input_length和max_output_length的比例,避免过长的 Prompt 占用过多显存。
总结
GLM-4-9B 开源模型的微调关键在于理解其对话模板与 Loss 掩码的配合机制。通过精确控制 loss_mask,我们可以确保模型仅在助手回复的 Token 上进行梯度更新,从而高效地学习人类指令遵循能力。掌握这一底层逻辑,有助于开发者在遇到训练异常时快速定位问题,并根据实际需求调整数据处理策略,提升微调效果。


