CoSENT 句嵌入模型理论介绍与语义检索实践
引言
CoSENT(Cosine Sentence)是一种基于余弦相似度排序损失的句嵌入模型。相较于传统的 Sentence-BERT,CoSENT 在训练目标上进行了改进,使其更契合文本匹配的实际应用场景。本文将对 CoSENT 的理论基础进行简述,并结合领域文本训练句嵌入以实现语义检索,最终对比 CoSENT 和 Sentence-BERT 的效果差异。
有监督句嵌入模型概述
句嵌入是将句子表征为向量的过程,基于句向量可以进一步完成文本匹配、文本聚类、语义搜索等下游场景任务。句嵌入主要分为无监督和监督两大类。
Sentence-BERT 是一种典型的有监督句嵌入方案,它通过人工标注的三元组数据(句子 1,句子 2,是否相似),微调 BERT 使得相似语义的文本表征距离更小。而无监督的方案不需要人工标注,它依据文本的上下文关系来构造出预测任务,句嵌入是该任务的中间产物。这类方法包括 Word2Vec 词嵌入池化、Doc2Vec、Sentence2Vec、Skip-Thought Vectors 等。
本篇重点介绍另一种有监督句嵌入模型CoSENT。它将cosine 余弦相似度的排序损失引入到 Sentence-BERT 的训练环节,使得训练过程更加契合应用场景,同时加快模型在训练阶段的收敛。在众多数据集上,CoSENT 的表现优于 Sentence-BERT。
快速开始:使用 CoSENT 生成句嵌入
模型加载
在 HuggingFace 模型仓库中下载 shibing624/text2vec-base-chinese 预训练模型。该模型以 macbert 作为基座,通过 CoSENT 损失函数策略微调得到,可以实现对输入文本做 Embedding 表征。
CoSENT 也是 BERT 模型微调的结果,因此使用 BERT 的模型 API 导入 CoSENT 模型和词表。
from transformers import BertTokenizer, BertModel
import torch
embedding_model_name = "./text2vec-base-chinese"
embedding_model_length = 512
tokenizer = BertTokenizer.from_pretrained(embedding_model_name)
model = BertModel.from_pretrained(embedding_model_name)
预处理与编码
输入样例句子,对它们进行分词编码预处理。注意设置 padding=True 和 truncation=True 以适应批量处理。
sentences = [
'我不知道过年火车票能不能抢到',
'过年假期你准备去哪里玩',
'我准备春节请假两天提前回家,但是好没有抢到票',
'这个假期太短了,我作业还没有做完'
]
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
均值池化
输出层需要使用 BERT 最后一层 block 的非 Padding 位置所有词 Embedding 的均值池化作为句嵌入。定义 mean_pooling 函数来实现该操作。
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
推理与验证
最后我们使用 CoSENT 对句子进行推理,生成 [batch_size, 768] 的矩阵,代表每个句子表征为 768 维的向量。
with torch.no_grad():
model_output = model(**encoded_input)
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']).cpu().numpy()
print(sentence_embeddings.shape)
我们以第一句'我不知道过年火车票能不能抢到'为目标,分别计算它和其他三个句子的余弦相似度,来初步验证 CoSENT 做文本匹配的有效性。
import numpy as np
def compute_sim_score(v1, v2):
return v1.dot(v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
score_1 = compute_sim_score(sentence_embeddings[0], sentence_embeddings[1])
score_2 = compute_sim_score(sentence_embeddings[0], sentence_embeddings[2])
score_3 = compute_sim_score(sentence_embeddings[0], sentence_embeddings[3])
print(f"与句 1 相似度:{score_1:.4f}")
print(f"与句 2 相似度:{score_2:.4f}")
print(f"与句 3 相似度:{score_3:.4f}")
结果汇总如下,句嵌入的相似结果和实际的语义情况相符,说明 CoSENT 和 text2vec 预训练模型有一定的效果。
| 目标句子 | 候选句子 | 余弦相似度 |
|---|
| 我不知道过年火车票能不能抢到 | 过年假期你准备去哪里玩 | 0.4940 |
| 我不知道过年火车票能不能抢到 | 我准备春节请假两天提前回家,但是好没有抢到票 | 0.6530 |
| 我不知道过年火车票能不能抢到 | 这个假期太短了,我作业还没有做完 | 0.4531 |
CoSENT 的目标函数详解
CoSENT 是 Sentence-BERT 的改进版本,两者的模型基座相同,但在损失函数部分,CoSENT 使用余弦相似度排序损失,替换了 Sentence-BERT 的分类交叉熵损失。
训练与预测的不一致性
一般有监督句嵌入的样本为三元组(句子 1,句子 2,是否相似)。Sentence-BERT 的训练过程以是否相似作为分类任务来微调 BERT,而预测阶段将 BERT 单独从模型中剥离出来,拿到 BERT 的表征作为句嵌入,在应用层使用余弦相似度作为文本匹配的依据。很明显,Sentence-BERT 的训练阶段和预测阶段目标不一致,可能出现训练中交叉熵损失还在下降,但是实际余弦相似度却没有提升的情况。
余弦相似度的局限性
由于应用层采用余弦相似度,作者想让训练和预测统一。但是 Sentence-BERT 直接使用余弦相似度或者它的变体作为损失函数效果并不好,原因是余弦相似度的值映射到是否相似的标签上不合适,导致样本矛盾,模型难以收敛。
在样本中出现的正样本标记为 1,它们都是语义相同的样本;负样本标记为 0,它们都是语义有差异,但是字面很相似的样本。举例如下:
正样本:什么时候可以降低花呗额度 花呗怎么降低额度 1
负样本:花呗里面没有看到 花呗也没有看到钱 0
其中负样本的句子 1 和句子 2 存在字和词的高度重叠,这种样本对模型来说属于'有一定难度的样本'。它们确实语义不同,但是由于本身重叠很高,余弦相似度自然也比较高(比如等于 0.7)。因此直接计算它们的余弦相似度并且往标签 0 去学对模型来说过头了,负样本还远达不到 0 的水平。真正 0 水平的样本应该是句子 1 和句子 2 完全驴头不对马嘴。
给到样本聚焦在一个比较小和比较难的样本空间,隔绝了大量的容易样本。本质上的原因是语义不相似,不代表余弦相似度就应该低。如果不改变这种小样本空间,相应的好坏分割点的阈值应该被提高,比如以 0.8 作为阈值,因为给到的不论正样本还是负样本大体都是相似的。
排序思想的引入
作者的创新点在于直接抛弃阈值,以排序的思想解决小样本空间问题。既然不能用绝对的是否来衡量,那么相似度的大小排序总应该还是存在的吧,即所有正样本对的余弦相似度应该尽可能的比所有负样本对的余弦相似度更高。
令一个批次下有三对样本,分别是正样本 V1,V2,负样本 V3,V4,V5,V6,<.,.>代表向量的余弦相似度,则期望<V1,V2>比其他两个都要大。列入笛卡尔积的所有正负样本的比较如下,其中?位置代表同类样本的比较,可以忽视,在损失函数中不起作用。
| 相似度对比 | <V1,V2> | <V3,V4> | V5,V6 |
|---|
| <V1,V2> | 等于 | 大于 | 大于 |
| <V3,V4> | 小于 | 等于 | ? |
| <V5,V6> | 小于 | ? | 等于 |
损失函数推导
为了能实现所有正样本都能比负样本余弦相似度更高的目的,作者提出 CoSENT 的目标函数如下:
$$ L = -\log \frac{\exp(s_{pos} / \tau)}{\sum_{i,j} \exp(s_{ij} / \tau)} $$
这个式子是一个 LogSumExp 形式。LogSumExp 可以看成是 max 的光滑近似。我们把 log 括号后面的 1 替换为 e 的 0 次幂,实际上就是某样本减去它自身,因此上式可以转化为包含所有正负例差值的组合。
公式右侧是一个笛卡尔积组合,罗列了所有负例和正例相减的情形。若所有正样本的余弦相似度都比负样本要高,则公式右侧的值为 0,此时损失近似为 0。而如果存在有正样本的余弦相似度比负样本要低,则取负样本和正样本得分差距最大的那个作为最终的损失。通过这种方式期望模型给任意正样本的余弦相似度得分都能够比负样本来的大,实现所有正样本都排在负样本前面的效果,这种方式绕过了阈值,直接从排序角度来对目标进行约束。
CoSENT 模型搭建和语义检索实践
本例参考前文相关技术文章,采用同样的数据集和模型基座 bert-base-chinese,实现最终模型效果的对比。
模型结构
为了使得在 train 状态下 BERT 的输出结果一致,将一对样本的句子 1 和句子 2 进行上下堆叠,在损失计算之前再分别取偶数位和奇数位分别拿到句子 1 和句子 2 的表征。
import torch.nn as nn
def get_cosine_score(s1: torch.Tensor, s2: torch.Tensor):
s1_norm = s1 / torch.norm(s1, dim=1, keepdim=True)
s2_norm = s2 / torch.norm(s2, dim=1, keepdim=True)
cosine_score = (s1_norm * s2_norm).sum(dim=1)
return cosine_score
class SentenceBert(nn.Module):
def __init__(self):
super(SentenceBert, self).__init__()
self.pre_train = PRE_TRAIN
self.linear = nn.Linear(PRE_TRAIN_CONFIG.hidden_size * 3, 2)
nn.init.xavier_normal_(self.linear.weight.data)
def forward(self, s):
s_emb = self.pre_train(**s)['last_hidden_state'][:, 0, :]
s1_emb, s2_emb = s_emb[::2], s_emb[1::2]
cosine_score = get_cosine_score(s1_emb, s2_emb)
return s1_emb, s2_emb, cosine_score
其中在前向传播中计算两向量的余弦相似度 cosine_score 用于和真实标签计算 Spearman 相关系数,而 Spearman 相关系数作为早停条件,如果连续 10 次验证集不上升则停止训练。
核心损失函数实现
CoSENT 的核心在于目标函数,它针对一个批次下的所有样本对,句子 1 和句子 2 的余弦相似度进行两两交叉组合相减,通过标签 y 挑选保留下所有负例 - 正例的情况,其他全部改为负无穷大,使得 e 的次幂接近为 0 对求和结果无效,在 logsumexp 中加入一项 0,作为负例 - 正例的天花板。
def cosent_loss(s1_emb, s2_emb, labels):
labels = (labels[:, None] < labels[None, :]).to(float)
cosine_score = get_cosine_score(s1_emb, s2_emb) * 20
cosine_diff = cosine_score[:, None] - cosine_score[None, :]
cosine_diff = (cosine_diff - (1 - labels) * 1e12).reshape(-1)
cosine_diff = torch.concat([torch.tensor([0.0]).to(DEVICE), cosine_diff], dim=0)
return torch.logsumexp(cosine_diff, dim=0)
模型训练过程将一个批次下所有句子 1 的表征,和句子 2 的表征传入 cosent_loss 即可进行损失迭代。
for step, (s, labels) in enumerate(train_loader):
s, labels = s.to(DEVICE), labels.to(DEVICE)[::2]
model.train()
optimizer.zero_grad()
s1_emb, s2_emb, cosine_score = model(s)
loss = cosent_loss(s1_emb, s2_emb, labels)
loss.backward()
optimizer.step()
训练评估与结果
训练集早停,以及测试集测评日志如下,在 ATEC 文本匹配数据集上测试集的 Spearman 相关系数有0.4973。
epoch: 6, step: 622, loss: 5.033726978888658, corrcoef:0.6161128974200909
epoch: 6, step: 623, loss: 5.7895638147161375, corrcoef:0.7360637834284756
100%|██████████| 313/313 [00:31<00:00, 9.78it/s]
[evaluation] loss: 6.5288932967316 corrcoef: 0.4973212744783885
本轮 Spearman 相关系数比之前最大 Spearman 相关系数下降:0.006715430800930733, 当前最大 Spearman 相关系数:0.5040367052793192
early stop...
[test] loss: 2113085378242445, corrcoef: 0.4973765080838994
笔者分别在蚂蚁金服 ATEC 和微众银行 BQ 两个问句数据上做了 Sentence-BERT 和 CoSENT 的测试,对比结果如下:
| 算法/数据集合 | ATEC 数据集 | BQ 数据集 |
|---|
| Sentence-BERT | 0.4592 | 0.7006 |
| CoSENT-BERT | 0.4974 | 0.7129 |
CoSENT 在两个数据上相比于 Sentence-BERT 都有明显的提升,说明 CoSENT 训练得到的句嵌入在文本匹配场景表现地更好,这种以余弦相似度的排序作为训练目标的策略更加有效。
工业界应用与优化建议
语义检索流程
在实际的语义检索系统中,通常采用以下流程:
- 索引构建:将文档库中的文本通过 CoSENT 模型编码为向量,存入向量数据库(如 Faiss, Milvus)。
- 查询编码:用户输入查询语句,同样通过 CoSENT 模型编码为向量。
- 相似度计算:计算查询向量与库中向量的余弦相似度。
- 召回与重排:根据相似度分数召回 Top-K 文档,可结合规则或重排序模型进一步优化。
超参数调优指南
- Batch Size:较大的 Batch Size 有助于估计更准确的余弦相似度分布,但受限于显存。建议从 32 或 64 开始尝试。
- Learning Rate:CoSENT 对学习率较为敏感,建议使用较小的学习率(如 1e-5 至 5e-5)配合 Warmup 策略。
- Temperature Scaling:损失函数中的温度系数(如代码中的 *20)影响梯度的平滑度,可根据具体任务调整。
常见问题排查
- 模型不收敛:检查数据标签是否正确,确保正负样本对区分明显。检查 Loss 函数实现是否与论文一致。
- 推理速度慢:CoSENT 基于 BERT,推理速度较慢。生产环境建议使用 ONNX 导出或量化加速。
- 长文本截断:BERT 类模型通常限制最大长度(如 512)。对于长文本,需考虑分段编码或摘要后再编码的策略。
总结
CoSENT 通过引入余弦相似度排序损失,解决了传统 Sentence-BERT 训练与预测目标不一致的问题。实验表明,在多个文本匹配数据集上,CoSENT 均取得了优于 Sentence-BERT 的效果。对于需要高精度语义检索的场景,CoSENT 是一个值得优先考虑的基线模型。