跳到主要内容从零训练文本嵌入模型:Arctic Embed 方法解析 | 极客日志PythonAI算法
从零训练文本嵌入模型:Arctic Embed 方法解析
文本嵌入模型在搜索和检索增强生成(RAG)中至关重要。基于 Arctic Embed 项目,详细阐述了从零训练嵌入模型的全流程。核心策略包括两轮训练法:大规模预训练结合批内负样本,以及微调阶段引入困难负样本挖掘。文章深入分析了数据过滤、合成数据生成、池化层选择([CLS] vs 平均池化)及序列长度对性能的影响。实验表明,数据采样质量与困难负样本策略比单纯扩大规模更有效。通过消融研究验证了源分层、更长截断长度及课程学习的重要性,为构建高性能检索模型提供了实践指导。
人间失格0 浏览 从零训练文本嵌入模型:Arctic Embed 方法解析
嵌入模型(Embedding models)在无需额外调优的情况下,能够提供准确的检索性能,这使得它们在搜索和检索增强生成(RAG)工作负载中备受青睐。与传统的关键词搜索不同,嵌入模型超越了词汇重叠的限制进行信息编码,将文本映射到高维向量空间,使得语义相似的文本在空间中距离更近。
由于这些模型的实用性及其广泛采用,开放源代码和研究社区不断推出越来越强大的文本嵌入模型,如 E5、GTE 和 Jina。这些工作的快速实验和改进,部分归功于大型开放评估基准的支持,如 MSMARCO、BEIR 和 MTEB。这些基准测试平台为研究人员提供了可靠的评估环境,促进了嵌入模型的持续发展和优化。
本文基于 Arctic Embed 项目的实践,提出了一组消融实验,表明在训练过程中数据采样和负采样方法比扩大数据规模和批量大小更能显著改善检索质量,而此前的工作主要集中在后者。此外,本文还介绍了一种基于挖掘的困难负样本的创新查询生成技术,发现其比同时生成查询和负样本的直接生成方法更有效。
Arctic 嵌入模型架构
在 Arctic-embed 的开发中,我们旨在从文献中公认的最佳实践出发,从头开始训练一个嵌入模型。与 E5、BGE、GTE、Jina 和 Nomic 等先前工作一致,我们进行了两轮训练,使用了两种不同类型的数据集。
网络结构
我们基于不同规模的 BERT 类模型进行了训练。我们的 m 和 l 模型采用标准的 BERT 架构(分别为 BERT base 和 BERT large)。对于较小的模型(xs 和 s),我们选择了 MiniLMv2 架构的变体,采用了 Nomic BERT 架构。这种分层设计允许我们在计算资源受限的场景下也能获得良好的性能表现。
池化层
在架构上,我们没有对任何基础模型进行修改,甚至没有进行常见的添加池化层的操作。此外,我们使用 [CLS] 标记的最终隐藏状态作为嵌入向量,而不是 E5、GTE 和 Nomic 中使用的平均池化策略。这一选择与 BGE 架构一致,实验表明 [CLS] token 往往能更好地捕捉句子的整体语义表示。
训练数据集构建
在创建我们的训练数据集时,我们从大型语言模型(LLMs)领域汲取灵感,并利用了受到 RefinedWeb、C4、Gopher 和 TogetherAI 启发的过滤方法。
首先,对于嘈杂的原始数据源,如网络搜索,我们使用 trafilatura 解析结构化的网页文档。在解析过程中,我们计算用于质量过滤的自定义信号。具体来说,对于正数据对清理,我们需要确保:
- 每对文本都具有良好的质量(语言过滤器、文本质量过滤器)。
- 文本对(查询,文档)在意义上相似(一致性过滤器)。
对于质量过滤,我们利用了一系列与 Snowflake 的 Arctic 模型训练手册中详细介绍的类似的过滤器。对于一致性过滤,我们采用了低保真、高吞吐量的成对相似性一致性过滤器 — 使用 fastText 的 word2vec 模型进行句子相似性计算(可以在 CPU 上便宜地运行)。我们不将这些嵌入信号视为明确的质量标签,而是采用保守的阈值(最低允许相似度为 0.3),并用它们来过滤掉无关的示例。
此外,在此步骤中,我们将长序列截断为 512 个单词。正如我们观察到的那样,网络语料库中的查询通常在文档的开头得到回答。这不仅在计算上是浪费的,而且即使是 word2vec 嵌入中捕获的含义也会因为从后面的不相关单词中平均向量而被稀释。
数据集混合与采样
由于不同数据集的大小、一致性、难度和学习动态的不同,简单地将所有可用数据集连接在一起被证明是一种次优策略,特别是在微调阶段。相反,我们进行了孤立实验,以了解每个数据集对微调性能的影响。然后,我们根据这些实验中它们的相对性能选择并组合数据集。
预训练数据集
我们的大规模预训练数据集包括了 3.08 亿个查询 - 文档对(从大约 20 亿个文档中筛选),其中 71% 是与查询或标题配对的网络搜索文档。除了网络搜索数据外,文本对集还包括了来自常见抓取源的 PAQ、StackExchange 标题 - 正文和标题 - 正文网络文档对,以及 S2ORC 标题 - 摘要对。
微调数据集
我们的微调数据集由我们的网络搜索数据与几个公共数据集(HotpotQA、NQ、Fever 和 StackExchange 标题 - 正文)组合而成,然后通过以下章节详细描述的合成挖掘策略进行进一步扩展。这种混合明显省略了其他嵌入模型使用的几个流行公共数据集,因为我们观察到正样本一致性和负样本难度水平。这些发现不太有用的数据集包括 NLI、MEDI、WikiAnswers 和 SQuAD。
经验上,我们观察到,在微调阶段,质量比数量更重要,过量的低质量数据可能会导致模型质量降低。
合成数据
与预训练中使用的大量网络规模数据相比,适用于微调的高质量示例更为稀缺。然而与这些先前方法不同,我们发现向我们的 LLM 输入中添加负面文档对于确定查询生成至关重要。此外,我们选择只生成合成查询,而不是合成负样本,因为我们发现 LLM 不容易生成与从现有文档语料库中挖掘的同样高质量的相关负样本。
图 4 展示了这一方法的实施方式——由算法 2 的变体生成的两个数据集导致了接近原始 HotpotQA 所提供的得分增加。
困难样本挖掘
微调数据集通常包括精心选择的'困难'负样本示例和一个与之正相关的查询 - 文档对。在微调阶段,这些负样本应该有多难才能实现最大效果的学习?我们对这个问题的答案最终是一个可调节的困难负样本挖掘策略,其中我们利用一个现有的文本嵌入模型来识别和评分每个训练示例中的最难负样本。然后,我们应用一个分数阈值来丢弃以上集合中的困难负样本。
我们发现,使用一个上限阈值而不是一个特定的排名有助于考虑到一些查询接受的前 k 个最难负样本比其他查询更难。除了调整到单个难度阈值级别外,我们假设按照负样本的困难程度对数据进行排序(即课程学习)可能会导致更好的结果。
训练细节
批内负样本对比预训练
在对比训练的第一轮中,我们旨在实现大规模,无论是在批处理还是总数据集大小上。我们使用我们的预训练数据集,采用 InfoNCE 对比损失,使用批内负样本(对于每个查询,批处理中与不同查询相关的所有文档都被视为负样本)。GPU 并行化、激活检查点和截断的序列长度在实现大批量大小方面起到了重要作用。
我们使用 AdamW 优化器进行一次纪元的训练,仅调整学习率,而将所有其他参数保留为 PyTorch 默认值。我们进行线性学习率预热数百步,然后在训练的剩余部分线性衰减到原始学习率的 10%。
更长的截断长度
我们的训练数据中包含了许多长度超过 128 个标记的文档。在大规模对比训练中,我们使用了 256 的文档序列长度,与 GTE 和 BGE 使用的 128 截断长度形成对比。我们将查询序列长度截断为 32,与 BGE 的源代码保持一致。我们在第 7 节的消融研究中发现,这种更长的截断长度导致了检索性能的显著提高。
源分层
在预训练期间,我们将每个批次填充了来自单个来源的数据,这是之前工作中准确度提升的一个来源。这有助于模型在不同分布的数据上保持稳定学习。
对比训练与精选负样本
在大规模训练之后,我们进行第二轮训练,利用我们的微调数据集,其中包含明确标记的负样本示例。我们不使用学习率预热,但应用与预训练阶段相同的线性学习率衰减计划。对于所有模型,包括长上下文变体 m-long,我们将序列长度截断为 512,用于查询和文档。对于批处理中的每个查询,我们包含一个正文档和十个难负样本文档。
禁用批内负样本损失
根据一些早期的微调运行结果,我们发现禁用批内负样本损失并没有明显降低性能。我们在微调期间停止使用批内负样本,专注于硬负样本的挖掘。
实验对比
为了评估我们的检索质量,我们在 MTEB 数据集的检索部分上评估模型性能。MTEB 实验的摘要结果显示,我们的模型在多个基准测试中取得了优异的表现。为了评估我们的长上下文模型的性能,我们利用了 LoCo 评测。
预训练消融
我们在一系列消融研究中探索了批量大小、序列长度、基础模型和训练数据的影响。消融研究结果支持了我们关于数据采集、更长的序列长度和源分层改善模型性能的假设。相比之下,从预训练的检索模型开始初始化在预训练后并没有显著影响 MTEB 检索分数。我们还注意到了源分层在训练后期比批量大小等其他因素更为重要的类似课程学习的模式。
微调消融研究
我们的可调节负样本挖掘方法使用阈值来过滤过难的负样本。我们对几个阈值数值进行了消融研究,以展示阈值参数的重要性。所示的结果表明,太低和太高的最大相关性阈值(过难和过简单的负样本)会导致性能显著下降。
端到端消融研究
为了彻底研究训练数据对最终得分的影响,我们通过微调步骤扩展了我们的部分预训练消融研究。尽管在预训练中,使用 Snowflake 和 Nomic 数据预训练的模型之间的性能差距相对较小,但随着微调,尽管微调配方相同,差距显著扩大。我们还看到了使用 e5-unsupervised-base 的配置最终得分略有改善。
总结
本文详细阐述了从零训练高性能文本嵌入模型的方法论。通过 Arctic Embed 的实践,我们验证了以下关键结论:
- 数据质量优于数量:在微调阶段,精心筛选的高质量数据比海量低质数据更能提升模型性能。
- 困难负样本挖掘:引入可调节阈值的困难负样本挖掘策略,结合课程学习,能显著提升检索精度。
- 序列长度优化:适当增加文档序列长度(如 256 或 512)能有效捕捉更长上下文的语义信息。
- 架构选择:使用 [CLS] 标记作为嵌入向量在某些场景下优于平均池化,且需根据具体任务调整。
遵循上述最佳实践,开发者可以构建出在搜索和 RAG 应用中表现卓越的嵌入模型。
相关免费在线工具
- 加密/解密文本
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
- RSA密钥对生成器
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
- Mermaid 预览与可视化编辑
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
- curl 转代码
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
- Base64 字符串编码/解码
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
- Base64 文件转换器
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online