大模型微调方法总结
LoRA, Adapter, Prefix-tuning, P-tuning, Prompt-tuning。
1、LoRA
Paper: LoRA: Low-Rank Adaptation of Large Language Models
简介
自然语言处理目前存在一个重要范式:一般领域数据的大规模预训练,对特定任务或领域的适应(finetune)。
但是随着预训练语言模型越来越大,这个范式存在以下问题:
- 当我们 finetune 大模型时,由于训练成本太高,不太可能重新训练所有模型参数。
- 以前的方法(论文发表于 2021 年)都或多或少有其它性能问题,如 adapter 增加了模型层数,引入了额外的推理延迟;prefix-tuning 比较难训练,效果不如直接 finetune。
基于上述背景,论文作者得益于前人的一些关于内在维度(intrinsic dimension)的发现:模型是过参数化的,它们有更小的内在维度,模型主要依赖于这个低的内在维度(low intrinsic dimension)去做任务适配。假设模型在任务适配过程中权重的改变量是低秩(low rank)的,由此提出低秩自适应(LoRA)方法,LoRA 允许我们通过优化适应过程中密集层变化的秩分解矩阵来间接训练神经网络中的一些密集层,同时保持预先训练的权重不变。
方法
LoRA 的实现思想很简单,就是冻结一个预训练模型的矩阵参数,并选择用 A 和 B 矩阵来替代,在下游任务时只更新 A 和 B。
结合流程来看,LoRA 的实现流程如下:
- 在原始预训练语言模型(PLM)旁边增加一个旁路,做一个降维再升维的操作,来模拟所谓的内在秩。
- 训练的时候固定 PLM 的参数,只训练降维矩阵 A 与升维矩阵 B。
- 模型的输入输出维度不变,输出时将 BA 与 PLM 的参数叠加。
- 用随机高斯分布初始化 A,用 0 矩阵初始化 B,保证训练的开始此旁路矩阵依然是 0 矩阵。
实现
接下来我们从公式上解释 LoRA 的实现。
假设要在下游任务微调一个预训练语言模型(如 GPT3),则需要更新预训练模型参数,公式表示如下:
W0 是预训练模型初始化的参数,ΔW 就是需要更新的参数。如果是全参数微调,则它的参数量=W0 参数量(如果是 GPT3,则 ΔW≈175B)。从这可以看出要全参数微调大语言模型,资源消耗巨大。
由于前人的工作发现预训练的语言模型具有较低的'内部维度(intrinsic dimension)',在任务适配过程中,即使随机投影到较小的子空间,仍然可以有效地学习。因此,LoRA 做的就是增加小参数模块去学习改变量 ΔW。
在训练过程中,W0 是固定不变的,只有 A 和 B 包含训练参数,是变化的。
而在推理的过程中,只需要把改变量放回原模型,就不会有任何延迟。
如果想切换任务,只需要在切换任务的过程中,减去 BA,然后换上用其它任务训练好的 B'A' 就可以了。
总结
总的来说,基于大模型的内在低秩特性,增加旁路矩阵来模拟 full finetuning,LoRA 是一个能达成 lightweight finetuning 的简单有效的方案。目前该技术已经广泛应用于大模型的微调,如 Alpaca,stable diffusion+LoRA,而且能和其它参数高效微调方法有效结合,例如 State-of-the-art Parameter-Efficient Fine-Tuning (PEFT)。
2、Adapter
Paper: Parameter-Efficient Transfer Learning for NLP
MAD-X: An Adapter-Based Framework for Multi-Task Cross-Lingual Transfer
简介
2019 年,Houlsby N 等人将 Adapter 引入 NLP 领域,作为全模型微调的一种替代方案。Adapter 主体架构如下图所示。
在预训练模型每一层 (或某些层) 中添加 Adapter 模块,微调时冻结预训练模型主体,由 Adapter 模块学习特定下游任务的知识。每个 Adapter 模块由两个前馈子层组成,第一个前馈子层将 Transformer 块的输出作为输入,将原始输入维度 d 投影到 m,通过控制 m 的大小来限制 Adapter 模块的参数量,通常情况下 m<<d。在输出阶段,通过第二个前馈子层还原输入维度,将 m 重新投影到 d,作为 Adapter 模块的输出。通过添加 Adapter 模块来产生一个易于扩展的下游模型,每当出现新的下游任务,通过添加 Adapter 模块来避免全模型微调与灾难性遗忘的问题。Adapter 方法不需要微调预训练模型的全部参数,通过引入少量针对特定任务的参数,来存储有关该任务的知识,降低对模型微调的算力要求。


