当前位置:   article > 正文

小型中文版聊天机器人_基于pytorch的中文聊天机器人

基于pytorch的中文聊天机器人

入门小菜鸟,希望像做笔记记录自己学的东西,也希望能帮助到同样入门的人,更希望大佬们帮忙纠错啦~侵权立删。

目录

一、简单介绍与参考鸣谢

二、数据集介绍

三、数据预处理

1、重复标点符号表达

2、英文标点符号变为中文标点符号

3、繁体字转为简体字

4、限定长度

5、为后面制作词表做准备

6、代码实现(sol_data.py文件)

7、处理完的数据展示

四、词表制作以及转化(word2seq.py文件)

五、数据集加载(dataset.py)

六、GPT模型搭建(gpt_model.py)

1、原理解析

(1)Transformer与GPT

(2)多头注意力机制

 (3)多头注意力机制

2、代码

七、训练模型(train.py)

八、补充实现(utils.py)

九、完整代码与结果


一、简单介绍与参考鸣谢

自己用pytorch搭建模型,训练一个小型的中文闲聊机器人。

数据集和实现思路部分参考这位博主大大@weixin_44599230的博文GPT模型介绍并且使用pytorch实现一个小型GPT中文闲聊系统_tinygpt_weixin_44599230的博客-CSDN博客

这位博主真的 tql 呜呜呜,他后续还出了另外一个版本的闲聊系统,在此鸣谢!


二、数据集介绍

数据集地址百度网盘 提取码jk8d

这份数据集是纯中文,没有英文、颜文字、数字等之类的干扰呜呜呜。

这份数据集是多轮次对话数据集,数据规模为50w,每一轮次用空行隔开。


三、数据预处理

虽然说这份数据集已经很nice了,但还是有一丢丢需要处理

1、重复标点符号表达

        比如说:将“??????????????????”缩成“?”

        emmm其实这个重不重复木有太大关系,主要是我想让他长度短一点,这样少点点计算

2、英文标点符号变为中文标点符号

        比如说:“?"变为”?“

        这也是为了让他别太复杂,算力有限能省就省(穷.jpg)

3、繁体字转为简体字

       使用zhconv这个包。

       这个包的下载安装以及使用方法可以看看这篇python汉字简繁体转换方法_python繁体字转简体字_一位代码的博客-CSDN博客

这位博主写得很清楚,点赞

4、限定长度

       同样的算力有限哈哈哈,并且也希望数据集中样本长度相差别太大。我是限制每个样本(加上每句对话结束符后)长度不能超过100,这样数据集规模变为48w+。

5、为后面制作词表做准备

       每个词(包括每句对话的结束符)用空格分开。

       最早的时候我的想法是用jieba分词,将所得的所有分词作为我的词表,但是分词后词表的词频更加稀疏了,而且词表大小巨大(如果我没记错的话,词表大小是以万为单位的,害怕极了)。

      于是乎我放弃了这种想法,就单独一个字一个字作为词来分就好了,这样词表大小是6000+个,大大减小了词表大小,词频也不会过于稀疏。

6、代码实现(sol_data.py文件)

其中的"import config"是我将所有配置信息都写在了config.py文件中,方便调整。

其中config.max_len=100,即前面提到的限制长度不超过100;

config.data_path_txt,即预处理后的数据保存地址

  1. import re
  2. from tqdm import tqdm
  3. import zhconv
  4. import config
  5. #处理重复符号的表达,如替换多个重复符号
  6. def delete_repeat(s):
  7. #注释掉的是英文的表达
  8. #s = re.sub('[!]+','!', s)
  9. #s = re.sub('[?]+','?', s)
  10. #s = re.sub('[,]+',',', s)
  11. #s = re.sub('[:]+',':', s)
  12. #s = re.sub('[;]+',';', s)
  13. s = re.sub('[,]+',',', s)
  14. s = re.sub('[!]+','!', s)
  15. s = re.sub('[?]+','?', s)
  16. s = re.sub('[:]+',':', s)
  17. s = re.sub('[;]+',';', s)
  18. s = re.sub('[。]+','。', s)
  19. s = re.sub('[、]+','、', s)
  20. return s
  21. with open('data/origin_train.txt','r',encoding='utf-8') as f: #打开原始数据集
  22. lines = f.readlines()
  23. train_datas = []
  24. temp_data = ''
  25. #每个多轮对话中使用'<EOS>'将其划分
  26. for line in tqdm(lines):
  27. if line!='\n':
  28. line = line.strip() #去除前导后方空格
  29. #英文标点符号置换为中文标点符号
  30. line = line.replace('!','!')
  31. line = line.replace('?','?')
  32. line = line.replace(',',',')
  33. line = line.replace('.','。')
  34. line = line.replace(':',':')
  35. line = line.replace(';',';')
  36. line = zhconv.convert(line, 'zh-cn') #转为简体字
  37. line = " ".join(line)
  38. temp_data+=(line+' <EOS> ')
  39. else:
  40. if len(temp_data.split()) <= config.max_len: #限制长度
  41. train_datas.append(temp_data)
  42. temp_data=''
  43. with open(config.data_path_txt,'w',encoding='utf-8') as f: #将处理后的数据保存在另一个文件中
  44. for train_data in train_datas:
  45. f.write(train_data+'\n')

