跳到主要内容
极客日志极客日志面向AI+效率的开发者社区
首页博客GitHub 精选镜像工具UI配色美学隐私政策关于联系
搜索内容 / 工具 / 仓库 / 镜像...⌘K搜索
注册
博客列表
PythonAI算法

GLM-4-9B 开源模型微调 Loss 计算逻辑解析

综述由AI生成GLM-4-9B 开源模型微调过程中,Loss 计算逻辑直接影响模型收敛效果。深入解析了基于对话格式的微调数据构建方法,重点阐述了在 process_batch 函数中如何通过角色掩码区分系统提示、用户输入与助手回复的权重。通过设置特定角色的 loss_mask 为 False 并将对应标签设为 -100,确保模型仅对助手生成的内容进行预测训练。同时对比了与其他框架如 InternLM、XTuner 的实现差异,提供了关于截断策略和 EOS 标记处理的实践建议,帮助开发者理解底层计算机制以避免训练偏差。

魔尊发布于 2025/2/6更新于 2026/6/221 浏览
GLM-4-9B 开源模型微调 Loss 计算逻辑解析
概述

在大型语言模型(LLM)的微调实践中,大多数开发者关注的是如何调用 API 或配置训练脚本,而对于底层 Loss 计算的具体逻辑往往缺乏深入了解。Loss 计算的准确性直接决定了模型能否正确收敛以及生成质量的高低。本文重点剖析 GLM-4-9B 开源模型在微调时的 Loss 计算机制,特别是多轮对话场景下的 Mask 策略与标签处理逻辑。

数据格式规范

GLM-4 系列模型采用标准的 ChatML 风格进行对话交互。在微调数据集中,每条样本通常包含一个 messages 列表,其中每个元素代表一次交互的角色和内容。支持的角色包括 system(系统提示)、user(用户提问)、assistant(助手回复)以及 observation(工具调用观察结果)。

标准的数据结构示例如下:

[
  {
    "messages": [
      {
        "role": "system",
        "content": "你是一个有用的助手。"
      },
      {
        "role": "user",
        "content": "你好,请介绍一下大模型。"
      },
      {
        "role": "assistant",
        "content": "大模型是指参数量巨大的深度学习模型..."
      }
    ]
  }
]

在实际训练中,这些文本会被 Tokenizer 转换为整数 ID 序列,以便模型进行概率预测。

Loss 计算核心逻辑

在 PyTorch 等深度学习框架中,CrossEntropyLoss 默认会将 ignore_index 设为 -100。这意味着在计算 Loss 时,所有标签为 -100 的位置不会参与梯度更新。GLM-4 的 Loss 计算正是基于这一机制,通过构建特定的 Label Mask 来决定哪些 Token 需要被预测。

核心处理函数 process_batch 负责将原始对话数据转换为模型可接受的输入 ID 和 Label ID。以下是关键代码逻辑分析:

def process_batch(
        batch: Mapping[str, Sequence],
        tokenizer: PreTrainedTokenizer,
        max_input_length: int,
        max_output_length: int,
) -> dict[str, list]:
    batched_conv = batch['messages']
    batched_input_ids = []
    batched_labels = []
    
    # 遍历批次中的每一条对话历史
    for conv in batched_conv:
        # 初始化特殊 token,如开始符等
        input_ids = [151331, 151333]
        loss_masks = [False, False]
        
        # 遍历对话中的每一个角色消息
        for message in conv:
            message = process_message(message)
            # 设置 mask 掩码:只有 system、user、observation 不参与 Loss 计算
            # assistant 角色的内容才需要计算 Loss,即让模型学习如何回复
            loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
            
            # 获取当前消息的 Token ID 表示
            # apply_chat_template 会自动处理模板拼接,[2:] 用于去除前缀特殊 token
            new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
            
            # 生成对应的 Loss Mask
            new_loss_masks = [loss_mask_val] * len(new_input_ids)
            
            # 拼接 Input 和 Mask
            input_ids += new_input_ids
            loss_masks += new_loss_masks
        
        # 追加结束符 EOS
        input_ids.append(tokenizer.eos_token_id)
        # 调整 Mask 长度以匹配 Input,开头两个特殊 token 不计算 Loss
        loss_masks = [False, *loss_masks]
        
        labels = []
        for input_id, mask in zip(input_ids, loss_masks):
            if mask:
                # 如果该位置需要计算 Loss,Label 等于 Input ID
                labels.append(input_id)
            else:
                # 如果不需要计算 Loss,Label 设为 -100 (ignore_index)
                labels.append(-100)
        
        # 截断策略:限制最大长度
        max_length = max_input_length + max_output_length + 1
        batched_input_ids.append(input_ids[:max_length])
        batched_labels.append(labels[:max_length])
    
    return {'input_ids': batched_input_ids, 'labels': batched_labels}

