简介
LLM(Large Language Model)通用模型在各种任务上表现良好,我们可以将它们用作对目标任务进行微调的基础。微调允许我们使模型适应目标域和目标任务,使其可以更好地完成我们所需要的特定任务。
本文介绍了基于 Firefly 框架对 Qwen 大模型进行 QLoRA 微调的完整流程。内容包括环境配置、训练参数详解、数据格式准备与转换、训练命令执行以及推理测试代码。重点讲解了因果语言模型的 Attention Mask 机制在多轮对话训练中的应用,并提供了显存优化、Loss 不下降等常见问题的解决方案,帮助开发者在有限算力下高效完成垂直领域模型的适配与部署。

LLM(Large Language Model)通用模型在各种任务上表现良好,我们可以将它们用作对目标任务进行微调的基础。微调允许我们使模型适应目标域和目标任务,使其可以更好地完成我们所需要的特定任务。
目前模型微调方法主要有 Full(全参微调)、Freeze、P-tuning、LoRA、QLoRA。这些方法各有优势,关于它们的介绍也有很多。本篇主要讲解代码实现,原理方面不赘述。考虑到不是所有读者都有足够的算力,因此使用占用资源最少的 QLoRA 对模型进行微调。
这里推荐使用 Firefly 项目来实现模型微调。这个项目主要是为了微调多轮对话数据集,不过单轮对话也同样适用。
Firefly 项目训练多轮对话模型时,采取了一种更加充分高效的方法。将一条多轮对话数据拼接之后,输入模型,并行计算每个位置的 loss,只有 Assistant 部分的 loss 参与权重更新。
为什么这种做法是可行的?答案在于因果语言模型的 attention mask。以 GPT 为代表的 Causal Language Model(因果语言模型),这种模型的 attention mask 是一个对角掩码矩阵,每个 token 在编码的时候,只能看到它之前的 token,看不到它之后的 token。所以 User1 部分的编码输出,只能感知到 User1 的内容,无法感知到它之后的文本,可以用来预测 Assistant1 的内容。而 User2 部分的编码输出,只能看到 User1、Assistant1、User2 的内容,可以用来预测 Assistant2 的内容,依此类推。对于整个序列,只需要输入模型一次,便可并行获得每个位置的 logits,从而用来计算 loss。
首先 pull 项目并配置环境:
git clone https://github.com/yangjianxin/Firefly.git
cd Firefly
pip install -r requirements.txt
确保安装以下核心依赖库:
然后找到 train_args/sft/qlora 路径下的配置文件(例如 qwen-7b-sft-qlora.json)。虽然文件名可能包含 7b,但通常 qwen 系列模型都可以通用。需要对里面的内容进行修改,主要修改模型路径和训练文件,其余参数可以根据显存情况调整。
{
"output_dir": "output/firefly-qwen-1_8b-sft-qlora",
"model_name_or_path": "Qwen/Qwen-1_8B-Chat",
"train_file": "./data/dummy_data.jsonl",
"template_name": "qwen",
"num_train_epochs": 1,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 16,
"learning_rate": 2e-4,
"max_seq_length": 1024,
"logging_steps": 100,
"save_steps": 100,
"save_total_limit": 1,
"lr_scheduler_type": "constant_with_warmup",
"warmup_steps": 100,
"lora_rank": 64,
"lora_alpha": 128,
"lora_dropout": 0.05,
"gradient_checkpointing": true,
"disable_tqdm": false,
"optim": "paged_adamw_32bit",
"seed": 42,
"fp16": true,
"report_to": "tensorboard",
"dataloader_num_workers": 0,
"save_strategy": "steps",
"weight_decay": 0,
"max_grad_norm": 0.3,
"remove_unused_columns": false
}
关键参数说明:
model_name_or_path: 指定预训练模型的路径或 HuggingFace ID。train_file: 训练数据文件路径,支持 jsonl 格式。per_device_train_batch_size: 单设备批次大小,显存不足时可减小。gradient_accumulation_steps: 梯度累积步数,用于模拟更大的 batch size。lora_rank: LoRA 的秩,越大拟合能力越强但显存占用越高。lora_alpha: LoRA 缩放系数,通常设为 rank 的 2 倍。gradient_checkpointing: 开启后节省显存,但会略微增加训练时间。optim: 优化器选择,paged_adamw_32bit 适合显存受限场景。这里的数据集需要符合特定的 jsonl 格式。官方提供的一个测试数据集可用于跑通流程。
每条数据应包含 conversation_id 和 conversation 两个字段。conversation_id 表示对话的序号,conversation 对应的是一个列表,元素是字典,每个字典中有 human 和 assistant 两个键,分别表示用户和模型的说话内容。
如果要用自己的数据集,也要按照这种格式进行修改。以下是一个 Python 脚本示例,用于将普通 CSV 数据转换为 Firefly 所需的格式:
import json
import csv
# 假设原始数据为 csv 格式,包含 question, answer 列
with open('raw_data.csv', 'r', encoding='utf-8') as f_in, \
open('./data/custom_data.jsonl', 'w', encoding='utf-8') as f_out:
reader = csv.DictReader(f_in)
for i, row in enumerate(reader):
data = {
"conversation_id": i,
"conversation": [
{"human": row['question'], "assistant": row['answer']}
]
}
f_out.write(json.dumps(data, ensure_ascii=False) + '\n')
数据清洗建议:
max_seq_length。利用以下代码开始训练:
python train.py --train_args_file train_args/sft/qlora/qwen-7b-sft-qlora.json
训练完成后会在配置文件中设置的 output_dir 生成对应的 QLoRA 文件。训练过程中可以通过 TensorBoard 监控 loss 变化曲线,观察模型收敛情况。
训练完成后,可利用 Firefly/script/chat 路径下的脚本进行调用,测试模型微调的效果。以下是简单的推理代码示例:
from firefly.trainer.sft_trainer import SFTTrainer
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM
# 加载基础模型
base_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen-1_8B-Chat",
load_in_8bit=True,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-1_8B-Chat", use_fast=False)
# 加载微调后的 LoRA 权重
model = PeftModel.from_pretrained(base_model, "output/firefly-qwen-1_8b-sft-qlora")
# 构造输入
messages = [
{"role": "user", "content": "你好,请介绍一下你自己。"}
]
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
# 生成回答
outputs = model.generate(**inputs, max_new_tokens=512)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
如果遇到 OOM 错误,可以尝试以下措施:
per_device_train_batch_size。gradient_accumulation_steps 以保持有效 batch size。gradient_checkpointing。max_seq_length。bf16 或 fp16。learning_rate。human 和 assistant 的标签位置。num_train_epochs。lora_rank 和 lora_alpha 参数。template_name) 以匹配模型特性。max_new_tokens。通过上述步骤,您可以完成 Qwen 大模型的 QLoRA 微调,并根据实际需求调整参数以获得最佳效果。

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online