7、处理完的数据展示


四、词表制作以及转化(word2seq.py文件)

       先定义填充符<PAD>,未知符<UNK>和结束符<EOS>,然后再对数据集中的词进行标号,生成词表与转义词表,最后统计数据集中每个词出现的词频,生成一个词频表可以直观看看咱们的词表情况。

       并且定义词到标号,标号到词的转化的方法,方便后期训练以及测试时使用。

其中,config.word_sequence_dict是保存词典的位置

  1. #生成词表
  2. #构造文本序列化和反序列化方法(文本转数字)
  3. import pickle
  4. import config
  5. from tqdm import tqdm
  6. class Word2Sequence():
  7. PAD_TAG = "<PAD>" #填充编码
  8. UNK_TAG = "<UNK>" #未知编码
  9. EOS_TAG = "<EOS>" #句子结尾
  10. #上面四种情况的对应编号
  11. PAD = 0
  12. UNK = 1
  13. EOS = 2
  14. def __init__(self):
  15. #文字——标号字典
  16. self.dict = {
  17. self.PAD_TAG :self.PAD,
  18. self.UNK_TAG :self.UNK,
  19. self.EOS_TAG :self.EOS
  20. }
  21. #词频统计
  22. self.count = {}
  23. self.fited = False #是否统计过词典了
  24. #以下两个转换都不包括'\t'
  25. #文字转标号(针对单个词)
  26. def to_index(self,word):
  27. """word -> index"""
  28. assert self.fited == True,"必须先进行fit操作"
  29. return self.dict.get(word,self.UNK) #无这个词则用未知代替
  30. #标号转文字(针对单个词)
  31. def to_word(self,index):
  32. """index -> word"""
  33. assert self.fited == True, "必须先进行fit操作"
  34. if index in self.inversed_dict:
  35. return self.inversed_dict[index]
  36. return self.UNK_TAG
  37. # 获取词典长度
  38. def __len__(self):
  39. return len(self.dict)
  40. #统计词频生成词典
  41. def fit(self, sentence):
  42. """
  43. :param sentence:[word1,word2,word3]
  44. """
  45. for a in sentence:
  46. if a not in self.count:
  47. self.count[a] = 0
  48. self.count[a] += 1
  49. self.fited = True
  50. def build_vocab(self, min_count=config.min_count, max_count=None, max_feature=None):
  51. """
  52. :param min_count: 最小出现的次数
  53. :param max_count: 最大出现的次数
  54. :param max_feature: 总词语的最大数量
  55. """
  56. # 限定统计词频范围
  57. if min_count is not None:
  58. self.count = {k: v for k, v in self.count.items() if v >= min_count}
  59. if max_count is not None:
  60. self.count = {k: v for k, v in self.count.items() if v <= max_count}
  61. # 给对应词进行编号
  62. if isinstance(max_feature, int): #是否限制词典的词数
  63. #词频从大到小排序
  64. count = sorted(list(self.count.items()), key=lambda x: x[1])
  65. if max_feature is not None and len(count) > max_feature:
  66. count = count[-int(max_feature):]
  67. for w, _ in count:
  68. self.dict[w] = len(self.dict)
  69. else: #按字典序(方便debug查看)
  70. for w in sorted(self.count.keys()):
  71. self.dict[w] = len(self.dict)
  72. # 准备一个index->word的字典
  73. self.inversed_dict = dict(zip(self.dict.values(), self.dict.keys()))
  74. #debug专用
  75. f_debug_word = open("data/debug_word.txt","w",encoding='utf-8')
  76. t = 0
  77. for key,_ in self.dict.items():
  78. t = t + 1
  79. if t >= 4: #排除那3种情况(填充,未知,结尾)
  80. f_debug_word.write(key+"★ "+str(self.count[key]) + "\n") #使用★ 区分是为了防止其中的词语包含分隔符,对我们后续的操作不利
  81. f_debug_word.close()
  82. def transform(self, sentence,max_len=None,add_eos=True):
  83. """
  84. 实现把句子转化为向量
  85. :param max_len: 限定长度
  86. :param add_eos: 是否在最后再补上<EOS>结束符
  87. :return:
  88. """
  89. assert self.fited == True, "必须先进行fit操作"
  90. r = [self.to_index(i) for i in sentence]
  91. if max_len is not None: #限定长度
  92. if max_len>len(sentence):
  93. if add_eos:
  94. #添加结束符与填充符达到一定长度
  95. r+=[self.EOS]+[self.PAD for _ in range(max_len-len(sentence)-2)]
  96. else: #添加填充符达到一定长度
  97. r += [self.PAD for _ in range(max_len - len(sentence)-1)]
  98. else:
  99. if add_eos:
  100. r = r[:max_len-2]
  101. r += [self.EOS]
  102. else:
  103. r = r[:max_len-1]
  104. else:
  105. if add_eos:
  106. r += [self.EOS]
  107. return r
  108. def inverse_transform(self,indices):
  109. """
  110. 实现从句子向量 转化为 词(文字)
  111. :param indices: [1,2,3....]
  112. :return:[word1,word2.....]
  113. """
  114. sentence = []
  115. for i in indices:
  116. word = self.to_word(i)
  117. sentence.append(word)
  118. return sentence
  119. #以下可供第一次运行,下一次就可以注释掉了
  120. #初始
  121. word_sequence = Word2Sequence()
  122. #词语导入
  123. for line in tqdm(open(config.data_path.txt,encoding='utf-8').readlines()):
  124. word_sequence.fit(line.strip().split())
  125. print("生成词典...")
  126. word_sequence.build_vocab(min_count=None,max_count=None,max_feature=None)
  127. print("词典大小:",len(word_sequence.dict))
  128. pickle.dump(word_sequence,open(config.word_sequence_dict,"wb")) #保存词典

