跳到主要内容
极客日志极客日志面向AI+效率的开发者社区
首页博客GitHub 精选镜像工具UI配色美学隐私政策关于联系
搜索内容 / 工具 / 仓库 / 镜像...⌘K搜索
注册
博客列表
PythonAI算法

如何在 Llama-Factory 中自定义损失函数

综述由AI生成介绍在 Llama-Factory 框架中自定义损失函数的高级用法。针对标准交叉熵损失无法处理业务优先级或样本不平衡的问题,通过重写 Trainer 的 compute_loss 方法实现灵活定制。示例包括标签平滑和基于类别权重的损失调整。同时强调了梯度稳定性、分布式训练兼容性及内存效率等注意事项,帮助开发者将业务目标编码进模型训练过程。

鲜活发布于 2026/4/6更新于 2026/5/2339 浏览

如何在 Llama-Factory 中自定义损失函数

在大模型微调日益普及的今天,越来越多的实际任务开始暴露出标准训练流程的局限性。比如,你在训练一个金融客服机器人时发现,尽管整体准确率不错,但模型总是'忽略'那些关键却少见的问题——像'账户被冻结怎么办'这类高风险咨询,出现频率低、样本少,结果在交叉熵损失主导下被梯度淹没。这时候,你真正需要的不是更多数据,而是一种能表达业务优先级的损失函数。

这正是 Llama-Factory 作为现代微调框架的价值所在:它不仅让你'跑得起来',更允许你深入到底层训练逻辑,把领域知识、工程经验甚至产品目标,编码进模型的学习过程中。其中最关键的入口之一,就是自定义损失函数。


Llama-Factory 基于 Hugging Face Transformers 构建,底层使用 PyTorch,其训练流程遵循典型的因果语言建模范式。默认情况下,Trainer 类会调用内置的 CrossEntropyLoss 来计算 token 级别的预测误差。这个过程看似固定,实则留出了清晰的扩展点——只要你重写 compute_loss 方法,就能完全接管损失计算逻辑。

这种设计并非偶然。它的核心思想是:训练引擎负责调度和优化,而损失函数定义'什么是对的'。 换句话说,框架管'怎么学',你来决定'学成什么样'。

举个例子,标签平滑(Label Smoothing)是一种常见的正则化技术,用于防止模型对训练标签过度自信。虽然 Hugging Face 的 Trainer 支持通过参数启用,但在某些场景下你需要更细粒度的控制,比如动态调整平滑强度或结合其他监督信号。这时,直接定制 compute_loss 就成了最灵活的选择。

import torch
import torch.nn as nn
from transformers import Trainer

class CustomTrainer(Trainer):
    def __init__(self, label_smoothing=0.0, **kwargs):
        super().__init__(**kwargs)
        self.label_smoothing = label_smoothing
        self.ce_loss = nn.CrossEntropyLoss(reduction="none")

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # Shift for causal language modeling
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        flat_logits = shift_logits.view(-, shift_logits.size(-))
        flat_labels = shift_labels.view(-)

         .label_smoothing > :
            vocab_size = flat_logits.shape[-]
             torch.no_grad():
                true_probs = torch.full_like(flat_logits, .label_smoothing / (vocab_size - ))
                true_probs.scatter_(, flat_labels.unsqueeze(),  - .label_smoothing)
                log_probs = torch.log_softmax(flat_logits, dim=-)
                loss = -(true_probs * log_probs).(dim=-).mean()
        :
            loss = .ce_loss(flat_logits, flat_labels).mean()

         (loss, outputs)  return_outputs  loss
1
1
1
if
self
0
1
with
self
1
1
1
1
self
1
sum
1
else
self
return
if
else

这段代码的关键在于,它没有改动任何训练流程,只是替换了损失计算部分。你可以把它看作一个'插槽'——只要返回的是标量 loss,PyTorch 就能自动完成反向传播。这意味着你的自定义逻辑可以非常复杂,比如引入对比学习项、KL 散度约束,甚至是基于外部奖励的强化学习目标。

更重要的是,Llama-Factory 提供了配置驱动的加载机制。你不需要修改主程序,只需将上述类保存为 trainers/custom_trainer.py,然后在 YAML 配置中声明:

trainer_type: custom
custom_trainer_path: ./trainers/custom_trainer.py
label_smoothing: 0.1

框架会在初始化时动态导入并实例化你的 CustomTrainer,自动注入所有配置参数。这种插件式架构让实验迭代变得极其高效:换损失就像换电池一样简单。

但别忘了,灵活性也意味着责任。当你跳出默认路径时,有几个坑必须警惕。

首先是梯度稳定性。如果你在损失中加入了复杂的数学运算,比如除法、对数或指数操作,稍不注意就会导致 NaN 或梯度爆炸。建议始终用 torch.clamp 对输入做裁剪,并在调试阶段开启 torch.autograd.set_detect_anomaly(True) 来捕捉异常源头。

其次是分布式训练兼容性。在多 GPU 场景下,每个设备只看到一部分 batch。如果你在损失中做了全局归一化或统计量计算(如均值、方差),必须确保这些值是在所有设备上同步聚合过的。否则,梯度更新会不一致。好在 Llama-Factory 默认使用 DistributedDataParallel,你可以借助 torch.distributed.all_reduce 手动同步张量,或者干脆避免跨设备依赖。

