当前位置:   article > 正文

语义相似度模型SBERT ——一个挛生网络的优美范例_huggingfaceembeddings 判断两个文本相似度

huggingfaceembeddings 判断两个文本相似度

论文地址:https://arxiv.org/abs/1908.10084
论文中文翻译:https://www.cnblogs.com/gczr/p/12874409.html
源码下载:https://github.com/UKPLab/sentence-transformers
相关网站:https://www.sbert.net/

“论文中文翻译”已相当清楚,故本篇不再翻译,只简单介绍SBERT的原理,以及训练和使用中文相似度模型的方法和效果。

原理

挛生网络Siamese network(后简称SBERT),其中Siamese意为“连体人”,即两人共用部分器官。SBERT模型的子网络都使用BERT模型,且两个BERT模型共享参数。当对比A,B两个句子相似度时,它们分别输入BERT网络,输出是两组表征句子的向量,然后计算二者的相似度;利用该原理还可以使用向量聚类,实现无监督学习任务。

挛生网络有很多应用,比如使用图片搜索时,输入照片将其转换成一组向量,和库中的其它图片对比,找到相似度最高(距离最近)的图片;在问答场景中,找到与用户输入文字最相近的标准问题,然后给出相应解答;对各种文本标准化等等。

衡量语义相似度是自然语言处理中的一个重要应用,BERT源码中并未给出相应例程(run_glue.py只是在其示例框架内的简单示例),真实场景使用时需要做大量修改;而SBERT提供了现成的方法解决了相似度问题,并在速度上更有优势,直接使用更方便。

SBERT对Pytorch进行了封装,简单使用该工具时,不仅不需要了解太多BERT API的细节, Pytorch相关方法也不多,下面来看看其具体用法。

配置环境

需要注意的是机器需要能正常配置BERT运行环境,如GPU+CUDA+Pytorch+Transformer匹配版本。

$ pip install sentence_transformers

下载源码

$ git clone https://github.com/UKPLab/sentence-transformers.git

模型预测

在未进行调优(fine-tune)前,使用预训练的通用中文BERT模型也可以达到一定效果,下例是从几个选项中找到与目标最相近的字符串。

  1. from sentence_transformers import SentenceTransformer
  2. import scipy.spatial
  3. embedder = SentenceTransformer('bert-base-chinese')
  4. corpus = ['这是一支铅笔',
  5. '关节置换术',
  6. '我爱北京天安门',
  7. ]
  8. corpus_embeddings = embedder.encode(corpus)
  9. # 待查询的句子
  10. queries = ['心脏手术','中国首都在哪里']
  11. query_embeddings = embedder.encode(queries)
  12. # 对于每个句子,使用余弦相似度查询最接近的n个句子
  13. closest_n = 2
  14. for query, query_embedding in zip(queries, query_embeddings):
  15. distances = scipy.spatial.distance.cdist([query_embedding], corpus_embeddings, "cosine")[0]
  16. # 按照距离逆序
  17. results = zip(range(len(distances)), distances)
  18. results = sorted(results, key=lambda x: x[1])
  19. print("======================")
  20. print("Query:", query)
  21. print("Result:Top 5 most similar sentences in corpus:")
  22. for idx, distance in results[0:closest_n]:
  23. print(corpus[idx].strip(), "(Score: %.4f)" % (1-distance))

训练中文模型

模型训练方法

训练原理:https://www.sbert.net/docs/training/overview.html
训练示例说明:https://www.sbert.net/examples/training/sts/README.html
训练示例代码:examples/training/sts/training_stsbenchmark.py

训练中文模型

把示例中的bert-base-cased换成bert-base-chinese,即可下载和使用中文模型。需要注意的是:中文和英文词库不同,不能将中文模型用于英文数据训练。

下载中文训练数据

下载信贷相关数据,csv数据7M多,约10W条训练数据,可在下例中使用

  1. $ git clone https://github.com/lixuanhng/NLP_related_projects.git
  2. $ ls NLP_related_projects/BERT/Bert_sim/data

代码

  1. from torch.utils.data import DataLoader
  2. import math
  3. from sentence_transformers import SentenceTransformer, LoggingHandler, losses, models, util
  4. from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
  5. from sentence_transformers.readers import InputExample
  6. import logging
  7. from datetime import datetime
  8. import sys
  9. import os
  10. import pandas as pd
  11. model_name = 'bert-base-chinese'
  12. train_batch_size = 16
  13. num_epochs = 4
  14. model_save_path = 'test_output'
  15. logging.basicConfig(format='%(asctime)s - %(message)s',
  16. datefmt='%Y-%m-%d %H:%M:%S',
  17. level=logging.INFO,
  18. handlers=[LoggingHandler()])
  19. # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
  20. word_embedding_model = models.Transformer(model_name)
  21. # Apply mean pooling to get one fixed sized sentence vector
  22. pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
  23. pooling_mode_mean_tokens=True,
  24. pooling_mode_cls_token=False,
  25. pooling_mode_max_tokens=False)
  26. model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
  27. train_samples = []
  28. dev_samples = []
  29. test_samples = []
  30. def load(path):
  31. df = pd.read_csv(path)
  32. samples = []
  33. for idx,item in df.iterrows():
  34. samples.append(InputExample(texts=[item['sentence1'], item['sentence2']], label=float(item['label'])))
  35. return samples
  36. train_samples = load('/workspace/exports/git/NLP_related_projects/BERT/Bert_sim/data/train.csv')
  37. test_samples = load('/workspace/exports/git/NLP_related_projects/BERT/Bert_sim/data/test.csv')
  38. dev_samples = load('/workspace/exports/git/NLP_related_projects/BERT/Bert_sim/data/dev.csv')
  39. train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
  40. train_loss = losses.CosineSimilarityLoss(model=model)
  41. evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')
  42. warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up
  43. # Train the model
  44. model.fit(train_objectives=[(train_dataloader, train_loss)],
  45. evaluator=evaluator,
  46. epochs=num_epochs,
  47. evaluation_steps=1000,
  48. warmup_steps=warmup_steps,
  49. output_path=model_save_path)
  50. model = SentenceTransformer(model_save_path)
  51. test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='sts-test')
  52. test_evaluator(model, output_path=model_save_path)

测试结果

  • 直接使用预训练的英文模型,测试集正确率21%
  • 直接使用预训练的中文模型,测试集正确率30%
  • 使用1000个用例的训练集,4次迭代,测试集正确率51%
  • 使用10000个用例的训练集,4次迭代,测试集正确率68%
  • 使用100000个用例的训练集,4次迭代,测试集正确率71%

一些技巧

除了设置超参数以外,也可通过构造训练数据来优化SBERT网络,比如:构造正例时,把知识“喂”给模型,如将英文缩写与对应中文作为正例对训练模型;构造反例时用容易混淆的句子对训练模型(文字相似但含义不同的句子;之前预测出错的实例,分析其原因,从而构造反例;使用知识构造容易出错的句子对),以替代之前的随机抽取反例。

参考

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/264286
推荐阅读
相关标签
  

闽ICP备14008679号