五、数据集加载(dataset.py)

       定义一个ChatDataset类,可以逐一取出数据,并且获取数据集大小。

       并且定义一个处理数据的方法——将句子中的词转为标号,并且进行填充。这里并不是整份数据集都是一样的样本长度,只要保证一个batch里的样本长度一致就好了(不一致就填充),这样设计的原因见后面的模型原理分析。

  1. #构建数据集
  2. import torch
  3. import pickle
  4. import config
  5. from torch.utils.data import Dataset,DataLoader
  6. from tqdm import tqdm
  7. from word2seq import Word2Sequence
  8. word_sequence = pickle.load(open(config.word_sequence_dict,"rb")) #词典加载
  9. class ChatDataset(Dataset):
  10. def __init__(self):
  11. super(ChatDataset,self).__init__()
  12. #读取内容
  13. data_path = config.data_path_txt
  14. self.data_lines = open(data_path,encoding='utf-8').readlines()
  15. #获取对应索引的问答
  16. def __getitem__(self, index):
  17. input = self.data_lines[index].strip().split()[:-1]
  18. target = self.data_lines[index].strip().split()[1:]
  19. #为空则默认读取下一条
  20. if len(input) == 0 or len(target)==0:
  21. input = self.data_lines[index+1].split()[:-1]
  22. target = self.data_lines[index+1].split()[1:]
  23. #此处句子的长度如果大于max_len,那么应该返回max_len
  24. return input,target,len(input),len(target)
  25. #获取数据长度
  26. def __len__(self):
  27. return len(self.data_lines)
  28. # 整理数据————数据集处理方法
  29. def collate_fn(batch):
  30. # 排序
  31. batch = sorted(batch,key=lambda x:x[2],reverse=True) #输入长度排序
  32. input, target, input_length, target_length = zip(*batch)
  33. max_len = max(input_length[0],target_length[0]) #这里只需要固定每个batch里面的样本长度一致就好,并不需要整个数据集的所有样本长度一致
  34. # 词变成词向量,并进行padding的操作
  35. input = torch.LongTensor([word_sequence.transform(i, max_len=max_len, add_eos=False) for i in input])
  36. target = torch.LongTensor([word_sequence.transform(i, max_len=max_len, add_eos=False) for i in target])
  37. input_length = torch.LongTensor(input_length)
  38. target_length = torch.LongTensor(target_length)
  39. return input, target
  40. print("数据集装载...")
  41. data_loader = DataLoader(dataset=ChatDataset(),batch_size=config.batch_size,shuffle=True,collate_fn=collate_fn,drop_last=True)
  42. '''''
  43. #测试专用(debug)
  44. if __name__ == '__main__':
  45. for idx, (input, target) in enumerate(data_loader):
  46. print(idx)
  47. print(input)
  48. print(target)
  49. '''''

六、GPT模型搭建(gpt_model.py)

1、原理解析

(1)Transformer与GPT

      说到GPT就要提到Transformer啦。GPT是Transformer的Decoder部分。

Transformer的网络结构如下:(图是网上找的,侵权立删)

而GPT则如下:

 因为其没有encoder的输出作为另一个输入分支,所以去掉了encoder-decoder的attention机制。

(2)多头注意力机制

A、提出原因

       self attention是通过某种运算来直接计算得到句子在编码过程中每个位置上的注意力权重,然后再以权重和的形式来计算得到整个句子的隐含向量表示(self attention提出原因:在深度学习领域,模型往往需要接收和处理大量的数据,然而在特定的某个时刻,往往只有少部分的数据是重要的。这种情况下应该让模型更加关注这些重要数据,这样他就可以在计算能力有限的情况下,将计算资源分配给更重要的任务,同时解决信息超载问题)。

       但self attention的缺陷是:模型在对当前位置的信息进行编码时,会过度的将注意力集中于自身的位置, 因此提出了通过多头注意力机制来解决这一问题。

       注:为了更好发挥并行输入的特点首先要解决的问题就是要让输入的内容具有一定的位置信息,因此引入位置编码。

 B、注意力机制

     键值对注意力机制公式如下:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/363574
推荐阅读
相关标签