当前位置: 首页 > article >正文

NLP: SBERT介绍及sentence-transformers库的使用

1. Sentence-BERT

  Sentence-BERT(简写SBERT)模型是BERT模型最有趣的变体之一,通过扩展预训练的BERT模型来获得固定长度的句子特征,主要用于句子对分类、计算两个句子之间的相似度任务。

1.1 计算句子特征

  SBERT模型同样是将句子标记送入预训练的BERT模型来获取句子特征的,但这里并不使用 R [ C L S ] R_{[CLS]} R[CLS]作为最终的句子特征。在SBERT中,通过汇聚所有标记的特征来计算整个句子的特征。具体的汇聚方法有两种:平均汇聚和最大汇聚。

  • 平均汇聚:使用平均汇聚来获取句子特征。这种方法得到的句子的特征将包含所有词语(Token)的意义。
  • 最大汇聚:使用最大汇聚来获取句子特征。这种方法得到的句子的特征将仅包含重要词语(Token)的意义。
    在这里插入图片描述

1.2 SBERT架构

  SBERT模型使用二元组网络架构来执行以一对句子作为输入的任务,并使用三元组网络架构来实现三元组损失函数。

1.2.1 使用二元组网络架构的SBERT模型

  SBERT通过二元组网络(两个共享同样权重的相同网络)架构对执行句子对任务的预训练的BERT模型进行微调。句子对任务具体包括以下两种:

  • 句子对分类任务: 判断句子对是否相似。相似则返回1,不相似则返回0。其SBERT模型架构为:
    在这里插入图片描述
  • 句子对回归任务:计算两个给定句子之间的语义相似度。其对应的SBERT架构为:在这里插入图片描述
1.2.2 使用三元组网络架构的SBERT模型

  三元组网络架构的SBERT模型的任务计算出一个特征,使锚定句和正向句之间的相似度高,锚定句和负向句之间的相似度低。其架构如下:
在这里插入图片描述

2. 计算文本相似度

2.1 bi-encoder VS cross-encoder

  bi-encoder和cross-encoder是语义匹配、文本相似度、信息检索场景下下常用的两种模型架构。这两者都基于深度学习模型(如BERT等)进行编码和比较文本之间的相似度,但它们在计算方式、效率和适用场景上有显著的区别。

2.1.1 bi-encoder

  bi-encoder是一种独立编码方式,即输入的两个文本会被分别编码为独立的向量,然后通过计算这两个向量的相似度来判断文本之间的关系。使用bi-encoder方式计算文本相似度的案例如下:

from sentence_transformers import SentenceTransformer
#加载预训练的sentence transformer模型
model = SentenceTransformer('all-MiniLM-L6-v2')
sentences=["这个商品挺好用的","这个商品一点也不好用"]
embeddings=model.encode(sentences)
similarity=model.similarity(embeddings[0],embeddings[1])
print(similarity) #0.5868
2.1.2 cross-encoder

  cross-encoder是一种联合编码方式,即将两个文本拼接在一起作为模型的输入,模型会通过对两个文本的联合表示来直接输出一个相似度分数。这种方式可以更好地捕捉两个文本之间的复杂交互信息,因此在诸如问答匹配、精确文本相似度计算等需要细粒度判断的任务上表现更好。具体使用方式如下:

from sentence_transformers.cross_encoder import CrossEncoder
model=CrossEncoder("cross-encoder/stsb-distilroberta-base")
query="这个产品挺好用的"
corpus=["这个产品很好",
        "这个产品的设计有很大问题",
        "这个产品不好用"]
ranks=model.rank(query,corpus)
for rank in ranks:
    print(f"{rank['score']:.2f}\t{corpus[rank['corpus_id']]}")

3 微调SBERT

  接下来我们使用STSB数据集对SBERT模型进行微调。具体代码如下

from datasets import load_dataset
from sentence_transformers import losses
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerTrainer,
)
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator,SimilarityFunction
from datasets import load_dataset

