Sentence-BERT 句嵌入模型介绍与实践
Sentence-BERT 是一种基于 BERT 的句嵌入模型,通过孪生网络结构实现高效的语义相似度计算。 Embedding 技术原理,对比了传统 BERT 与 Sentence-BERT 的差异,并提供了基于 PyTorch 和 HuggingFace 库的训练与检索实践方案。内容包括环境配置、数据预处理、模型搭建、损失函数选择及余弦相似度检索流程,旨在帮助开发者在文本匹配、语义搜索等场景中应用句嵌入技术。

Sentence-BERT 是一种基于 BERT 的句嵌入模型,通过孪生网络结构实现高效的语义相似度计算。 Embedding 技术原理,对比了传统 BERT 与 Sentence-BERT 的差异,并提供了基于 PyTorch 和 HuggingFace 库的训练与检索实践方案。内容包括环境配置、数据预处理、模型搭建、损失函数选择及余弦相似度检索流程,旨在帮助开发者在文本匹配、语义搜索等场景中应用句嵌入技术。

Sentence-BERT 是一种句嵌入表征模型,常用于文本语义相似度的匹配。本文对 Sentence-BERT 做理论介绍,并结合领域文本数据进行实践,训练句嵌入实现语义检索。
Embedding 是将某个实体转换为由数字序列形成的向量,使得计算机能够对该实体进行理解,从而完成各种算法任务。Embedding 技术广泛应用于自然语言处理、图像识别、推荐系统等场景。在 NLP 和大模型领域,文本经过分词编码和 Embedding 处理成数值信息灌入语言模型,通过海量语料的训练使得模型具备类似人类一样的语义理解和生成能力。
自然语言通过 Embedding 进行语义表征。对文本中的每个分词进行 Embedding 称为词嵌入,对一整句或者一段文本进行 Embedding 称为句嵌入。句嵌入在文本推荐、查询改写、智能问答、知识库检索等领域有广泛的应用。这些嵌入向量作为模型的中间产物,如果对其本身进行向量聚类和相似度匹配,也可以挖掘出语义的关系远近。一般的,通过对文本做 Embedding 向量化配合余弦相似度来进行语义比对,余弦相似度越大,实体在语义空间的夹角越小,语义越相似。
在进行 Sentence-BERT 的实践之前,需要准备好相应的开发环境。推荐使用 Python 3.8+ 版本,并安装以下核心依赖库:
transformers: HuggingFace 提供的预训练模型库torch: PyTorch 深度学习框架numpy: 数值计算库scikit-learn: 用于评估指标计算可以通过 pip 命令安装:
pip install transformers torch numpy scikit-learn
此外,确保 GPU 驱动已正确安装以加速模型训练。若使用 CPU 运行,需调整 batch_size 以避免内存溢出。
Sentence-BERT 是一种句嵌入模型,输入一段文本,输出整段文本的向量表征。在 HuggingFace 仓库中下载预训练模型 sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 进行调用,快速开始使用 Sentence-BERT 输出句向量。
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
tokenizer = AutoTokenizer.from_pretrained('./sentence-transformers')
model = AutoModel.from_pretrained('./sentence-transformers')
sentences = ['中午我想吃清蒸鲈鱼', '天气预报说明天下雨', '食堂的餐饭不好吃', '我做了红烧鱼作为中午的饭菜']
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = model(**encoded_input)
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']).cpu().numpy()
以上代码对四句话进行句嵌入表征,分别是'中午我想吃清蒸鲈鱼','天气预报说明天下雨', '食堂的餐饭不好吃','我做了红烧鱼作为中午的饭菜'。Sentence-BERT 通过对最后一层输出的所有非 Padding 位置的词向量做均值池化获得句子向量,每个句子表征为 384 维。
print(sentence_embeddings.shape)
# Output: (4, 384)
print(sentence_embeddings)
进一步用余弦相似度计算第一句和其他三个句子之间的语义相关程度。
def compute_sim_score(v1, v2):
return v1.dot(v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
score_0_1 = compute_sim_score(sentence_embeddings[0], sentence_embeddings[1])
score_0_2 = compute_sim_score(sentence_embeddings[0], sentence_embeddings[2])
score_0_3 = compute_sim_score(sentence_embeddings[0], sentence_embeddings[3])
print(f"Score 0-1: {score_0_1}")
print(f"Score 0-2: {score_0_2}")
print(f"Score 0-3: {score_0_3}")
汇总相似度得分表格如下,'中午我想吃清蒸鲈鱼'和'我做了红烧鱼作为中午的饭菜'的语义相关程度最高,该结论也符合人类的感知,说明 Sentence-BERT 句嵌入具有一定的有效性。
| 目标句子 | 候选句子 | 相似度 |
|---|---|---|
| 中午我想吃清蒸鲈鱼 | 天气预报说明天下雨 | 0.3363 |
| 中午我想吃清蒸鲈鱼 | 食堂的餐饭不好吃 | 0.3904 |
| 中午我想吃清蒸鲈鱼 | 我做了红烧鱼作为中午的饭菜 | 0.7262 |
Sentence-BERT 是 2019 年由论文《Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks》提出的一种有监督的句嵌入算法,它本质上是基于BERT预训练模型的输出作为句嵌入。额外的,它引入孪生网络的思想将一对句子的表征和人工标注的相似度做比对,从而实现对 BERT 的微调,使得 BERT 输出的句嵌入更加契合语义匹配的场景。
给定一段文本输入给 BERT,BERT 输出为 [batch_size, seq_len, emb_size] 的矩阵,它由每个位置的 token Embedding 构成。在文本分类等下游任务中,一般将 [CLS] 位置的 Embedding 或者所有 token Embedding 的均值池化作为整段文本的信息表征。这种表征方式无法适配语义检索场景,因为 BERT 的预训练是基于自然文本,侧重于学习词和句子的上下文关联,而上下文关联并不代表语义相似。
另一方面,BERT 自身可以完成文本匹配的下游任务,输入一对句子,拼接为 [CLS]+Sentence 1+[SEP]+Sentence 2+[SEP],做二分类预测两个句子语义是否相近。这种方式端到端地预测两个句子的匹配程度,但是每次都需要将目标句子和候选所有句子输入到 BERT 中进行分类预测,推理成本极高,不适合大规模的语义检索场景。
综上所述作者提出 Sentence-BERT(SBERT),通过孪生 BERT 网络以及人工标注的语义相似三元组数据对 BERT 做微调,在部署阶段推理出句嵌入,后续再使用余弦相似度进行语义搜索。
Sentence-BERT 分为训练和预测两个阶段,训练阶段基于标注的三元组句子对有监督微调 BERT,而预测阶段直接基于微调后的 BERT 生成句嵌入。训练阶段的网络结构如下:
作者参考了孪生网络的思想引入了两个参数完全一样的 BERT,令 BERT 输出的 embedding 维度为 d,将句子 A 和句子 B 分别输入其中,通过池化得到两个句嵌入 u,v。进一步作者对 u,v 向量做逐位相减再取绝对值,最终生成三组向量将它们拼接为长度 3d,输入全连接层做二分类Softmax 交叉熵预测。
对于池化操作的选择作者分别尝试了均值池化,最大池化,和直接使用 [CLS] 位置三种方式,在实验中均值池化效果最好。在最后一层全连接之前的向量拼接方式选取上,作者尝试了 u,v 向量的多种 element-wise 组合方法,包括逐位置相乘,逐位相减等,在实验中 u,v 和两者逐位相减拼接的效果最好。
在损失函数的构造上,根据标注数据的类型不同,作者分别采用了二分类交叉熵,MSE 回归指标,以及对比学习中的 Triplet loss 三元组损失,不论采用哪种损失最终的目的都是使得句嵌入的距离能够和人工标注的语义相似关系对齐。样本类型包括:
Sentence-BERT 的预测阶段直接使用微调后的 BERT 根据相同的池化方式输出句嵌入即可,相当于把训练阶段的孪生网络中的 BERT 单独摘出来,Sentence-BERT 只负责输出句嵌入 u,v,后续的相似度检索交给下游的余弦相似度任务单独实现。
Sentence-BERT 的理论涉及到孪生神经网络和对比学习这两个概念,具体而言,Sentence-BERT 采用了孪生神经网络的结构,在损失优化环节借鉴了对比学习的策略。孪生网络和对比学习这两种方法常常一起使用。
孪生神经网络 (Siamese Networks),它由两个权重共享的任意神经网络拼接而成,两个样本分别输入,输出其嵌入高维空间的表征,从而比较两个样本的相似程度。孪生神经网络于 1994 年被首次提出,用于验证手写平板电脑签名,一对孪生网络分别提取两个签名的特征表示,从而量化两个特征向量之间的距离,若手写签名与以前存储签名间的距离小于预设阈值则被接受,否则将被视为伪造签名。同理,孪生神经网络也常用于人脸比对。
对比学习 (Contrastive Learning),对比学习的目标是学习一个编码器,使得相似实体的编码在特征空间中尽可能接近,不相似的实体编码结果在特征空间中尽可能的远。对比学习基于代理任务预先生成相似样本和不相似样本,从而提供了一个监督信号结合目标函数去训练模型。其中代理任务通常是人为设定的一些相似规则,数据增强是代理任务的实现常见手段,目标函数一般是基于向量距离的计算,比如 Triplet loss 中采用基准分别和正例负例的距离之差的最大值作为优化目标,公式中 Sa,Sp,Sn 分别代表基准向量,正例向量,负例向量,||…||代表两个向量的距离度量方式。
本节将基于预训练 bert-base-chinese 模型,在领域文本上从头训练一个 Sentence-BERT 模型,完成训练和预测两个流程,并且基于预测的向量结果完成文本相似检索。
采用公开的 ATEC 文本匹配数据集,内容包含 10 万多条客服问句匹配样本,格式为三元组形式(问句 1,问句 2,是否相似),数据样例如下:
打不开花呗 为什么花呗打不开 1
花呗收钱就是用支付宝帐号收嘛 我用手机花呗收钱 0
花呗买东西,商家不发货怎么退款 花呗已经分期的商品 退款怎么办 0
Sentence-BERT 网络结构比较简单,只需要将问句 1 和问句 2 分别经过 BERT 的分词编码,再输入给 BERT 拿到表征,拼接后输入全连接做二分类预测即可,BERT 表征的方式本例采用 [CLS] 位置,也可以使用其他方式,比如非 Padding 位置的均值池化等。
在数据处理环节,需要将单条样本的问句 1,问句 2 上下堆叠成两条样本,目的是统一输入同一个 BERT 网络。如果分开输入,在 train 状态下由于有 Dropout 的存在,就算是相同参数的 BERT 输出也不一样,此时就不满足孪生神经网络的要求。因此需要将原始一对问句进行堆叠,比如一个批次处理 32 条原始三元组样本,则实际灌入模型的是 64 条二元组,堆叠函数如下:
def collate_fn(data):
s, labels = [], [] # 二元组
for d in data: # 三元组 (s1, s2, label)
s.append(d[0])
s.append(d[1])
labels.append(d[2]) # y 值也需要复制一次
labels.append(d[2])
s_token = TOKENIZER.batch_encode_plus(s, truncation=True, max_length=PRE_TRAIN_CONFIG.max_position_embeddings,
return_tensors="pt", padding=True)
labels = torch.LongTensor(labels)
return s_token, labels
在网络层只有两个模块预训练 BERT 和 Linear,BERT 拿到最后一层的 [CLS] 位置表征,由于前面有堆叠操作,此处再取出所有偶数位还原出所有句子 1,取出所有奇数位还原出句子 2,两者拼接上相减绝对值之后一齐输入给 Linear。
class SentenceBert(nn.Module):
def __init__(self):
super(SentenceBert, self).__init__()
self.pre_train = PRE_TRAIN
self.linear = nn.Linear(PRE_TRAIN_CONFIG.hidden_size * 3, 2)
nn.init.xavier_normal_(self.linear.weight.data)
def get_cosine_score(self, s1, s2):
s1_norm = s1 / torch.norm(s1, dim=1, keepdim=True)
s2_norm = s2 / torch.norm(s2, dim=1, keepdim=True)
cosine_score = (s1_norm * s2_norm).sum(dim=1)
return cosine_score
def forward(self, s):
s_emb = self.pre_train(**s)['last_hidden_state'][:, 0, :]
s1_emb, s2_emb = s_emb[::2], s_emb[1::2]
cosine_score = self.get_cosine_score(s1_emb, s2_emb)
concat = torch.concat([s1_emb, s2_emb, torch.abs(s1_emb - s2_emb)], dim=1)
output = self.linear(concat)
return output, cosine_score
损失函数采用交叉熵,前向传播部分同时输出该批次样本的每对句子的余弦相似度 cosine_score,它和标签 y 值进行皮尔逊相关系数计算,定义 10 次验证集皮尔逊相关系数不上升作为早停条件。皮尔逊相关系数越大越好,说明余弦相似度和真实的是否相似的情况趋势越趋同。
model = SentenceBert().to(DEVICE)
criterion = nn.CrossEntropyLoss(reduction="mean")
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00003)
for step, (s, labels) in enumerate(train_loader):
s, labels = s.to(DEVICE), labels.to(DEVICE)[::2] # labels 需要折叠,取偶数位即可
model.train()
output, cosine_score = model(s)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
# 计算皮尔逊相关系数作为早停条件
corrcoef = compute_corrcoef(cosine_score.detach().cpu().numpy(), labels.detach().cpu().numpy())
print("epoch: {}, step: {}, loss: {}, corrcoef:{}".format(epoch + 1, step, loss.item(), corrcoef))
if step % 200 == 0 or step == len(train_loader):
# 验证集早停逻辑
loss_val, corrcoef_val = eval_metrics(model, val_loader)
...
在测试阶段,依旧采用测试集预测每对样本的余弦相似度和真实 y 值的皮尔逊相关系数作为评价指标,最终相关系数为0.4592。
# TODO 测试
model2 = SentenceBert().to(DEVICE)
model2.load_state_dict(torch.load("./model/sbert_{}.bin".format(data)))
loss_test, corrcoef_test = eval_metrics(model2, test_loader)
# 0.41255660838512037 0.4555726951427768
print(loss_test, corrcoef_test)
在实际训练中,为了提升模型性能,建议关注以下超参数:
在预测流程中,只需要将微调之后 BERT 从 Sentence-BERT 网络中摘出来即可,后续的向量预测都仅仅需要该 BERT 模型,首先只对整个网络中的 pre_train BERT 进行保存。
s_bert = model2.pre_train
torch.save(s_bert.state_dict(), "./model/sbert_ATEC/pytorch_model.bin")
预测的时候通过 HuggingFace 的 BERT 模型 API 进行导入。
from transformers import BertModel, BertTokenizer, BertConfig
PRE_TRAIN_PATH = "model/sbert_ATEC"
TOKENIZER = BertTokenizer.from_pretrained(PRE_TRAIN_PATH)
PRE_TRAIN_CONFIG = BertConfig.from_pretrained(PRE_TRAIN_PATH)
PRE_TRAIN = BertModel.from_pretrained(PRE_TRAIN_PATH)
将样本中所有句子 1 和句子 2 全部按照批次灌入 BERT 中进行 [CLS] 位置的向量预测,代码如下:
cut = list(range(0, len(total), batch_size))
for i in range(len(cut)):
start, end = cut[i], len(total) if i == len(cut) - 1 else cut[i + 1]
batch_text = total[start:end]
text_token = TOKENIZER.batch_encode_plus(batch_text, truncation=True, padding=True,
max_length=PRE_TRAIN_CONFIG.max_position_embeddings,
return_tensors="pt")
embs = PRE_TRAIN(**text_token)[0][:, 0, :]
embs_norm = (embs / torch.norm(embs, dim=1, keepdim=True)).detach().cpu().numpy().tolist()
total_emb.extend(embs_norm)
pickle.dump((total, total_emb), open("./model/sbert_ATEC/emb.bin", "wb"))
每个句子都会生成一个 768 维度的向量,预览其中 1 条如下:
句子:蚂蚁借呗用了了多久能恢复
向量:[-0.017775828018784523, 0.06854370981454849, -0.00908558805286884, 0.007142649497836828,...]
最终我们想输入任意文本,在候选的所有句子中找到和它最相似的文本,本例采用 Numpy 直接计算余弦相似度,整个过程包含对输入文本的分词编码,输入文本的 BERT 向量输出,输入文本向量和所有候选向量比对三个过程,我们取最相似的 top3 句子以及相似度得分。
def search_top_n(input_text, candidate_text, candidate_emb, top_n=3):
text_token = TOKENIZER.batch_encode_plus([input_text], truncation=True, padding=True,
max_length=PRE_TRAIN_CONFIG.max_position_embeddings,
return_tensors="pt")
embs = PRE_TRAIN(**text_token)[0][:, 0, :].detach().cpu().numpy()
# 输入文本向量标准化
embs = embs / np.linalg.norm(embs, axis=1)
# 计算余弦相似度
scores = np.dot(embs, np.array(candidate_emb).T)
scores[np.isneginf(scores)] = 0
top_score = np.sort(scores, axis=1)[:, -3:]
top_index = np.argsort(scores, axis=1)[:, -3:]
res = []
for s, i in zip(top_score, top_index):
one = []
for n in range(top_n):
one.append({"text": candidate_text[i[n]], "score": s[n]})
res.append(one)
return res
测试 1:输入文本为'如何关闭支付宝免密支付',运行输出如下,候选的三个句子和输入语义完全相同。
>>> input_text = "没网的时候支付宝能够支付吗"
>>> search_top_n(input_text, total, total_emb, top_n=3)
[[{'text': '怎样去消花呗的免密支付', 'score': 0.9739149930056332},
{'text': '怎么关闭花呗免密支付', 'score': 0.9834292467296013},
{'text': '怎样关闭花呗的免密支付', 'score': 0.9878402150756829}]]
测试 2:输入文本为'没网的时候支付宝能够支付吗',运行输出如下,候选的三个句子和输入语义相似,但是主体存在略微差异和不明确。
>>> input_text = "没网的时候支付宝能够支付吗"
>>> search_top_n(input_text, total, total_emb, top_n=3)
[[{'text': '手机没网,花呗会自动扣款吗', 'score': 0.874686905948586},
{'text': '不用手机支付宝,花呗能自动还款吧', 'score': 0.8775933520615644},
{'text': '我没有手机支付宝 是不是就没办法给花呗还款了', 'score': 0.8905988043804666}]]
测试 3:输入文本为'支付宝能炒股吗',运行输出如下,效果可以语义基本相同。
>>> input_text = "支付宝能炒股吗"
>>> search_top_n(input_text, total, total_emb, top_n=3)
[[{'text': '借呗可以用来买股票吗', 'score': 0.900523830153816},
{'text': '蚂蚁借呗能拿来买股票吗', 'score': 0.9067184565541515},
{'text': '借呗可以炒股吗', 'score': 0.9342895273812792}]]
从实践结果来看,Sentence-BERT 输出的句嵌入能够很好的完成文本向量化和文本相似匹配任务。

微信公众号「极客日志」,在微信中扫描左侧二维码关注。展示文案:极客日志 zeeklog
使用加密算法(如AES、TripleDES、Rabbit或RC4)加密和解密文本明文。 在线工具,加密/解密文本在线工具,online
生成新的随机RSA私钥和公钥pem证书。 在线工具,RSA密钥对生成器在线工具,online
基于 Mermaid.js 实时预览流程图、时序图等图表,支持源码编辑与即时渲染。 在线工具,Mermaid 预览与可视化编辑在线工具,online
解析常见 curl 参数并生成 fetch、axios、PHP curl 或 Python requests 示例代码。 在线工具,curl 转代码在线工具,online
将字符串编码和解码为其 Base64 格式表示形式即可。 在线工具,Base64 字符串编码/解码在线工具,online
将字符串、文件或图像转换为其 Base64 表示形式。 在线工具,Base64 文件转换器在线工具,online