跳到主要内容 大模型 LLM 微调经验与总结 | 极客日志
Python AI 算法
大模型 LLM 微调经验与总结 分享了 ChatGLM-6B 大模型微调的实战经验,涵盖 Freeze、P-Tuning 和 LoRA 三种主流方法的技术原理与代码实现。通过汽车工业故障模式关系抽取任务进行实验,对比了各方法在显存占用、训练耗时及 F1 分数上的表现。结果显示 PT 方法效果最佳,LoRA 效率较高,且单指令微调未导致灾难性遗忘。文章还汇总了常用的中文开源大模型、指令数据集及项目资源,并提供了显存溢出处理、训练稳定性优化等常见问题解决方案,适合希望深入理解大模型微调流程的开发者参考。
Pythonist 发布于 2025/2/7 更新于 2026/4/20 2 浏览
前言 随着大型语言模型(LLM)技术的快速发展,开源社区涌现了大量优秀的微调项目。本文基于 ChatGLM-6B 模型的微调实践,分享 Freeze、P-Tuning 和 LoRA 三种主流方法的实战经验,并汇总了相关的开源资源。实验表明,在特定任务下采用单指令微调,模型并未出现明显的灾难性遗忘现象。
ChatGLM-6B 模型微调方法 模型参数量越大,对显存的要求越高。目前主流的轻量化微调方法包括 Freeze(参数冻结)、P-Tuning(软提示)和 LoRA(低秩适配)。以下以信息抽取任务为例,介绍这三种方法的具体实现。
1. Freeze 方法 Freeze 方法即参数冻结,通过固定原始模型的大部分参数,仅训练部分层或模块,从而实现在单卡或不进行张量并行(TP)的情况下进行训练。
核心逻辑:
遍历模型参数,根据名称匹配需要冻结的层。例如,保留后几层的可训练性,冻结前面的层。
for name, param in model.named_parameters():
if not any (nd in name for nd in ["layers.27" , "layers.26" , "layers.25" , "layers.24" , "layers.23" ]):
param.requires_grad = False
训练配置:
使用 DeepSpeed 进行加速训练。主要参数包括训练路径、模型目录、训练轮数、批次大小、梯度累积步数等。
CUDA_VISIBLE_DEVICES=0 deepspeed finetuning_freeze.py --num_train_epochs 5 --train_batch_size 2
推理代码:
参考 predict_freeze.py,根据具体任务的评价标准进行预测。
2. P-Tuning 方法 P-Tuning 是一种针对大模型的 soft-prompt 方法,通过在 Embedding 层或每一层前添加可训练的连续向量来引导模型。
P-Tuning: 仅对大模型的 Embedding 加入新的参数。
P-Tuning-V2: 将大模型的 Embedding 和每一层前都加上新的参数,效果通常更好。
config = ChatGLMConfig.from_pretrained(args.model_dir)
config.pre_seq_len = args.pre_seq_len
config.prefix_projection = args.prefix_projection
model = ChatGLMForConditionalGeneration.from_pretrained(args.model_dir, config=config)
for name, param in model.named_parameters():
if not any (nd in name for nd in ["prefix_encoder" ]):
param.requires_grad = False
当 prefix_projection 为 True 时,启用 P-Tuning-V2;为 False 时,为 P-Tuning。
CUDA_VISIBLE_DEVICES=0 deepspeed finetuning_pt.py --num_train_epochs 5 --train_batch_size 2 --pre_seq_len 16
3. LoRA 方法 LoRA (Low-Rank Adaptation) 通过在指定参数上增加额外的低秩矩阵,并在训练过程中仅更新这些新增参数。当秩值远小于原始参数维度时,可训练参数量极小,但能获取较好的效果。
model = ChatGLMForConditionalGeneration.from_pretrained(args.model_dir)
config = LoraConfig(r=args.lora_r,
lora_alpha=32 ,
target_modules=["query_key_value" ],
lora_dropout=0.1 ,
bias="none" ,
task_type="CAUSAL_LM" ,
inference_mode=False ,
)
model = get_peft_model(model, config)
注意事项:
对于需要保持结果一致性的任务(如关闭 dropout,解码时关闭 do_sample),需保存模型时修改 adapter_config.json 中的 inference_mode 为 false,并执行 model.eval()。这是因为 ChatGLM 模型代码中未采用 Conv1D 函数,需注意推理模式切换。
三元组抽取实验设置
最大序列长度:768
Batch Size: 2
训练轮数:5
精度:fp16
分布式策略:DeepSpeed Zero-1
数据集:
为防止数据泄露,采用领域比赛数据集'汽车工业故障模式关系抽取',随机抽取 50 条作为测试集。
prompt_text:你现在是一个信息抽取模型,请你帮我抽取出关系内容为"性能故障" , "部件故障" , "组成" 和 "检测工具" 的相关三元组,三元组内部用"_" 连接,三元组之间用\n分割。文本:
输入:故障现象:发动机水温高,风扇始终是低速转动,高速档不工作,开空调尤其如此。
输出:发动机_部件故障_水温高\n风扇_部件故障_低速转动
实验结果分析 实验均在 80G-A100 显卡上进行,对比不同微调方法的显存占用、参数量、训练耗时及 F1 分数。
微调方法 PT-Only-Embedding PT Freeze Lora 显卡占用 37G 56G 24G 39G 总参数 6.259B 7.211B 6.255B 6.259B 可训练参数占比 0.0586% 13.26% 16.10% 0.0586% 训练耗时 20min 52min 46min 25min 测试结果 F1 0.0 0.6283 0.5675 0.5359
效果排序: PT > Freeze > Lora > PT-Only-Embedding。
速度排序: PT-Only-Embedding > Lora > Freeze > PT。
PT-Only-Embedding 表现不佳: Loss 仅收敛到 2.x,而其他方法可收敛到 0.x。原因是输出形式与原语言模型任务差异大,仅增加额外 Embedding 不足以改变复杂下游任务。
显存占用: PT 方法因增加较多额外参数,显存占用较大。
推理耗时: 生成模型生成的长度会影响耗时,其他方法因增加额外参数,推理耗时略高于 Freeze。
灾难性遗忘: 模型在指定任务微调后,并未丧失原有能力(如写快排算法、翻译、问答)。这可能是因为大模型微调多采用大量 Instruction 训练,单一指令微调对原指令影响较小。
验证测试:
使用 test_forgetting.py 对翻译、代码、问答任务进行测试,结果显示模型保留了基础能力。
中文开源大模型与资源汇总 虽然开源大模型众多,但可直接微调且支持中文的并不多。以下是常用资源汇总。
中文开源大模型
中文开源指令数据 大多数指令集从 Alpaca 翻译而来,也可利用 GPT-4 等模型进行廉价标注。
热门开源项目
常见问题与最佳实践
1. 显存溢出 (OOM) 处理
开启 gradient_checkpointing:牺牲训练时间换取显存空间。
减小 batch_size 或增加 gradient_accumulation_steps。
使用更高效的量化技术(如 INT8/FP16)。
升级算力硬件,如使用 A100 或 H100 集群。
2. 训练稳定性
确保学习率设置合理,过大的学习率可能导致 Loss 震荡。
监控 Loss 曲线,若 Loss 无法下降,检查数据质量或 Prompt 格式。
定期保存 Checkpoint,防止意外中断导致进度丢失。
3. 推理优化
生产环境中建议使用 TensorRT 或 vLLM 等推理框架加速。
注意 inference_mode 的设置,避免 Dropout 影响生成一致性。
总结 当前大模型生态发展迅速,个人开发者可通过微调技术实现垂直领域的模型应用。本文总结了 ChatGLM-6B 的三种微调方法及实验结论,希望能为相关从业者提供参考。未来建议关注更多开源模型动态,持续探索更高效、更低成本的微调方案。
注:以上实验数据基于特定硬件环境,实际效果可能因硬件配置和数据集差异而有所不同。
相关免费在线工具 加密/解密文本 使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
RSA密钥对生成器 生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
Mermaid 预览与可视化编辑 基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
curl 转代码 解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
Base64 字符串编码/解码 将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
Base64 文件转换器 将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online