关键点解析:

  1. 角色掩码策略:代码中明确判断 message['role']。对于 system、user 和 observation,loss_mask_val 设为 False,最终生成的 Label 为 -100。这确保了模型在训练时只关注 assistant 的回复部分,而忽略用户的指令和系统的预设规则。这是监督微调(SFT)的标准做法,防止模型学习到错误的因果方向。
  2. Token 对齐:Input IDs 和 Labels 必须严格一一对应。当 mask 为 True 时,Label 使用实际的 Token ID;当 mask 为 False 时,Label 使用 -100。这种对齐保证了 CrossEntropyLoss 能正确跳过无效位置。
  3. EOS 标记处理:在对话末尾添加 tokenizer.eos_token_id,并在 Mask 中标记其是否参与计算。通常 EOS 标记本身也作为预测目标的一部分,但在某些实现中可能根据具体需求调整。
  4. 长度截断:为了防止显存溢出,代码实现了 max_input_length + max_output_length 的截断逻辑。超出部分将被丢弃,这在长上下文训练中尤为重要。
数据集加载与训练流程

上述 process_batch 函数通常在 DataLoader 的 collate_fn 或 Dataset 的 transform 中被调用。典型的加载流程如下:

from functools import partial

# 加载模型和分词器
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
data_manager = DataManager(data_dir, ft_config.data_config)

# 定义数据处理函数
process_func = partial(
    process_batch,
    tokenizer=tokenizer,
    max_input_length=ft_config.max_input_length,
    max_output_length=ft_config.max_output_length,
)

# 创建训练数据集
train_dataset = data_manager.get_dataset(
    Split.TRAIN,
    process_func,
    batched=True,
)
print('train_dataset:', train_dataset)

在此流程中,functools.partial 用于预填充参数,确保每个 Batch 都能正确应用 Loss Mask 逻辑。batched=True 表明返回的数据已经是按 Batch 组织的张量或列表。

常见问题与最佳实践
  1. Loss 不下降的原因:如果训练初期 Loss 没有明显下降,首先检查 Label 中是否错误地将所有位置都设为了 -100,或者 Input/Label 长度不一致导致维度错误。此外,确认 apply_chat_template 是否正确处理了特殊 Token。
  2. 多轮对话的累积效应:在多轮对话中,前面的 User 和 Assistant 消息都会成为后续预测的上下文。虽然前面的 Assistant 回复在后续步骤中会被视为 Input,但其对应的 Label 依然遵循'仅预测 Assistant 回复'的原则。这意味着模型需要记住之前的对话历史来生成下一轮回复。
  3. 与其他框架的对比:相比于早期的 ChatGLM 版本,GLM-4 的多轮对话 Loss 计算更加标准化。InternLM、XTuner、Firefly 等主流微调框架均已原生支持此类 Mask 策略。在使用 LoRA 或其他 PEFT 技术时,需确保 Mask 逻辑与全量微调保持一致。
  4. 显存优化:在处理长序列时,建议开启 Gradient Checkpointing 或使用 Flash Attention。同时,合理设置 max_input_length 和 max_output_length 的比例,避免过长的 Prompt 占用过多显存。
总结

GLM-4-9B 开源模型的微调关键在于理解其对话模板与 Loss 掩码的配合机制。通过精确控制 loss_mask,我们可以确保模型仅在助手回复的 Token 上进行梯度更新,从而高效地学习人类指令遵循能力。掌握这一底层逻辑,有助于开发者在遇到训练异常时快速定位问题,并根据实际需求调整数据处理策略,提升微调效果。

目录

  1. 概述
  2. 数据格式规范
  3. Loss 计算核心逻辑
  4. 数据集加载与训练流程
  5. 加载模型和分词器
  6. 定义数据处理函数
  7. 创建训练数据集
  8. 常见问题与最佳实践
  9. 总结
  • 💰 8折买阿里云服务器限时8折了解详情
  • Magick API 一键接入全球大模型注册送1000万token查看
  • 🤖 一键搭建Deepseek满血版了解详情
  • 一键打造专属AI 智能体了解详情
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • 基于 openJiuwen 记忆库构建 AI 职业匹配智能体
  • Python 数据思维:元组基础
  • OpenClaw安全AI助理从零搭建实战教程
  • Go Map 底层原理深度解析
  • Kiro 工具实测:前端代码生成验证与调整
  • Motrix WebExtension 浏览器扩展配置指南
  • 基于 SpringBoot 的烟草商品在线采购与供应链管理系统设计
  • 本地 Docker 部署 Appsmith 及远程访问配置
  • 大模型未来技术演进与应用趋势深度解析
  • C++26 契约编程新特性:利用静态与动态检查提升代码健壮性
  • AndroidGen-Llama-3-70B:零标注自主操控安卓应用的大模型实践
  • OpenClaw AI Agent 架构原理与实战应用
  • 基于 Python Django/Flask 的体育户外服装商城设计与实现
  • 比迪丽 AI 绘画多设备协同:PC 生成、手机审核与平板标注
  • Qwen3.5-4B 微调实战:基于 LLaMA-Factory 构建医疗 AI 助手
  • Stable Diffusion 3.5 FP8 多卡并行实测:双 GPU 扩展性分析
  • 基于 Vue 3 的情侣双人飞行棋网页版实现
  • Clang/Clang++ 编译器架构与 C/C++ 编译指南
  • 链表应用实战:从内存管理到缓存淘汰
  • Python Pandas Index 常用方法用法精讲

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online