零样本模型的稳健微调
零样本模型的稳健微调
该存储库包含本文的代码作者:Mitchell Wortsman*、Gabriel Ilharco*、Jong Wook Kim、Mike Li、Simon Kornblith、Rebecca Roelofs、Raphael Gontijo-Lopes、Hannah Hajishirzi、Ali Farhadi、Hongseok Namkoong、Ludwig Schmidt。
TLDR:我们对零样本模型进行微调,同时在微调或推理期间保持或提高 OOD 准确度,且无需额外的计算成本。
抽象的
大型预训练模型(例如 CLIP 或 ALIGN)在执行零样本推理(即无需对特定数据集进行微调)时,可在一系列数据分布中提供一致的准确率。尽管现有的微调方法可以大幅提高分布内的准确率,但它们往往会降低分布外的稳健性。我们通过引入一种简单有效的方法来提高稳健性,以解决这一矛盾:集成零样本和微调模型的权重(WiSE-FT)。与标准微调相比,WiSE-FT 可以在分布外提供较大的准确率改进,同时保持较高的分布内准确率。在 ImageNet(分布内)和五个派生分布偏移上,WiSE-FT 将分布外准确率提高了 4 到 6 个百分点(pp),而分布内准确率提高了 1.6 pp。WiSE-FT 在另外六个分布偏移的多样化集合上实现了同样大的稳健性改进(2 到 23 pp),与七个常用迁移学习数据集上的标准微调相比,分布内准确率提高了 0.8 到 3.3 pp。这些改进在微调或推理期间不会产生额外的计算成本。
概要图
代码
概述
除了标准微调之外,WiSE-FT 还可以用几行代码实现,如下所示。参见了解更多详情。 # Load models zeroshot = ImageClassifier.load(zeroshot_checkpoint) finetuned = ImageClassifier.load(finetuned_checkpoint) theta_0 = zeroshot.state_dict() theta_1 = finetuned.state_dict() # make sure checkpoints are compatible assert set(theta_0.keys()) == set(theta_1.keys()) # interpolate between checkpoints with mixing coefficient alpha theta = { key: (1-alpha) * theta_0[key] + alpha * theta_1[key] for key in theta_0.keys() } # update the model acccording to the new weights finetuned.load_state_dict(theta) # evaluate evaluate(finetuned, args)
安装依赖项 conda env create conda activate wiseft
将目录添加到 PYTHONPATH: cd wise-ft export PYTHONPATH="$PYTHONPATH:$PWD"
下载数据
有需要时请参考有关如何下载数据集的说明。
运行 WiSE-FT
零样本和微调模型可用时的示例命令:
<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:#1f2328"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>python src/wise_ft.py \
--eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch \
--load=models/zeroshot.pt,models/finetuned.pt \
--results-db=results.jsonl \
--save=models/wiseft \
--data-location=~/data \
--alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
</code></span></span></span></span>
使用 ViT-B/32 从头开始运行 WiSE-FT 的示例命令:
<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:#1f2328"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>python src/wise_ft.py \
--train-dataset=ImageNet \
--epochs=10 \
--lr=0.00003 \
--batch-size=512 \
--cache-dir=cache \
--model=ViT-B/32 \
--eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch \
--template=openai_imagenet_template \
--results-db=results.jsonl \
--save=models/wiseft/ViTB32 \
--data-location=~/data \
--alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
</code></span></span></span></span>
注意:该标志--freeze-encoder
控制是否仅对线性分类器进行微调,或者是否对所有权重进行微调(端到端)。
绘制结果
生成散点图的示例命令:
<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:#1f2328"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>python src/scatter_plot.py \
--eval-datasets=ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch \
--results-db=results.jsonl \
--save plots
</code></span></span></span></span>
我们展示了使用 ViT-B/16 运行上述命令时的预期行为示例(模型可以下载 ):