train = load_dataset("sentence-transformers/stsb",split='train')
dev = load_dataset("sentence-transformers/stsb",split='validation')
test= load_dataset("sentence-transformers/stsb",split='test')

model=SentenceTransformer('FacebookAI/xlm-roberta-base')

loss=losses.CoSENTLoss(model=model)

args=SentenceTransformerTrainingArguments(output_dir='models/model1',
                                          num_train_epochs=1,
                                          per_device_train_batch_size=16,
                                          per_device_eval_batch_size=16,
                                          warmup_ratio=0.1,
                                          eval_strategy='steps',
                                          eval_steps=100,
                                          save_strategy='steps',
                                          save_total_limit=2,
                                          bf16=False,)

dev_evaluator=EmbeddingSimilarityEvaluator(
    sentences1=dev['sentence1'],
    sentences2=dev['sentence2'],
    scores=dev['score'],
    main_similarity=SimilarityFunction.COSINE,
    name='dev-evaluator')

dev_evaluator(model)

trainer=SentenceTransformerTrainer(model=model,
                                   args=args,
                                   train_dataset=train,
                                   eval_dataset=dev,
                                   loss=loss,
                                   evaluator=dev_evaluator)   
trainer.train()                        

test_evaluator=EmbeddingSimilarityEvaluator(
    sentences1=test['sentence1'],
    sentences2=test['sentence2'],
    scores=test['score'],
    main_similarity=SimilarityFunction.COSINE,
    name='test-evaluator')
test_evaluator(model)
model.save_pretrained('models/model1')

关于上述代码,需要说明以下几点:

  • 训练和评估SBERT的数据类型必须是datasets.Datasetdatasets.DatasetDict
  • 数据集的格式必须和损失函数、评估器相匹配。如果损失函数需要标签字段,那么数据集必须有“label”或“score”字段;其他名称非“label”或“score”的字段将自动归属于Inputs字段。所以在进行后续步骤时,必须将数据集中的无法标签删除,同时要保证数据集中的字段顺序与对应损失函数中要求的顺序一致。
  • 需要根据具体的任务以及数据集的形式选择合适的损失函数,没有哪种损失函数可以解决所有的问题。SBERT提供的损失函数列表如下:
    https://www.sbert.net/docs/sentence_transformer/loss_overview.html
  • 微调后的模型可以和其他预训练的模型一样使用,比如计算文本相似度,这里不再赘述。

参考资料

  1. BERT基础教程: Transformer大模型实战
  2. https://baijiahao.baidu.com/s?id=1801193891938395467
  3. https://www.sbert.net

http://www.kler.cn/news/340451.html

相关文章:

  • Spring Boot大学生就业招聘系统的开发策略
  • 基于SSM的家庭理财系统的设计与实现
  • 掌握 ASP.NET Web 开发:从基础到身份验证
  • 【Java 并发编程】解决多线程中数据错乱问题
  • 前端vue-安装pinia,它和vuex的区别
  • Vue中watch监听属性的一些应用总结
  • 微信小程序启动不起来,报错凡是以~/包名/*.js路径的文件,都找不到,试过网上一切方法,最终居然这么解决的,【避坑】命运的齿轮开始转动
  • 【Linux】man手册安装使用
  • 以一个B站必剪应用Bug过一下CVSS 4.0评分
  • TCP网络通信——多线程
  • 重学SpringBoot3-集成Redis(二)之注解驱动
  • 408算法题leetcode--第26天
  • Kubernetes--深入理解Pod资源管理
  • (Linux驱动学习 - 9).设备树下platform的LED驱动
  • 如何通过jupyter调用服务器端的GPU资源
  • 微信小程序流量主
  • 字节青训营-技术训练营报名啦!!!
  • 外包干了6天,技术明显退步。。。
  • 项目-坦克大战学习-爆炸特效消除
  • 九大排序之交换排序