当前位置:   article > 正文

NLPCDA —— 基于SimBERT的相似文本生成

nlpcda

                            NLPCDA —— 基于SimBERT的相似文本生成

             感谢苏神开源的SimBERT,笔者先前简单尝试了SimBERT在相似文本生成的应用。同时结合nlpcda作者开源的代码,所以才有了博客中的demo:NLPCDA——中文数据增强工具。估计是标题不够高大上,或者大家不知道NLPCDA这个工具,阅读量不大。

            最近,苏神又开源了RoFormer-Sim模型(SimBERT的升级版,简称SimBERTv2),链接:SimBERTv2来了!融合检索和生成的RoFormer-Sim模型

1. SimBERT与SimBERTv2的核心区别

  • A. 训练细节
  1. SimBERT = BERT + UniLM + 对比学习
  2. SimBERTv2 = RoFormer + UniLM + 对比学习 + BART + 蒸馏
  1. SimBERT:疑问类型相似句
  2. SimBERTv2:疑问类型相似句 + 通用类型相似句
  • D. 生成能力
  1. SimBERT:无BART
  2. SimBERTv2:基于BART的思想,“输入带噪声的句子,输出原句子的一个相似句“
  • E. 蒸馏
蒸馏的目的:在SimBERTv2训练完之后,进一步通过蒸馏的方式把SimBERT的检索效果转移到SimBERTv2上去,从而使得SimBERTv2的检索效果基本持平甚至优于SimBERT。

2.  SimBERTv2的相似文本生成demo

代码其实还是和上篇博客中一样,只是模型不一样,并且需要更新bert4keras的版本,以及修改源码中的generator函数。一步一步说吧。

(1)python的demo

  1. from nlpcda import Simbert
  2. from time import time
  3. def test_sing(simbert, N):
  4. """
  5. 功能: 单元测试
  6. :param simbert:
  7. :return:
  8. """
  9. while True:
  10. text = input("\n输入: ")
  11. ss = time()
  12. synonyms = simbert.replace(sent=text, create_num=N)
  13. for line in synonyms:
  14. print(line)
  15. print("总耗时{0}ms".format(round(1000 * (time() - ss), 3)))
  16. if __name__ == "__main__":
  17. # SimBERT模型: Simbert/chinese_simbert_L-12_H-768_A-12
  18. # SimBERTv2模型: Simbert/chinese_roformer-sim-char_L-12_H-768_A-12
  19. config = {
  20. 'model_path': 'Simbert/chinese_roformer-sim-char_L-12_H-768_A-12',
  21. 'device': 'cuda',
  22. 'max_len': 32,
  23. 'seed': 1
  24. }
  25. sim_bert = Simbert(config=config)
  26. test_sing(simbert=sim_bert, N=10) # 单元测试

说明:chinese_roformer-sim-char_L-12_H-768_A-12模型下载链接,苏神是提供了的。

(2)包版本更新

pip install bert4keras==0.10.6

(3)generator.py文件修改

改后的代码应该是

  1. # -*- coding: utf-8 -*-
  2. import os
  3. import numpy as np
  4. from bert4keras.backend import keras
  5. from bert4keras.models import build_transformer_model
  6. from bert4keras.tokenizers import Tokenizer
  7. from bert4keras.snippets import sequence_padding, AutoRegressiveDecoder
  8. def setup_seed(seed):
  9. try:
  10. import random
  11. import numpy as np
  12. np.random.seed(seed)
  13. random.seed(seed)
  14. except Exception as e:
  15. pass
  16. class SynonymsGenerator(AutoRegressiveDecoder):
  17. """seq2seq解码器
  18. """
  19. def __init__(self, model_path, max_len=32, seed=1):
  20. # super().__init__()
  21. setup_seed(seed)
  22. self.config_path = os.path.join(model_path, "bert_config.json")
  23. self.checkpoint_path = os.path.join(model_path, "bert_model.ckpt")
  24. self.dict_path = os.path.join(model_path, "vocab.txt")
  25. self.max_len = max_len
  26. self.tokenizer = Tokenizer(self.dict_path, do_lower_case=True)
  27. self.bert = build_transformer_model(
  28. self.config_path,
  29. self.checkpoint_path,
  30. # model='roformer', # SimBERTv2模型加载, SimBERT模型加载时, 注释该行
  31. with_pool='linear',
  32. application='unilm',
  33. return_keras_model=False,
  34. )
  35. self.encoder = keras.models.Model(self.bert.model.inputs,
  36. self.bert.model.outputs[0])
  37. self.seq2seq = keras.models.Model(self.bert.model.inputs,
  38. self.bert.model.outputs[1])
  39. super().__init__(start_id=None, end_id=self.tokenizer._token_end_id,
  40. maxlen=self.max_len)
  41. # @AutoRegressiveDecoder.set_rtype('probas') # bert4keras==0.7.7
  42. @AutoRegressiveDecoder.wraps(default_rtype='probas') # bert4keras==0.10.6
  43. def predict(self, inputs, output_ids, states):
  44. token_ids, segment_ids = inputs
  45. token_ids = np.concatenate([token_ids, output_ids], 1)
  46. segment_ids = np.concatenate(
  47. [segment_ids, np.ones_like(output_ids)], 1)
  48. return self.seq2seq.predict([token_ids, segment_ids])[:, -1]
  49. def generate(self, text, n=1, topk=5):
  50. # bert4keras==0.7.7
  51. # token_ids, segment_ids = self.tokenizer.encode(
  52. # text, max_length=self.max_len)
  53. # bert4keras==0.10.6
  54. token_ids, segment_ids = self.tokenizer.encode(
  55. text, maxlen=self.max_len)
  56. output_ids = self.random_sample([token_ids, segment_ids], n, topk)
  57. return [self.tokenizer.decode(ids) for ids in output_ids]
  58. def gen_synonyms(self, text, n=100, k=20, threhold=0.75):
  59. """"含义: 产生sent的n个相似句,然后返回最相似的k个。
  60. 做法:用seq2seq生成,并用encoder算相似度并排序。
  61. """
  62. r = self.generate(text, n)
  63. r = [i for i in set(r) if i != text]
  64. r = [text] + r
  65. X, S = [], []
  66. for t in r:
  67. x, s = self.tokenizer.encode(t)
  68. X.append(x)
  69. S.append(s)
  70. X = sequence_padding(X)
  71. S = sequence_padding(S)
  72. Z = self.encoder.predict([X, S])
  73. Z /= (Z ** 2).sum(axis=1, keepdims=True) ** 0.5
  74. scores = np.dot(Z[1:], Z[0])
  75. argsort = scores.argsort()
  76. scores = scores.tolist()
  77. # print(scores.shape)
  78. # return [(r[i + 1], scores[i]) for i in argsort[::-1][:k] if scores[i] > threhold]
  79. return [(r[i + 1], scores[i]) for i in argsort[::-1][:k]]

最后,我们再来运行下SimBERTv2模型的生成结果。

为了对比出SimBERTv2的优势,笔者试了3条一般问句在SimBERT和SimBERTv2的结果。

  1. 帮我关一下台灯
  2. 我想吃附近的火锅
  3. 我们一起去打羽毛球吧
  • SimBERTv2模型的相似文本生成结果

  •  SimBERT模型的相似文本生成结果

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/382961
推荐阅读
相关标签
  

闽ICP备14008679号