跳到主要内容
极客日志极客日志面向AI+效率的开发者社区
首页博客GitHub 精选镜像工具UI配色美学隐私政策关于联系
搜索内容 / 工具 / 仓库 / 镜像...⌘K搜索
注册
博客列表
PythonAI算法

大模型微调技术详解与实战代码实现

综述由AI生成大模型微调技术旨在利用预训练模型知识适应特定任务,降低训练成本。全量微调、前缀微调、提示微调、P-Tuning 及 LoRA 等参数高效微调方法及其原理。通过 CoLA 数据集的 BERT 模型实战代码,展示了环境配置、数据预处理、模型训练及性能测试流程。实验分析表明,LoRA 在保持接近全量微调性能的同时显著降低了计算资源消耗,适合资源受限场景。掌握这些技术有助于开发者高效部署垂直领域大模型应用。

Kubernet发布于 2025/2/7更新于 2026/6/323 浏览
大模型微调技术详解与实战代码实现

课题背景

近年来,全球迎来了一股大模型的热潮,众多大型预训练模型如 GPT-4、BERT 相继问世。这些模型通过在海量文本数据上进行训练,能够掌握丰富的语言模式和广泛的一般知识,在各种自然语言处理任务中表现出色。然而,这些模型的训练成本高昂,需要庞大的计算资源和大量的数据。此外,尽管大型预训练模型具备较好的通用性知识,但每个具体任务或应用(如情感分析、文本摘要和对话生成)都有其独特的需求和模式。

为了应对上述问题,许多研究人员开始探索大模型微调技术,即在预训练模型的基础上,针对特定任务进行额外的训练,以满足特定需求,提高预训练模型在新任务上的性能,同时也减轻了大型预训练模型的训练成本。这种方式即使在计算资源受限的情况下,也能迅速利用预训练模型的知识来适应新任务,实现高效的迁移学习。因此,大模型微调技术不仅提升了模型性能,同时大大缩短了训练时间和计算成本,使更多人能够参与深度学习研究。

微调的概念已经存在很多年,并在很多领域得到了广泛应用。微调技术已知最早的应用是机器翻译,研究人员使用预训练的神经网络来初始化一个更小的网络的权重,然后针对特定翻译任务对其进行微调。经典的大模型微调方法,即全量微调(Full Fine-Tuning)会将预训练模型与少量特定任务数据一起继续训练,在这个过程中,预训练模型的权重被更新,以更好地适应任务。

但是,随着模型变得越来越大,在消费级硬件上对模型进行全部参数的微调变得不可行。此外,为每个下游任务独立存储和部署微调模型变得非常昂贵,因为微调后的模型(即调整了所有参数的模型)和原始预训练模型的规模是相同的。为了解决这个难题,研究人员开始探索参数高效微调技术(Parameter-Efficient Fine-Tuning,简称 PEFT),该技术的目标是在尽可能减少所需参数和计算资源的同时,有效地微调预训练语言模型。相较于传统的全量微调方法,高效微调技术所需的参数和计算资源更少,这一技术通过只训练一小部分参数来解决传统微调技术所需的大量资源问题,这些参数可以是已有模型参数的子集,或者是新增的一组参数。这些方法在参数效率、内存效率、训练速度以及模型的最终性能等方面都有所不同,接下来我们将详细探讨一些常用的微调技术和其主要实现方式。

常用微调方法

(1)全量微调

全量微调(Full Fine-Tuning)是一种在预训练模型的基础上进行微调的方法,其基本思想是:首先在大量的未标注数据上预训练一个大型模型,然后在具体的任务数据上对整个模型进行微调。在微调阶段,所有的模型参数(包括预训练阶段学习到的参数)都会被更新。微调的目标是根据具体任务的标签优化模型的性能。

全量微调的优势在于能够借助模型在预训练阶段积累的通用知识,然后通过微调,将这些知识应用到特定任务中。而且,全量微调已经历了相当长的时间的验证,被广泛应用于各种场景,并且其性能也得到了众多行业专家的认可。然而,全量微调也面临着一些挑战。首先,由于需要更新所有模型参数,所以它需要大量的计算资源和时间。其次,由于微调可能导致模型过度适应微调数据(即过拟合),因此,我们需要谨慎调整学习率和正则化参数。最后,对于大型模型来说,全面微调所有参数可能会导致模型性能下降,因为有些参数可能在预训练阶段已经被优化到了理想的状态。

(2)前缀微调