再来看内存效率。长序列任务中,一次性展开所有 token 的 logits 和 labels 可能占用巨大显存。例如,一个 batch size 为 8、序列长度为 8192 的输入,展平后的形状是 (8*8192, vocab_size),对于 32K 词表来说就是近 2GB 的中间张量。解决办法是分块计算或使用 reduction='none' 后按需降维,而不是盲目 .mean()。

还有一个常被忽视的点是日志可解释性。当你加了权重、平滑或多个损失项时,最终的 loss 值已经不能直接和原始交叉熵比较了。建议在训练日志中同时输出原始 loss 和加权后的 total loss,方便分析收敛行为。Llama-Factory 支持 TensorBoard,你可以轻松记录这些辅助指标:

if self.args.local_rank == 0:  # 主进程记录
    self.log({"base_loss": base_loss.item(), "weighted_loss": weighted_loss.item()})

说到实际应用,我们再回到那个金融客服的例子。假设你有一组标注好的问题类别,其中'退款政策'、'账户安全'等属于高优先级。与其靠数据过采样来提升曝光,不如直接在损失层面赋予它们更高权重:

def compute_loss(self, model, inputs):
    labels = inputs["labels"]
    category_ids = inputs.get("category_id", None)
    outputs = model(**inputs)
    logits = outputs["logits"]
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    loss_per_token = self.ce_loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    
    # Reshape to [batch_size, seq_len] and average over sequence
    loss_per_sample = loss_per_token.view(labels.size(0), -1).mean(dim=1)
    
    if category_ids is not None:
        class_weights = {
            0: 1.0,   # login_issue
            1: 5.0,   # refund_policy
            2: 8.0,   # account_frozen
            3: 1.5    # feature_request
        }
        weights = torch.tensor([class_weights[cid.item()] for cid in category_ids], device=loss_per_sample.device)
        loss_per_sample = loss_per_sample * weights
    
    return loss_per_sample.mean()

这种方法的优势在于,它不改变数据分布,避免了因重复采样带来的噪声放大;同时又能精准地将业务意图转化为可优化目标。而且,权重参数完全可以从配置文件读取,做到代码与策略解耦。

类似的思路还能拓展到更多高级场景:

  • Focal Loss:抑制易分类样本的贡献,聚焦难例;
  • Contrastive Loss:在检索增强问答中拉近 query 与 positive passage 的表示距离;
  • KL Div Loss:在蒸馏任务中对齐教师模型与学生模型的输出分布;
  • Multi-task Learning:联合优化生成任务和分类任务,共享主干网络。

这些都不是理论设想,而是已经在推荐系统、医疗诊断、法律文书生成等领域落地的技术实践。关键在于,你是否拥有一个足够开放的框架来承载这些创新。

Llama-Factory 的真正优势,不只是支持 LoRA、QLoRA 这些热门技术,而是它把整个微调链条打开给你看,并告诉你:'这里也可以改。' 它的设计哲学很明确:通用性解决共性问题,可扩展性应对个性需求。

这也解释了为什么它能在众多微调工具中脱颖而出。相比 Alpaca-LoRA 这类脚本型项目,它提供了 WebUI 和模块化 API;相比纯 CLI 工具,它又保留了深度定制的空间。无论是想快速验证想法的研究者,还是需要稳定交付的企业开发者,都能找到自己的位置。

未来的大模型训练,不会停留在'喂数据、调 learning rate'的层面。随着应用场景越来越复杂,我们需要的是语义感知的优化目标、任务感知的损失结构,甚至是用户反馈驱动的动态调整机制。而这一切的起点,往往就是一个被重新定义的 compute_loss 方法。

当你能把'这个问题很重要'翻译成'这个样本的损失要翻倍',你就不再只是在训练模型,而是在塑造它的价值观。这才是高级微调的真正意义。

目录

  1. 如何在 Llama-Factory 中自定义损失函数
  • 💰 8折买阿里云服务器限时8折了解详情
  • Magick API 一键接入全球大模型注册送1000万token查看
  • 🤖 一键搭建Deepseek满血版了解详情
  • 一键打造专属AI 智能体了解详情
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • Python Selenium 浏览器自动化入门与实战指南
  • ROS2 无人机全栈技术解析:从飞控集成到场景落地
  • 大语言模型高效参数微调(PEFT)技术详解
  • Stable Diffusion 入门:提示词编写与 ControlNet 实战指南
  • AI 人才薪资飙升:硕士年薪 50 万,博士超 200 万,缺口千万
  • 基于 Java 和 Leaflet 的湖南省道路长度 WebGIS 系统实现
  • 自学编程指南:职业方向、语言选择与学习路径
  • QWEN-AUDIO 语音合成支持 20+ 情感指令与多音色演绎
  • AI 元人文:自感概念与 DOS 模型深度解析
  • 多人人体解析失败原因与 M2FP 拼图算法解析
  • AIGC 技术全景解析:大语言模型、扩散模型与多模态应用指南
  • OpenHarmony Flutter 开发:使用 sanitize_html 净化 HTML 防止 XSS
  • NUC 迷你主机配合 OpenClaw 构建家庭 AI 助理
  • Python Selenium 浏览器自动化入门与实战
  • SQL Server 2016 及 Management Studio 安装指南
  • ThinkPad 笔记本安装 Ubuntu 系统完整教程
  • Shell 脚本基础:参数校验与退出状态码解析
  • 基于 nanobot 搭建轻量级 QQ AI 机器人及搜索功能优化
  • 智谱开源Open-AutoGLM模型本地部署与性能优化指南
  • Pixel Fashion Atelier 部署教程:Stable Diffusion 像素时装生成

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online