大模型指令微调中的 Prompt 设计与数据集构建指南
1. 指令微调数据集形式与质量策略
在大型语言模型(LLM)的微调过程中,Prompt 的设计对模型的训练效果及推理表现有着至关重要的影响。许多开发者在推理阶段发现,若不使用特定的 Prompt 格式直接输入,模型性能会显著下降。这引发了一个核心问题:如果在训练阶段未包含 Prompt,测试时是否可以直接输入?此外,多轮对话与单轮对话的构造方式也直接影响最终模型的能力。
目前市面上的指令微调数据格式繁多,导致选择困难。针对这一问题,我们提出以下核心观点:
- 质量优先:单次实验微调所用的指令微调数据集应选取'高质量、高多样性'的数据。低质量的噪声数据会严重干扰模型收敛。
- 资源利用:在训练资源充足的情况下,可以加入数量更多、长度更大的数据集,以增强模型的泛化能力。
- 统一格式:建议基于多个高质量数据源,制作一份格式统一的多样性数据用于 SFT(Supervised Fine-Tuning)。一次性微调通常优于多次微调,后者可能导致灾难性遗忘或效果折扣。
- 增量微调方案:如果必须进行多次微调,建议采用 LoRA 或 QLoRA 等参数高效微调方法。将训练好的 LoRA 权重合并到原始底座模型中,可以有效减轻多次微调对模型原有能力的负面影响。
2. 常见指令微调模板分析
通过观测 Hugging Face 排行榜靠前和主流开源项目的指令微调数据集,我们可以总结出几种常见的 Prompt 模板结构。不同的模型架构往往对应特定的模板格式,混用会导致生成效果不佳。
2.1 Stanford Alpaca 模板
这是最经典的指令微调模板之一,适用于大多数基础指令跟随任务。
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
2.2 Llama2 模板
Llama2 引入了 System Prompt 机制,增强了模型的安全性和角色设定能力。
instruction = """[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n{} [/INST]"""
2.3 Linly-AI 模板
这是一种简洁的中文指令模板,常用于中文场景。
2.4 NousResearch (OpenLLM Leaderboard Top)
该模板结构与 Alpaca 类似,但更加强调响应前的换行符处理。
<prompt>
<leave a newline blank for model to respond>
当包含额外上下文时:
<prompt>
<additional context>
<leave a newline blank for model to respond>
2.5 Yayi 模板
Yayi 模型使用了特殊的 Token 标记来区分系统、用户和助手。
prompt = "你是谁?"
formatted_prompt = f"""<|System|>:
You are a helpful, respectful and honest assistant named YaYi developed by Beijing Wenge Technology Co.,Ltd. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<|Human|>:
{prompt}
<|YaYi|>:
"""
2.6 StableBeluga2 模板
Stable Beluga 2 采用了清晰的 System/User/Assistant 三段式结构。
This is a system prompt, please behave and help the user.
Your prompt here
The output of Stable Beluga 2
具体实现示例:
system_prompt = "### System:\nYou are Stable Beluga, an AI that follows instructions extremely well. Help as much as can. Remember, be safe, and don't do anything illegal.\n\n"
message = "Write me a poem please"
prompt = f"{system_prompt}### User: {message}\n\n### Assistant:\n"
2.7 Guanaco 数据集常用模板
Guanaco 项目广泛使用的 ChatML 风格变体。
或者更完整的描述:
prompt = "Introduce yourself"
formatted_prompt = (
f"A chat between a curious human and an artificial intelligence assistant."
f"The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
f"### Human: {prompt} ### Assistant:"
)
3. 多轮对话输入和输出构造
在多轮对话场景中,如何正确计算 Loss 是提升训练效率的关键。参考 Firefly 项目和 Chinese-Llama-2-7b 项目的实现,一般采用的方式是:在计算 Loss 时,通过 Mask 的方式,Input 部分的 Loss 不参与参数更新,只有 Target(回答)部分的 Loss 参与参数更新。这种方式充分利用了模型的优势,训练更加高效,且多轮对话中的每个 Target 部分都参与了训练。
3.1 损失掩码原理
如果不进行 Mask,而是将 n 轮对话拆分成 n 条数据,且只计算最后一个 Target 的 Loss,会大大降低训练效率。正确的做法是在 Tokenizer 编码后,对 Input 对应的 Label 位置填充 IGNORE_TOKEN_ID(通常为 -100),这样 PyTorch 的 CrossEntropyLoss 会自动忽略这些位置的梯度计算。
3.2 具体实现代码
以下是基于 LinkSoul-AI 项目的 Tokenize 函数示例,展示了如何处理 System Prompt 和多轮对话的角色分配。
def tokenize(item, tokenizer):
roles = {"human": "user", "gpt": "assistant"}
input_ids = []
labels = []
if "instruction" in item and len(item["instruction"]) > 0:
system = item["instruction"]
else:
system = dummy_message["system"]
system = B_SYS + system + E_SYS
item["conversations"][0]['value'] = system + item["conversations"][0]['value']
for i, turn in enumerate(item["conversations"]):
role = turn['from']
content = turn['value']
content = content.strip()
if role == 'human':
content = f"{B_INST} {content} {E_INST} "
content_ids = tokenizer.encode(content)
labels += [IGNORE_TOKEN_ID] * (len(content_ids))
else:
content = f"{content} "
content_ids = tokenizer.encode(content, add_special_tokens=False) + [tokenizer.eos_token_id]
labels += content_ids
input_ids += content_ids
input_ids = input_ids[:tokenizer.model_max_length]
labels = labels[:tokenizer.model_max_length]
trunc_id = last_index(labels, IGNORE_TOKEN_ID) +
input_ids = input_ids[:trunc_id]
labels = labels[:trunc_id]
(labels) == :
tokenize(dummy_message, tokenizer)
input_ids = safe_ids(input_ids, tokenizer.vocab_size, tokenizer.pad_token_id)
labels = safe_ids(labels, tokenizer.vocab_size, IGNORE_TOKEN_ID)
input_ids, labels
另一种实现方式参考 Firefly 项目,使用 target_mask 来标记哪些位置需要计算 Loss。
class SFTDataset(Dataset):
def __init__(self, file, tokenizer, max_seq_length):
self.tokenizer = tokenizer
self.bos_token_id = tokenizer.bos_token_id
self.eos_token_id = tokenizer.eos_token_id
self.eos_token = tokenizer.eos_token
self.bos_token = tokenizer.bos_token
self.max_seq_length = max_seq_length
logger.info('Loading data: {}'.format(file))
with open(file, 'r', encoding='utf8') as f:
data_list = f.readlines()
logger.info("there are {} data in dataset".format(len(data_list)))
self.data_list = data_list
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
data = self.data_list[index]
data = json.loads(data)
conversation = data['conversation']
utterances = []
for x in conversation:
utterances.append(x['human'])
utterances.append(x['assistant'])
utterances_ids = self.tokenizer(utterances, add_special_tokens=).input_ids
input_ids = [.bos_token_id]
target_mask = []
i, utterances_id (utterances_ids):
input_ids += (utterances_id + [.eos_token_id])
i % == :
target_mask += [] * ((utterances_id) + )
:
target_mask += [] * ((utterances_id) + )
(input_ids) == (target_mask)
input_ids = input_ids[:.max_seq_length]
target_mask = target_mask[:.max_seq_length]
attention_mask = [] * (input_ids)
(input_ids) == (target_mask) == (attention_mask)
inputs = {
: input_ids,
: attention_mask,
: target_mask
}
inputs
核心逻辑在于通过 IGNORE_INDEX(-100) 遮蔽掉 Input 对应的目标输出,确保反向传播仅针对模型预测正确的部分进行优化。
4. 高效率微调大模型的最佳实践
如何在短时间、高效率地训练出实际效果不错、综合能力较强的大模型?从指令微调数据集处理工作上,建议遵循以下流程:
4.1 数据集准备策略
- 多样化来源:事先准备多种高质量的指令微调数据集,每个数据集尽量保持差异性。高质量数据的定义可以参考当前效果不错的开源模型(如 Llama-2, ChatGLM 等)所公开的训练数据分布。
- 多轮对话增强:实验表明,加入多轮对话的数据有助于提升模型的上下文理解能力和生成长度。如果仅用单轮对话或单轮指令训练,模型生成的文本往往偏短,缺乏连贯性。
- 模板一致性:微调时使用某种模板,推理时也必须严格使用相同的模板。否则会导致效果大幅下降,表现为生成内容短小、逻辑混乱甚至中英文混杂。例如,训练使用了英文模板,推理时未使用提示模板,模型可能无法识别指令边界。
4.2 超参数调整建议
- Learning Rate:对于 SFT 任务,学习率通常设置在 1e-5 到 5e-5 之间。过高的学习率可能导致模型破坏预训练知识,过低则收敛缓慢。
- Batch Size:根据显存大小调整有效 Batch Size。较大的 Batch Size 有助于稳定梯度,但需配合 Gradient Accumulation 使用。
- Epochs:SFT 通常不需要过多的 Epoch,1-3 个 Epoch 即可达到饱和。过多的 Epoch 容易导致过拟合。
4.3 评估与验证
- Perplexity (PPL):监控验证集的困惑度,作为模型收敛的参考指标。
- 人工评估:随机抽取生成结果,由人工判断指令遵循程度、事实准确性和流畅度。
- 自动化基准:使用 MMLU、C-Eval 等标准基准进行测试,量化模型能力提升情况。
5. 常见问题与解决方案
5.1 灾难性遗忘
如果在 SFT 过程中发现模型丢失了原有的通用知识,可能是由于学习率过大或数据分布过于单一。解决方法包括混合少量预训练语料(Continual Pre-training)或使用正则化约束。
5.2 显存溢出
对于大模型微调,显存不足是常见问题。建议使用 DeepSpeed ZeRO-3 优化器状态分片,或者使用 QLoRA 技术将模型量化至 4bit,可大幅降低显存需求。
5.3 生成重复
如果模型输出出现大量重复内容,可以尝试调整 Temperature 参数,或在训练数据中加入去重步骤。此外,检查解码策略(如 Top-P, Top-K)的设置。
6. 总结
大模型指令微调是一个系统工程,涉及数据清洗、Prompt 设计、训练策略和评估等多个环节。选择合适的模板格式、保证数据质量、合理配置训练参数是成功的关键。随着技术的演进,未来可能会出现更多高效的微调框架和自动化工具,但理解底层的 Prompt 机制和数据构造原理依然是开发者的核心竞争力。