前缀微调(Prefix-Tuning)是一种高效微调预训练语言模型的技术。这种技术的基本思想是在模型的输入端(前缀)增加一些可学习的参数,然后在训练过程中优化这些参数,而保持模型的主体部分固定不变。具体来说,我们首先定义一些额外的参数,这些参数可以被视为一个序列,我们把它添加到我们的输入序列的前面,然后一起输入到模型中。这些参数通常被初始化为零,但在训练过程中会被优化以改进模型在特定任务上的性能。该方法其实和构造提示(Prompt)类似,只是提示是人为构造的'显式'的提示,并且无法更新参数,而前缀则是可以学习的'隐式'的提示。同时,为了防止直接更新前缀的参数导致训练不稳定和性能下降,论文提出在前缀层前面加 MLP 结构,训练完成后,只保留前缀部分的参数。除此之外,通过消融实验证实,只调整输入层的表现力不够,因此,前缀微调在每层都加了前缀参数,改动较大。前缀微调有以下几个优点:

① 减少计算资源的需求:因为只需要更新一部分参数,所以前缀微调需要的计算资源比全模型微调少。

② 减少过拟合的风险:因为只更新一部分参数,所以前缀微调降低了过拟合的风险。

③ 更好的迁移性能:因为每个任务都有一个特定的前缀,所以前缀微调可以更好地将模型在预训练阶段学习到的一般知识迁移到具体的任务上。

然而,前缀微调也面临着一些挑战。首先,前缀的效果会影响模型的性能,所以如何设计合适的前缀是一项挑战。其次,虽然前缀微调减少了参数的数量,但如果前缀的参数数量很大,那么计算资源的需求仍然可能很大。最后,对于一些复杂的任务,只通过前缀微调可能无法充分利用模型的能力。

(3)提示微调

提示微调(Prompt Tuning)可以看作是前缀微调的简化版本,只在输入层加入 prompt tokens,并不需要加入 MLP 进行调整来解决难训练的问题,提示微调只关注模型的输入 prompt,它不直接改变模型的参数,而是找到一组优化的 prompt,这些 prompt 可以引导模型在特定任务上生成更好的输出。例如,假设我们有一个预训练模型,我们想使用它来生成关于电影的积极评论。我们可以使用提示微调来找到一个优化的 prompt,如'这部电影真的很棒,因为…',然后让模型在这个 prompt 的引导下生成评论。

提示微调的主要优点是其效率和灵活性。因为我们只调整输入 prompt,而不是模型的所有参数,所以提示微调可以节省大量的计算资源。此外,同一模型可以通过叠加不同的 prompt 来调整适应不同的任务,因此提示微调非常灵活。

(4)P-Tuning

P-Tuning 提出将 Prompt 转换为可以学习的 Embedding 层,这个任务是让模型来预测一个国家的首都。左边是全 token 的 prompt,文献里称为'离散的 prompt'。右边是 token+vector 形式的 prompt,其保留了原 token prompt 里面的关键信息 (capital, Britain),它们 (capital, Britain) 是和任务、输出结果最相关的信息,其他不关键的词汇 (the, of ,is) 留给模型来学习。

考虑到直接对 Embedding 参数进行优化会存在这样两个挑战:

① Discretenes:对输入正常语料的 Embedding 层已经经过预训练,而如果直接对输入的 prompt embedding 进行随机初始化训练,容易陷入局部最优。

② Association:无法捕捉到 prompt embedding 之间的相关关系。

所以作者提出用 MLP+LSTM 的方式来对 prompt embedding 进行一层处理。P-Tuning 的优点是它可以极大地降低微调大型模型的计算和内存开销,因为它只需优化少量的参数。此外,因为这些参数是独立于任务的,我们可以将同一模型用于不同的任务,只需更换 soft prompt 即可。

(5)LoRA

在机器学习的实践中,虽然许多模型具有大量的参数(被称为过度参数化),但实际有效的参数或被学习到的特性可能只存在于这些参数空间的一个较低维度的子空间中。

举个例子,考虑一个深度神经网络,它可能有数百万或数十亿的参数。然而,通过训练,网络可能找到了一个有效的解决方案,这个解决方案在参数空间中可能只占据一个较小的区域,即只需要改变一部分参数就可以在不同任务间进行有效的迁移。这就说明这个模型可能在一个低内在维度的子空间上工作。低秩自适应(LoRA)方法假设在模型自适应过程中权重的变化也具有较低的'内在秩',它的核心思想很简单:

在原始 PLM (Pre-trained Language Model) 旁边增加一个旁路,做一个降维再升维的操作,来模拟所谓的 intrinsic rank。

