导读
本文介绍微调(Fine-tuning)的基本概念,以及如何对语言模型进行微调。从 GPT-3 到 ChatGPT、从 GPT-4 到 GitHub Copilot 的过程中,微调扮演了重要角色。什么是微调?微调能解决什么问题?什么是 LoRA?如何进行微调?本文将解答以上问题,并通过代码实例展示如何使用 LoRA 进行微调。
本文详解大语言模型微调概念,对比 SFT 与 RLHF,深入解析 LoRA 低秩适配原理及数学基础。通过 Hugging Face Transformers 库实战演示电影评论分类任务,涵盖数据准备、模型加载、PEFT 配置及训练评估全流程。文章补充了模型合并与量化部署方案,并提供最佳实践建议,帮助开发者低成本实现模型定制化。

本文介绍微调(Fine-tuning)的基本概念,以及如何对语言模型进行微调。从 GPT-3 到 ChatGPT、从 GPT-4 到 GitHub Copilot 的过程中,微调扮演了重要角色。什么是微调?微调能解决什么问题?什么是 LoRA?如何进行微调?本文将解答以上问题,并通过代码实例展示如何使用 LoRA 进行微调。
微调是利用已经训练好的模型(通常是大型的预训练模型)作为起点,在新的数据集进一步训练模型,从而使其更适合特定的应用场景。即使非专业算法同学,只要硬件成本可控,也可动手尝试微调自己的模型。
GPT-3 使用大量互联网上的语料训练完成后,并不完全适合对话场景。例如输入'中国的首都是哪里?',基于训练后的模型参数推理,结果可能不准确。这是因为训练数据中相关句子的概率分布与特定对话需求存在差异。需要多阶段的优化过程使模型更擅长处理对话,并更好地理解和回应用户需求。
GPT-3 模型的微调过程包括几个关键步骤:
目前 OpenAI 公开信息显示,ChatGPT 的主要改进是通过微调和 RLHF 实现。流程大致为:预训练 → 微调(SFT) → 强化学习(RLHF) → 模型修剪与优化。
在生产实践中,虽然 RLHF 可提升表现,但对特定任务采用 SFT 往往效果更好。RLHF 成本高,依赖大量人工标注数据,相对 SFT 使用较少。
两者可结合:通用预训练 → 继续预训练(行业/部门) → 微调(具体业务小组)。
微调是基于已训练好的神经网络模型,通过对其参数进行细微调整,使其更好适应特定任务。根据微调范围分为全模型微调和部分微调。
生产中常用参数高效微调(PEFT),通过引入低秩矩阵(如 LoRA)或适配层,减少资源需求。
LoRA(Low-Rank Adaptation)通过引入低秩矩阵来减少微调过程中需要更新的参数数量,显著降低计算资源需求。
可重用性:LoRA 不改变原模型参数,不同任务的低秩矩阵可分别存储加载,灵活应用于不同任务。例如在手机终端跑大模型,针对不同任务动态加载 LoRA 参数,相比一个任务一个模型,大幅节省空间。
研究表明,模型适应特定任务时并不需要用到所有复杂能力。原始矩阵维度较高,假设为 $d_k$ 维矩阵 $W_0$,调整方式为矩阵加法增加 $ riangle W$。若 $ riangle W$ 仍为 $d \times k$ 维,参数量多。LoRA 将其表示为低秩分解:
$$ h = W_0 + \triangle W = W_0 + BA $$
其中 $B \in \mathbb{R}^{d \times r}, A \in \mathbb{R}^{r \times k}$,且秩 $r \ll \min(d, k)$。
举例计算:$d=1000, k=1000$,全量调整需 100w 参数。若 $r=4$,仅需 $1000 \times 4 + 4 \times 1000 = 8000$ 个参数。
论文实验表明,在调整 Transformer 权重矩阵时,$r=1$ 时对特定任务就有非常好效果。通常 $r$ 设置为 1~8,经验值常为 4。
数据要求:
本节使用 LoRA 微调 distilbert/distilbert-base-uncased 模型,实现对电影评论的情感分类(正面/负面)。数据集为 stanfordnlp/imdb。使用 Colab 免费 T4 GPU,1000 条数据,10 个 Epoch,约 6 分钟完成。微调前准确率 50%,微调后达 87%。
!pip install datasets
!pip install transformers
!pip install evaluate
!pip install torch
!pip install peft
from datasets import load_dataset, DatasetDict, Dataset
from transformers import (
AutoTokenizer,
AutoConfig,
AutoModelForSequenceClassification,
DataCollatorWithPadding,
TrainingArguments,
Trainer)
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig
import evaluate
import torch
import numpy as np
# 加载 IMDB 数据
imdb_dataset = load_dataset("stanfordnlp/imdb")
# 定义子采样大小
N = 1000
rand_idx = np.random.randint(24999, size=N)
# 提取训练和测试数据
x_train = imdb_dataset['train'][rand_idx]['text']
y_train = imdb_dataset['train'][rand_idx]['label']
x_test = imdb_dataset['test'][rand_idx]['text']
y_test = imdb_dataset['test'][rand_idx]['label']
# 创建新数据集
dataset = DatasetDict({
'train': Dataset.from_dict({'label': y_train, 'text': x_train}),
'validation': Dataset.from_dict({'label': y_test, 'text': x_test})
})
print(np.array(dataset['train']['label']).sum() / len(dataset['train']['label'])) # 约 0.508
IMDB 数据格式示例:
{
"label": 0,
"text": "Not a fan, don't recommend."
}
from transformers import AutoModelForSequenceClassification
model_checkpoint = 'distilbert-base-uncased'
# 定义标签映射
id2label = {0: "Negative", 1: "Positive"}
label2id = {"Negative": 0, "Positive": 1}
# 生成分类模型
model = AutoModelForSequenceClassification.from_pretrained(
model_checkpoint, num_labels=2, id2label=id2label, label2id=label2id)
模型架构为 6 层 Transformer,LoRA 影响 q_lin 层的权重(768*768 矩阵)。
from transformers import AutoTokenizer
# 创建分词器
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
# 添加 pad token
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))
# 定义 tokenize 函数
def tokenize_function(examples):
text = examples["text"]
tokenizer.truncation_side = "left"
tokenized_inputs = tokenizer(
text,
return_tensors="np",
truncation=True,
max_length=512,
padding='max_length'
)
return tokenized_inputs
# 处理数据集
tokenized_dataset = dataset.map(tokenize_function, batched=True)
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
说明:
import torch
from peft import LoraConfig, get_peft_model
import evaluate
# 评估指标
accuracy = evaluate.load("accuracy")
def compute_metrics(p):
predictions, labels = p
predictions = np.argmax(predictions, axis=1)
return {"accuracy": accuracy.compute(predictions=predictions, references=labels)}
# PEFT 配置
peft_config = LoraConfig(
task_type="SEQ_CLS",
r=1,
lora_alpha=32,
lora_dropout=0.01,
target_modules=['q_lin']
)
# 获取 PEFT 模型
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
输出显示可训练参数占比不到 1%,参数量越大,比例越小。
# 超参数
lr = 1e-3
batch_size = 4
num_epochs = 10
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir=model_checkpoint + "-lora-text-classification",
learning_rate=lr,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=num_epochs,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
)
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
# 开始训练
trainer.train()
训练后再次测试,分类正确率显著提升。
微调完成后,LoRA 权重可以合并到基座模型中,也可以单独加载。合并模型可减少推理时的内存占用,因为不再需要实时计算低秩矩阵加法。
合并权重示例:
from peft import PeftModel
merged_model = model.merge_and_unload()
merged_model.save_pretrained("./merged_model")
量化加速:
对于生产环境,建议结合量化技术(如 INT8 或 FP4)进一步降低显存需求。可使用 bitsandbytes 库加载量化模型。
r 值通常从 4 开始尝试,alpha 设为 r 的 2 倍左右。fp16 或 bf16 加速训练并节省显存。本文介绍了微调的基本概念,以及如何对语言模型进行微调。微调虽成本低于大模型的预训练,但对于大量参数的模型微调成本仍非常之高。好在随着算力增长,微调的成本门槛会越来越低,应用场景也会越来越多。
高质量的输入非常重要,类似于人去学习技能。阅读经典,反复阅读,才能掌握精髓。希望本文能帮助开发者快速上手大模型微调技术。

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online