跳到主要内容大模型算法二次开发:核心流程与关键技术详解 | 极客日志PythonAI算法
大模型算法二次开发:核心流程与关键技术详解
综述由AI生成详细解析了大模型二次开发的四个关键阶段:增量预训练、有监督微调、奖励模型对齐及人类反馈强化学习。重点阐述了 SFT、RLHF 与 DPO 的技术差异,并深入探讨了持续学习中的 Prompt-based、Representation-based 及 Model Mixture 三种主流方法及其优缺点。内容涵盖从领域知识注入到偏好对齐的完整技术路径,以及部署时的量化、服务化与安全过滤最佳实践,为垂直领域大模型构建提供理论支撑与实践思路。
无尘25 浏览 大模型算法二次开发:核心流程与关键技术详解
随着众多大模型相继问世,大模型二次开发、大模型微调成为一项热门技术。本文总结了大模型二次开发的基本方法与思路,涵盖从领域知识注入到偏好对齐的完整技术路径。
开发方法分类
- 领域知识注入:Continue PreTraining(增量预训练)。一般垂直大模型是基于通用大模型进行二次开发,需要用领域内的语料进行继续预训练。
- 知识召回(激发):SFT(Supervised Fine-tuning,有监督微调)。通过 SFT 可以激发大模型理解领域内的各种问题并进行回答的能力。
- 基础偏好对齐:奖励模型(RM)、强化学习(RL)。可以让大模型的回答对齐人们的偏好,比如行文的风格。
- 高阶偏好对齐:RLHF(人类反馈强化学习训练)、DPO(直接偏好优化)。
开发阶段分类
模型通常分为三个主要阶段:
- 第一阶段:增量预训练 (Continue PreTraining)。在海量领域文档数据上二次预训练 GPT 模型,以注入领域知识。
- 第二阶段:有监督微调 (SFT)。构造指令微调数据集,在预训练模型基础上做指令精调,以对齐指令意图。
- 第三阶段:偏好对齐。RLHF 和 DPO 二选一,用于优化模型输出的人类偏好匹配度。
各个阶段功能介绍
1. 增量预训练 (Post-pretraining)
在大模型整个流程中,增量预训练属于后期预训练(Post-pretraining)的范畴。这是一种在模型的初始预训练和最终微调之间进行的训练方法,通常用于进一步适应模型以处理特定类型的数据或任务。
- 定义:Post-pretraining 是在通用预训练模型的基础上,对模型进行额外训练的过程,通常是为了使模型更好地适应特定的领域或任务。
- 数据集:使用的数据集通常比预训练阶段的数据集更专注于某个领域或任务,但比微调阶段使用的数据集更大、更广泛。
- 训练方法:可以是监督学习,也可以是自监督学习,具体取决于可用数据的类型和训练目标。
- 目标:在不过度专化到某个特定任务的同时,提高模型对特定领域的理解和表现。
- 优势:允许模型在保持通用性的同时,增强对特定领域的理解,有助于模型在后续的微调阶段更快速地适应特定任务。与 SFT 相比,Post-pretraining 在微调之前提供了一个中间步骤,有助于模型更平滑地过渡到特定任务上。
2. 微调 (Fine-tuning) & SFT
在这个阶段,预训练模型(可能经过了 Post-pretraining)被进一步训练,以优化它在一个特定任务上的表现。微调通常在一个相对较小的、特定任务的数据集上进行,这个数据集包含了明确的标签,模型通过监督学习来进行优化。
SFT (Supervised Fine-Tuning) 是微调的一种形式,强调在有监督的环境下进行。在 SFT 阶段,我们使用特定领域的数据或私有化数据对预训练模型进行改良。这一阶段需要指令微调数据,数据集通常由输入(用户问题)和输出(标准答案)两个字段构成。标准答案通常由专家标注获得。
- SFT 是一种简单的微调方法,它使用带有正确答案的数据集来继续训练一个预训练的模型。
- 这种方法依赖于大量的标注数据,即每个输入都有一个预先定义的正确输出。
- 微调的目的是使模型更好地适应特定的任务或领域【垂直领域】,比如特定类型的语言理解或生成任务。
- SFT 通常不涉及复杂的策略或奖励函数,只是简单地最小化预测输出和真实输出之间的差异。
3. RLHF 人类反馈强化学习
RLHF 是一种利用人类反馈来训练强化学习模型的方法。在 RLHF 中,模型通过与人类交互获得反馈,这些反馈作为奖励信号来指导模型的行为。RLHF 通常用于训练能够生成更自然、更符合人类偏好的文本或其他输出的模型。这种方法特别适用于需要模型理解和适应人类偏好的场景。
- RLHF (Reinforcement Learning from Human Feedback) 是一种更复杂的训练方法,它结合了监督学习和强化学习。
- 在 RLHF 中,模型首先通过监督学习进行预训练,然后通过人类提供的反馈来进行强化学习。
人类反馈可以是对模型输出的评分,或者是在模型输出之间做出选择的偏好。强化学习部分涉及到定义一个奖励函数,该函数根据人类反馈来调整模型的行为,以优化长期的奖励。RLHF 的目标是训练出一个在没有明确标签的复杂任务中表现良好的模型,这些任务可能需要更细致的判断和调整。4. 模型对齐
对齐阶段目的是进一步优化模型,使其更符合实际应用需求。在这个阶段,我们收集用户反馈数据(如点赞或点踩),并基于这些数据进行模型的进一步训练。
对齐阶段的数据格式与 SFT 阶段不同:通常包含对同一问题的接受(accept)和拒绝(reject)两种答案。
问题解决策略及部署:在 SFT 阶段,模型被训练以识别'想要的答案',但未明确告知'不想要的答案'。为解决这一问题,我们通过收集用户反馈和日志数据,在对齐阶段告诉模型哪些答案是不可接受的。经过 SFT 和对齐阶段的训练,我们可以得到一个优化后的模型,这个模型可以部署上线。在对齐过程中,我们可以使用一些常见的方法,如 PPO(Proximal Policy Optimization)和 DPO(Distributional Proximal Optimization)。DPO 由于训练过程相对简单,已成为对齐阶段的主流算法。
总的来说,SFT 更侧重于直接从标注数据中学习,而 RLHF 则试图通过人类的反馈来引导模型学习更复杂和更细粒度的行为。RLHF 通常被认为是一种更接近人类学习方式的方法,因为它不仅仅依赖于标签数据,还依赖于人类对模型输出的评价和偏好。
5. RLHF 与模型对齐区别
总的来说,模型对齐阶段可以视为一个更广泛的概念,而 RLHF 是一种特定的实现方式,特别是在强化学习领域。两者在实践中可能会有交集,但它们侧重点和应用方式有所不同。
- 联系:两者都涉及到根据反馈来调整模型的行为,以提高模型的性能和适应性。
- 区别:
- 技术实现:对齐阶段可能不仅限于强化学习,还可以包括监督学习或其他类型的学习;而 RLHF 明确使用了强化学习框架。
- 反馈来源:对齐阶段的反馈可以来自用户的实际使用情况,而 RLHF 的反馈通常来自与模型交互的人类评估者。
- 目标:对齐阶段的目标是使模型的输出与用户期望对齐,而 RLHF 的目标是通过人类反馈来优化模型的决策过程。
技术创新与发展
RLHF
利用人类指导的力量有效地训练人工智能模型。传统的强化学习模型通过与环境交互产生的奖励来学习,而 RLHF 则不同,它引入了人类反馈作为宝贵的指导来源。这种反馈可以帮助人工智能系统导航复杂的决策空间,与人类价值观保持一致,并做出更明智和道德的选择。
DPO
RLHF 是一个复杂且经常不稳定的过程,那我们是否可以通过直接优化语言模型来实现对其行为的精确控制,而无需使用复杂的强化学习,也可以有效学习到人类偏好呢?
DPO 就是通过利用奖励函数与最优策略之间的映射关系,证明这个受限的奖励最大化问题可以通过单阶段的策略训练来精确优化来达到这个目的的。
它通过直接优化模型以生成首选响应,将问题表述为使用人类偏好对数据集的分类任务,本质上是在人类偏好数据上解决一个分类问题。DPO 是稳定的、性能和计算成本轻量级的,无需拟合奖励模型,在微调期间从 LM 中采样,或执行显着的超参数调整。
DPO 是一种单阶段算法,可直接优化 LLM 以生成首选响应。具体的实现手段是通过增加偏好样本的对数概率与减小非偏好样本响应的对数概率。DPO 最大化生成首选完成的概率并最小化生成非首选完成的概率。它不涉及多轮训练。
持续学习方法详解
增量预训练(Continue Pretraining)
增量预训练也叫领域自适应预训练(domain-adapter pretraining),即在所属领域数据上继续预训练。自适应预训练的方法可以分为三类:Prompt-based 方法、representation-based 方法和 model mixed-based 方法。
1. Prompt-based 方法
在使用模型全局 tuning 的方式适应下游任务时,预训练模型的泛化性能会被严重削弱,因此 Prompt-based 方法在保持预训练模型参数权重不变的条件下,增加额外可学习的 Prompt tuning 模块来实现对下游任务的泛化,这样就能较好地保持原模型的泛化性能。
VPT 这种方式虽然可以较好地保留模型的泛化性,但是,在面对新的任务时,以往的 Prompt 模块的知识同样被覆盖,依旧遭遇了灾难性遗忘问题。为此,有学者提出了 Prompt Pool 的概念,设计了 Prompt 模块的集合,即 P={P1,P2,…,Pm}(m 表示该 Pool 的最大尺寸)。Prompt Pool 的思想有效避免了单一 Prompt 的问题,但是 Pool 的设计使得其需要进行 Prompt Selection 操作,也就是需要将特定任务与其对应的 Prompt 模块进行索引匹配。
L2P 算法是一种较为常用的 Prompt selection 算法,该算法设计了一种 Key-Query 的 Prompt 匹配方法,也就是为每一个 Prompt 提供一个可学习的索引键 k,即 P={(k1,P1),(k2,P2),…,(km,Pm)。L2P 利用预训练模型将输入特征编码到 Key 对应的嵌入空间中,然后利用余弦距离损失函数在已有的 Pool 中搜索最近似的 Key。接着,利用如交叉熵损失等方法对搜索到的 Key 对应的 Prompt 进行进行优化。
类似的 Prompt Selection 算法很多,如 DualPrompt 算法,该算法将 Prompt 进行解耦,分化为 General Prompt 和 Expert Prompt。General Prompt 面向所有任务,为所有任务中共享信息,而 Expert Prompt 针对独立任务,数量与任务量一致。其采用了和 L2P 相同的 key-query 匹配策略。
Prompt Selection 虽然可行,但仍是硬匹配,选项有限。基于注意力信息加权的 Prompt Combination 方法则有效缓解了该问题。如 CODA-Prompt 通过对 Prompt Pool 进行注意力机制嵌入,为每个注意力赋予自适应权重,进而求算全局 Key-Query 的加权和,实现可学习式 Prompt 组合。
从根本上来说 Prompt Combination 仍受制于 Prompt Pool 的范围。为此,许多学者则开展 Prompt Generation 有关的研究,如 DAP,其利用 MLP 进行特定任务提示信息的编码生成。
- Prompt 有助于弥合 domain gap,并可有效地对特定任务的知识进行编码。
- Prompt Design 属于 lightweight 模块,与 input feature 具有相同的维度,因此保存 Prompt 是 parameter-efficient,适用于边缘场景。
- Prompt Pool 作为预训练模型的外部存储器,其支持自适应知识的检索和特定实例的预测。
- 一些研究中发现 L2P 中的 prompt selection 过程收敛到一个单点,使得 prompt selection 只集中在特定子集上。
- 由于 key 和 query 在整个学习过程中不断变化,这些参数的更新将会消除先前任务的参数,导致 matching-level 和 prompt-level 的遗忘,使 prompt selection 成为 CL 的瓶颈。
- 固定大小的 Prompt Pool 会使得模型的表示能力受限。但是,若 Prompt Pool 随着数据的发展而增长,可能会为旧任务检索新的提示,导致训练和测试之间的不匹配。
- 最后,一些研究发现 prompt-based CL 的性能低于简单的 representation-based 的 baseline 性能。并且批量提示有损比较的公平性。
2. Representation-based 方法
representation-based 方法直接利用预训练模型强大的泛化性和通用性来实现持续学习。比如 Simple-CIL 方法,该算法是 ADAM 算法原文中提出的 Baseline,Simple-CIL 冻结预训练模型参数,并通过求算类别中心的方式来构建 Classifier。具体来说,在面对很多类别时,计算同类的 embedding 或 features 的平均值,并将该平均值作为该类别的标准(prototype),最后结合类别标准与余弦比较的方法替换模型的原始 Classifier。
虽然基于 prototype 的方法存在一定的作用,但是并未很好地适应下游任务。为此,一些研究在基于 prototype 方法的基础上结合了外置参数高效调节模块或者外置适配器来使得预训练模型更加适应下游任务,如 ADAM 等。
ADAM 等算法在进行类别标准设定时,类别标准之间的仍存在联系,导致任务效果降低。为此,RanPAC 算法则采用 online LDA classifier 来去除原始方法 prototype 计算结果之间的相关性,加大类别间的分布差异。此外,RanPAC 算法利用 Random Projection layer 将 features 映射到高维空间中,并在高维空间中进行 prototype 的计算,以使得特征分布符合高斯拟合。
相较于前面将预训练模型的通用语和适应性分离处理的方式,SLCA 算法采用了差异学习率调整和特征经验重播的方式进行持续学习研究。该算法使用较小的 learn rate 调整模型主体部分,而使用较大的 learn rate 调节模型的 classifier,以实现模型的逐步微调和 classifier 的快速适应。为了避免忘记以前的分类器,SLCA 还对分类特征分布进行建模,并重播它们以校准 classifier。
- 由于 class prototype 代表了对应类别最常见的标准格式,因此利用其构建模型具有直观和可解释性。
- Representation-based 方法主要是冻结 backbone 和更新 classifier 权重。lightweight 的更新成本增加了其现实应用的可行性。
- 将不同模型的特征连接起来形成 class prototype,容易造成模型信息冗余。例如,不同的 backbone 中存在重复提取共享特征。
- 当下游任务涉及多个领域时,在第一阶段调整模型不足以弥合数据集之间的领域差距。在这种情况下,不断调整 backbone 可能更适合提取特定于任务的特征。
3. Model Mixture-based 方法
Model Mixture-based 方法在持续学习工程中构建了一组模型,然后再推理阶段通过 Model Ensemble 和 Model Merge 来进行信息综合决策。
Model Ensemble 中,ESN 算法凭借预训练模型强大的通用性,构建多个 classifier,在面对新任务重新初始化和训练一个新的 classifier。在推理时,采用投票策略来整合多个模型的结果进行最终决策。
由于 Model Ensemble 的核心因素取决于模型的方差,一些研究通过增强模型之间的多样性来替代使用相同的预训练模型构建不同的 classifier。如 PromptFusion 利用预训练的 ViT 和 CLIP,并在推理过程中动态地对 logit 进行组合,即 f(x) = λ fvit (x) +(1−λ)fclip(x)。
与多个 backbone 的集成不同,PROOF 采用了仅使用单个 CLIP 的更全面的推理方法。由于 CLIP 支持视觉和文本特征的跨模态匹配,因此 PROOF 设计了一个三层集成,考虑 image-to-text、image-to-image prototype、image-to-adjusted text 的跨模态融合。
Model Merge 将多个不同的模型合并为一个统一的模型,无需要额外的训练。LAE 定义了 online 和 offline 学习协议,online 模型通过交叉熵损失进行更新,目的是在新的任务中获取新的知识。离线模型则通过 Model Merge 进行更新,例如指数移动平均 (EMA): θ offline←α·θ offline +(1−α)·θ Online,其中 α 为权衡参数。LAE 仅将 EMA 应用于参数高效调谐模块 (如 prompt),其利用 online 和 offline 模型的最大 logit 进行推断。
与 LAE 一样,ZSCL 将合并技术应用于 CLIP 模型,目的是在持续学习过程中保持其 zero-shot 性能。然而,随着 EMA 中权衡参数的改变,CLIP 性能不再具有鲁棒性。因此,ZSCL 建议每隔几次迭代合并参数,从而在模型训练期间创建平滑的损失轨迹。
此外,CoFiMA 注意到 EMA 在 Merge 过程中对每个参数的重要性是相等的,CoFiMA 在 Merge 过程中插入 Fisher information(费雪信息)作为每个参数的估计重要性。
- 学习多个模型可以做出不同的决策。因此,使用 Model Ensemble 和 Model Merge 自然会产生更健壮的结果。
- 由于直接合并模型进行统一预测,因此可以调整前模型和后模型的权重,以突出不同阶段之间知识共享的重要性。
- 由于模型集将在推理过程中合并,因此最终的推理成本不会随着模型集中添加更多模型而增加。
- Model Ensemble 需要保存所有的历史模型,并消耗大量的内存缓冲区。虽然基于 Model Merge 不需要这么大的成本,但合并大型 backbone 的权重也需要大量的额外计算。
- 决定 Merge 哪些参数仍然是问题。
部署与最佳实践
在完成上述训练流程后,模型的部署与优化同样关键。以下是几个重要的实践建议:
- 量化与压缩:为了降低推理成本,可以使用 INT8 或 FP16 量化技术。对于资源受限的边缘设备,可以考虑知识蒸馏,将大模型的能力迁移到小模型中。
- 服务化封装:使用 FastAPI 或 Flask 将模型封装为 RESTful API,配合 Nginx 进行负载均衡。对于高并发场景,建议使用 vLLM 或 TGI 等高性能推理框架。
- 监控与回滚:建立完善的监控体系,跟踪延迟、吞吐量及错误率。一旦发现模型表现下降,应具备快速回滚到上一版本的能力。
- 安全过滤:在输入和输出层增加内容安全过滤器,防止敏感信息泄露或生成有害内容,确保符合合规要求。
总结
大模型二次开发是一个系统工程,涉及从数据准备、模型选择、训练策略到部署优化的全流程。选择合适的微调方法(如 SFT、DPO)和持续学习策略(如 Prompt-based、Representation-based)是提升模型垂直领域表现的关键。随着技术的演进,DPO 等更高效的对齐方法正逐渐取代传统 RLHF,成为行业主流。开发者应根据业务场景的资源约束和性能需求,灵活组合上述技术,构建高质量的大模型应用。
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- 随机西班牙地址生成器
随机生成西班牙地址(支持马德里、加泰罗尼亚、安达卢西亚、瓦伦西亚筛选),支持数量快捷选择、显示全部与下载。 在线工具,随机西班牙地址生成器在线工具,online
- Gemini 图片去水印
基于开源反向 Alpha 混合算法去除 Gemini/Nano Banana 图片水印,支持批量处理与下载。 在线工具,Gemini 图片去水印在线工具,online
- curl 转代码
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online