训练的时候固定 PLM 的参数,只训练降维矩阵 A 与升维矩阵 B。而模型的输入输出维度不变,输出时将 AB 与 PLM 的参数叠加。

用随机高斯分布初始化 A,用 0 矩阵初始化 B,保证训练的开始此旁路矩阵依然是 0 矩阵。

LoRA 的这种思想有点类似于残差连接,同时使用这个旁路的更新来模拟 Full Fine-Tuning 的过程。并且,Full Fine-Tuning 可以被看做是 LoRA 的特例(当 r 等于 k 时)。

到目前为止,LoRA 方法已在多个任务和预训练模型上显示出了良好的性能,包括在自然语言处理、计算机视觉和强化学习任务中。

代码实现

(1)配置环境

实验环境:GeForce GTX 3090 with 22 GB memory,CUDA 11.6,Python3.8

依赖包:Pytorch, transformers, PEFT

import os
import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertForSequenceClassification, AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup
from peft import PeftType, PrefixTuningConfig, PromptTuningConfig, PromptEncoderConfig, LoraConfig

(2)数据预处理

数据集:CoLA 数据集,做单句语法分类任务

下载数据集:

url = 'https://nyu-mll.github.io/CoLA/cola_public_1.1.zip'
os.system(f'wget {url}')
if not os.path.exists('./cola_public/'):
    os.system('unzip cola_public_1.1.zip')
df = pd.read_csv("./cola_public/raw/in_domain_train.tsv", delimiter='\t', header=None, names=['sentence_source', 'label', 'label_notes', 'sentence'])

分词(我们选取的预训练大模型为 bert-base-uncased):

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

添加特殊符号:

sentences = df.sentence.values
labels = df.label.values
max_len = 0
for sent in sentences:
    input_ids = tokenizer.encode(sent, add_special_tokens=True)
    max_len = max(max_len, len(input_ids))

input_ids_list = []
attention_masks = []
for sent in sentences:
    encoded_dict = tokenizer.encode_plus(
        sent,
        add_special_tokens=True, # 添加 '[CLS]' 和 '[SEP]'
        max_length=64,           # padding / 截断
        pad_to_max_length=True,
        return_attention_mask=True,
        return_tensors='pt'
    )
    input_ids_list.append(encoded_dict['input_ids'])
    attention_masks.append(encoded_dict['attention_mask'])

input_ids = torch.cat(input_ids_list, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
labels = torch.tensor(labels)

划分数据集:

dataset = TensorDataset(input_ids, attention_masks, labels)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

batch_size = 32

# 训练数据
train_dataloader = DataLoader(
    train_dataset,
    sampler=RandomSampler(train_dataset),
    batch_size=batch_size
)

# 验证数据
validation_dataloader = DataLoader(
    val_dataset,
    sampler=SequentialSampler(val_dataset),
    batch_size=batch_size
)

(3)训练模型

加载模型,如果选择全量微调,则直接加载预训练好的模型:

model_name_or_path = "bert-base-uncased"
model = BertForSequenceClassification.from_pretrained(
    model_name_or_path,
    num_labels=2,
    output_attentions=False,
    output_hidden_states=False,
    return_dict=False
)

若选择高效微调方法,则需要先指定一种微调技术再加载模型:

p_type = "lora"
if p_type == "prefix-tuning":
    peft_type = PeftType.PREFIX_TUNING
    peft_config = PrefixTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=20)
elif p_type == "prompt-tuning":
    peft_type = PeftType.PROMPT_TUNING
    peft_config = PromptTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=20)
elif p_type == "p-tuning":
    peft_type = PeftType.P_TUNING
    peft_config = PromptEncoderConfig(task_type="SEQ_CLS", num_virtual_tokens=20, encoder_hidden_size=128)
elif p_type == "lora":
    peft_type = PeftType.LORA
    peft_config = LoraConfig(task_type="SEQ_CLS", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)

model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, num_labels=2)

加载优化器:

optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

开始微调模型:

epochs = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch_i in range(0, epochs):
    total_train_loss = 0
    model.train()
    avg_train_loss, training_time = 0, 0
    
    for step, batch in enumerate(train_dataloader):
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)
        
        model.zero_grad()
        loss, logits = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
        
        total_train_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
    
    avg_train_loss = total_train_loss / len(train_dataloader)

(4)测试性能

total_test_accuracy = 0
predictions = []
true_labels = []

