从数据集构建到 LoRA 微调:使用 LlamaFactory 实现高效文本分类
背景介绍
本文详细介绍如何使用 LLaMA-Factory 框架利用开源大语言模型完成文本分类任务。以 LoRA 微调 qwen/Qwen2.5-7B-Instruct 为例,展示从数据准备、配置训练到推理评估的完整流程。
文本分类数据集构建
为了适配 LLaMA-Factory 的训练格式,我们需要按照 Alpaca 样式构建数据集。将自定义数据集添加到 LLaMA-Factory/data/dataset_info.json 文件中,以便后续直接根据自定义数据集名称加载数据。
数据集示例结构如下:
[
{
"instruction": "请将以下文本分类到一个最符合的类别中。以下是类别及其定义:",
"input": "改革创新发展、行政区划调整、行政管理体制等方面的内容,涉及到体制机制的改革与完善,旨在推动高质量发展和提升生活品质。",
"output": "reason: 该文本主要讨论的是 xxx。因此,该文本最符合'社会管理'这一类别。\n\nlabel: 社会管理"
}
]
在构建数据集时,建议明确定义每个类别的含义,并在 instruction 中提供清晰的指令。输入部分(input)包含待分类的原始文本,输出部分(output)则包含模型的推理理由(reason)和最终标签(label)。这种结构化输出有助于后续自动化评估。
LoRA 微调配置
LLaMA-Factory 支持网页端训练,但生产环境通常推荐使用命令行进行更灵活的控制。我们将训练参数存储在 YAML 配置文件中,例如 qwen_train_cls.yaml。
配置文件详解:
model_name_or_path: qwen/Qwen2.5-7B-Instruct
stage: sft
finetuning_type: lora
lora_target: all
dataset_dir: LLaMA-Factory/data/
dataset: 数据集名
template: qwen
cutoff_len: 2048
overwrite_cache: true
preprocessing_num_workers: 16
output_dir: output/qwen2.5-7B/cls_epoch2
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 2.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
关键参数说明:
stage: sft: 指定监督微调阶段。
finetuning_type: lora: 启用 LoRA 低秩适配器,节省显存。
cutoff_len: 截断长度,需根据显存大小和数据平均长度调整。
bf16: 启用 bfloat16 混合精度训练,加速计算并减少显存占用。
ddp_timeout: 分布式训练超时时间,防止因长时间运行被系统杀死。
eval: 开启验证集评估,监控过拟合情况。
启动模型训练
使用以下命令启动后台训练任务:
nohup llamafactory-cli train qwen_train_cls.yaml > qwen_train_cls.log 2>&1 &
命令分解:
nohup: 确保进程在终端关闭后继续运行。
llamafactory-cli train: 调用 CLI 工具执行训练子命令。
> qwen_train_cls.log 2>&1: 将标准输出和错误输出重定向到日志文件。
&: 将命令放入后台运行。
训练过程中,请定期检查 qwen_train_cls.log 中的 Loss 变化曲线,确保模型收敛正常。若出现 OOM(Out Of Memory)错误,可尝试减小 per_device_train_batch_size 或增加 gradient_accumulation_steps。
模型部署与推理
训练完成后,LoRA 权重保存在 output_dir 指定的路径下。虽然 LLaMA-Factory 原生支持推理,但为了获得更高的吞吐量,可以结合 vLLM 进行批量推理。
本地推理示例:
llamafactory-cli chat \
--model_name_or_path qwen/Qwen2.5-7B-Instruct \
--adapter_name output/qwen2.5-7B/cls_epoch2 \
--template qwen \
--infer_backend vllm
此命令加载基础模型及 LoRA 适配器,并使用 vLLM 后端加速生成。推理结果应遵循训练时的输出格式(包含 reason 和 label),以便于统一评估。
文本分类评估代码
为了量化模型效果,需要编写评估脚本解析预测结果并与真实标签对比。以下是一个基于 Python 和 sklearn 的评估示例。
import os
import re
import json
from sklearn.metrics import classification_report, confusion_matrix
CLASS_NAME = [
"产业相关",
"法律法规与行政事务",
"其他",
]
def load_jsonl(file_path):
"""加载 JSONL 文件"""
data = []
try:
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
tmp = json.loads(line)
data.append(tmp)
except FileNotFoundError as e:
print(f"文件未找到:{file_path}")
raise e
return data
def parser_label(text: str):
"""从模型输出中提取 label"""
pattern = r"label[::\s\.\d\*]*([^\s^\*]+)"
matches = re.findall(pattern, text, re.DOTALL)
if len(matches) == 1:
return matches[0]
return None
def trans2num(item):
"""将类别名称转换为索引"""
predict = parser_label(item["predict"])
label = parser_label(item["label"])
predict_idx = -
label_idx = -
idx, cls_name (CLASS_NAME):
predict == cls_name:
predict_idx = idx
label == cls_name:
label_idx = idx
predict_idx, label_idx
():
data = load_jsonl(file_path=input_file)
predicts = []
labels = []
item data:
predict, label = trans2num(item)
label == -:
predicts.append(predict)
labels.append(label)
report = classification_report(predicts, labels, output_dict=)
(report)
report
__name__ == :
cls_eval()
注意事项:
parser_label 函数需根据实际模型输出格式调整正则表达式。
- 若模型未输出有效标签(返回 -1),应在评估前过滤掉该样本。
- 建议使用 F1-score 作为核心指标,特别是在类别不平衡的情况下。
常见问题与优化建议
- 显存不足:如果训练过程中显存溢出,可以尝试开启
flash_attention_2(如果硬件支持),或者进一步降低 batch size 并增加梯度累积步数。
- 过拟合:观察验证集 Loss 是否上升。如果验证集效果下降而训练集持续下降,说明过拟合,可增加
weight_decay 或提前停止训练。
- 推理速度:在生产环境中,务必使用 vLLM 等推理引擎,相比原生 transformers 推理速度可提升数倍。
- 数据质量:文本分类的效果很大程度上取决于标注数据的质量。建议对训练数据进行去重和清洗,确保类别分布均衡。
通过上述步骤,您可以利用 LLaMA-Factory 快速构建高效的文本分类模型,并根据业务需求灵活调整微调策略。