Qwen2 是通义千问团队开源的大语言模型。以 Qwen2 作为基座大模型,通过指令微调的方式做高精度文本分类,是学习 LLM 微调的入门任务。
在本文中,我们会使用 Qwen2-1.5B-Instruct 模型在复旦中文新闻数据集上做指令微调训练,同时使用 SwanLab 监控训练过程、评估模型效果。
显存要求不高,10GB 左右就可以跑。
知识点:什么是指令微调?
大模型指令微调(Instruction Tuning)是一种针对大型预训练语言模型的微调技术,其核心目的是增强模型理解和执行特定指令的能力,使模型能够根据用户提供的自然语言指令准确、恰当地生成相应的输出或执行相关任务。
指令微调特别关注于提升模型在遵循指令方面的一致性和准确性,从而拓宽模型在各种应用场景中的泛化能力和实用性。
在实际应用中,指令微调更多把 LLM 看作一个更智能、更强大的传统 NLP 模型(比如 Bert),来实现更高精度的文本预测任务。所以这类任务的应用场景覆盖了以往 NLP 模型的场景,甚至很多团队拿它来标注互联网数据。
下面是实战正片:
1. 环境安装
本案例基于 Python>=3.8,请在您的计算机上安装好 Python;
另外,您的计算机上至少要有一张英伟达显卡(显存要求并不高,大概 10GB 左右就可以跑)。
我们需要安装以下这几个 Python 库,在这之前,请确保你的环境内已安装了 pytorch 以及 CUDA:
swanlab modelscope transformers datasets peft accelerate pandas
一键安装命令:
pip install swanlab modelscope transformers datasets peft pandas accelerate
本案例测试于 modelscope 1.14.0、transformers 4.41.2、datasets 2.18.0、peft 0.11.1、accelerate 0.30.1、swanlab 0.3.9
2. 准备数据集
本案例使用的是 zh_cls_fudan-news 数据集,该数据集主要被用于训练文本分类模型。
该数据集由几千条数据组成,每条数据包含 text、category、output 三列:
text 是训练语料,内容是书籍或新闻的文本内容;
category 是 text 的多个备选类型组成的列表;
output 则是 text 唯一真实的类型。
将三者组合成数据集的例子如下:
"""
[PROMPT]
Text: 第四届全国大企业足球赛复赛结束新华社郑州5月3日电(实习生田兆运)上海大隆机器厂队昨天在洛阳进行的第四届牡丹杯全国大企业足球赛复赛中,以5:4力克成都冶金实验厂队,进入前四名。沪蓉之战,双方势均力敌,90分钟不分胜负。最后,双方互射点球,沪队才以一球优势取胜。复赛的其它3场比赛,青海山川机床铸造厂队3:0击败东道主洛阳矿山机器厂队,青岛铸造机械厂队3:1战胜石家庄第一印染厂队,武汉肉联厂队1:0险胜天津市第二冶金机械厂队。在今天进行的决定九至十二名的两场比赛中,包钢无缝钢管厂队和河南平顶山矿务局一矿队分别击败河南平顶山锦纶帘子布厂队和江苏盐城无线电总厂队。4日将进行两场半决赛,由青海山川机床铸造厂队和青岛铸造机械厂队分别与武汉肉联厂队和上海大隆机器厂队交锋。本届比赛将于6日结束。(完)
Category: Sports, Politics
Output:
[OUTPUT]
Sports
"""
我们的训练任务,便是希望微调后的大模型能够根据 Text 和 Category 组成的提示词(Prompt),预测出正确的 Output。
我们将数据集下载到本地目录下。下载方式是前往魔搭社区,将 train.jsonl 和 test.jsonl 下载到本地根目录下即可。
3. 加载模型
这里我们使用 modelscope 下载 Qwen2-1.5B-Instruct 模型,然后把它加载到 Transformers 中进行训练:
from modelscope import snapshot_download, AutoTokenizer
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq
torch
model_dir = snapshot_download(, cache_dir=, revision=)
tokenizer = AutoTokenizer.from_pretrained(, use_fast=, trust_remote_code=)
model = AutoModelForCausalLM.from_pretrained(, device_map=, torch_dtype=torch.bfloat16)