for batch in prediction_dataloader:
    batch = tuple(t.to(device) for t in batch)
    b_input_ids, b_input_mask, b_labels = batch
    
    with torch.no_grad():
        outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
    
    logits = outputs[0]
    logits = logits.detach().cpu().numpy()
    label_ids = b_labels.to('cpu').numpy()
    
    predictions.append(logits)
    true_labels.append(label_ids)
    
    # 假设 flat_accuracy 函数已定义
    total_test_accuracy += flat_accuracy(logits, label_ids)

avg_test_accuracy = total_test_accuracy / len(prediction_dataloader)
print("Accuracy: {:.4f}".format(avg_test_accuracy))

实验分析

(1)性能比较

上图展示了不同微调技术性能对比图,横坐标为训练次数,纵坐标为准确率。通过观察可以发现,在所有微调方法中,全量微调性能最好;LoRA 是几种高效微调技术中表现最好的,与全量微调相差不大。提示微调在模型不够大时表现欠佳,这也与提示微调原文中得到的结论相吻合。

(2)综合比较

下表为不同微调技术性能、算力对比表格。通过该表可以看出,全量微调所需训练时间最长,内存占用最多。高效微调训练参数量、训练时长、内存占用都远小于全量微调。综合比较下,LoRA 表现最好,性能和全量微调不相上下,计算效率也很高。

总结

大规模预训练模型已经在各种任务中表现出惊人的效果,但为了在特定的应用场景中最大化其效果,微调技术已成为一个必不可少的工具。微调允许我们在保持模型的大部分权重不变的同时,对模型进行细粒度的调整,以更好地适应特定的任务或数据集。然而,尽管微调在许多情况下都非常有效,但它也有一些挑战,如灾难性遗忘和稳定性问题。新兴的技术,如 Prompt Tuning,P-Tuning,以及低秩自适应方法(LoRA)等,为这些挑战提供了有前景的解决方案。无论我们在哪个行业,都有必要理解和掌握微调技术,以便充分利用大规模预训练模型的潜力。

目录

  1. 课题背景
  2. 常用微调方法
  3. (1)全量微调
  4. (2)前缀微调
  5. (3)提示微调
  6. (4)P-Tuning
  7. (5)LoRA
  8. 代码实现
  9. (1)配置环境
  10. (2)数据预处理
  11. 训练数据
  12. 验证数据
  13. (3)训练模型
  14. (4)测试性能
  15. 实验分析
  16. (1)性能比较
  17. (2)综合比较
  18. 总结
  • 💰 8折买阿里云服务器限时8折了解详情
  • Magick API 一键接入全球大模型注册送1000万token查看
  • 🤖 一键搭建Deepseek满血版了解详情
  • 一键打造专属AI 智能体了解详情
极客日志微信公众号二维码

微信扫一扫,关注极客日志

微信公众号「极客日志V2」,在微信中扫描左侧二维码关注。展示文案:极客日志V2 zeeklog

更多推荐文章

查看全部
  • SpringBoot 源码解析:AnnotationConfigServletWebServerApplicationContext 构造方法
  • OpenClaw 配置飞书机器人完整指南
  • Python 爬取财富中国 500 强数据示例
  • 基于 Docker 和 Ollama 部署 DeepSeek 本地大模型
  • Face Analysis WebUI 体验报告:106 点关键点检测实测
  • DEIM 实时目标检测算法与 Visdrone2019 数据集实战
  • MySQL 内置函数实战指南:日期、字符串与数学运算
  • 大模型落地:在“卷模型”与“卷应用”间的行业化关卡
  • 设计支持万人并发抢购的秒杀系统架构方案
  • 前端动画新范式:CSS animation-timeline 动画时间线
  • Python 中的 == 与 is:本质区别与最佳实践
  • 基于 Rust 与 DeepSeek 大模型的智能 API Mock 生成器构建
  • Ubuntu 及 WSL 环境安装 Node.js、npm 和 Yarn 指南
  • SpringBoot 结合 Redis+Caffeine 多级缓存架构实践
  • 基于Python+OpenCV实现自动扫雷
  • Python 零基础系统学习指南与核心技能图谱
  • 66 个机器人项目合集:科研、教育、工业与医疗方向资源整理
  • C语言标准库与工具链:string.h、stdio.h、stdlib.h及CMake构建
  • llama.cpp CUDA 编译问题排查与性能优化指南
  • 基于 Llama-Factory 引擎重塑游戏 NPC 对话逻辑实战

相关免费在线工具

  • 加密/解密文本

    使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online

  • RSA密钥对生成器

    生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online

  • Mermaid 预览与可视化编辑

    基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online

  • 随机西班牙地址生成器

    随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online

  • Gemini 图片去水印

    基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online

  • curl 转代码

    解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online