全模型微调(Full Model Fine-Tuning)更新模型的所有参数,适用于目标任务与预训练任务差异较大或需要最大化模型性能的场景。虽然这种方法能获得最佳性能,但它需要大量计算资源和存储空间,并且在数据较少的情况下容易导致过拟合。相比之下,部分微调(Partial Fine-Tuning)仅更新模型的部分参数,其他参数保持冻结。这种方法减少了计算和存储成本,同时降低了过拟合的风险,适合数据较少的任务,但在任务复杂度较高时可能无法充分发挥模型的潜力。
To fine-tune a model, you are required to provide at least 10 examples. We typically see clear improvements from fine-tuning on 50 to 100 training examples with gpt-3.5-turbo but the right number varies greatly based on the exact use case.
We recommend starting with 50 well-crafted demonstrations and seeing if the model shows signs of improvement after fine-tuning. In some cases that may be sufficient, but even if the model is not yet production quality, clear improvements are a good sign that providing more data will continue to improve the model. No improvement suggests that you may need to rethink how to set up the task for the model or restructure the data before scaling beyond a limited example set.
from datasets import load_dataset, DatasetDict, Dataset
from transformers import (
AutoTokenizer,
AutoConfig,
AutoModelForSequenceClassification,
DataCollatorWithPadding,
TrainingArguments,
Trainer)
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig
import evaluate
import torch
import numpy as np
2.4.2 微调数据构造
# # load imdb data
imdb_dataset = load_dataset("stanfordnlp/imdb")
# # define subsample size
N = 1000# # generate indexes for random subsample
rand_idx = np.random.randint(24999, size=N)
# # extract train and test data
x_train = imdb_dataset['train'][rand_idx]['text']
y_train = imdb_dataset['train'][rand_idx]['label']
x_test = imdb_dataset['test'][rand_idx]['text']
y_test = imdb_dataset['test'][rand_idx]['label']
# # create new dataset
dataset = DatasetDict({'train':Dataset.from_dict({'label':y_train,'text':x_train}),
'validation':Dataset.from_dict({'label':y_test,'text':x_test})})
import numpy as np # Import the NumPy library
np.array(dataset['train']['label']).sum()/len(dataset['train']['label']) # 0.508
IMDB 中数据格式例子如下:
{"label":0,"text":"Not a fan, don't recommed."}
分别使用 1000 条数据作为微调数据与验证数据。训练数据中,正向与负向的评价各占 50%。
2.4.3 加载初始模型
from transformers import AutoModelForSequenceClassification
model_checkpoint = 'distilbert-base-uncased'# model_checkpoint = 'roberta-base' # you can alternatively use roberta-base but this model is bigger thus training will take longer# define label maps
id2label = {0: "Negative", 1: "Positive"}
label2id = {"Negative":0, "Positive":1}
# generate classification model from model_checkpoint
model = AutoModelForSequenceClassification.from_pretrained(
model_checkpoint, num_labels=2, id2label=id2label, label2id=label2id)
# display architecture
model
import torch # Import PyTorch
model_untrained = AutoModelForSequenceClassification.from_pretrained(
model_checkpoint, num_labels=2, id2label=id2label, label2id=label2id)
# define list of examples
text_list = ["It was good.", "Not a fan, don't recommed.", "Better than the first one.", "This is not worth watching even once.", "This one is a pass."]
print("Untrained model predictions:")
print("----------------------------")
for text in text_list:
# tokenize text
inputs = tokenizer.encode(text, return_tensors="pt")
# compute logits
logits = model_untrained(inputs).logits
# convert logits to label
predictions = torch.argmax(logits)
print(text + " - " + id2label[predictions.tolist()])
输出:基本是随机输出。
Untrained model predictions:
It was good. - Positive
Not a fan, don't recommed. - Positive
Better than the first one. - Positive
This is not worth watching even once. - Positive
This one is a pass. - Positive
import evaluate # Import the evaluate module# import accuracy evaluation metric
accuracy = evaluate.load("accuracy")
# define an evaluation function to pass into trainer laterdefcompute_metrics(p):
predictions, labels = p
predictions = np.argmax(predictions, axis=1)
return {"accuracy": accuracy.compute(predictions=predictions, references=labels)}
from peft import LoraConfig, get_peft_model # Import the missing function
peft_config = LoraConfig(task_type="SEQ_CLS",
r=1,
lora_alpha=32,
lora_dropout=0.01,
target_modules = ['q_lin'])