当前位置:   article > 正文

kbqa基于复旦大学的实现代码解析完成步骤 (二)_chatkbqa代码解析

chatkbqa代码解析

      一 已经堆主函数做了部分解释,许多细节是做了一些规则,或者利用了论文立的方法,所以要把代码和论文结合,才能看懂。

      在main_qa用到了两个自己写的文件: 

  1. from KBQA_small_data_version1.kbqa.connectSQLServer import connectSQL
  2. from KBQA_small_data.kbqa.entity_recognize import Entity

那么分别对这两个文件进行解释,

一   connectSQLServer文件

                        不需要做过多的介绍,原因比较简单就是连接一些数据库的信息。

二   entity_recognize

            那么这个文件,其实已经回答的很清楚了,在main_qa主要调用了get_synonym1()函数,entity_connect()函数这两个函数,其实并不多,

  1. #! -*- coding:utf-8 -*-
  2. """
  3. 为了识别问题与答案中的实体;数据保存在sqlserver,
  4. 思路一:首先加载m2e.txt实体到用户词典;对问题进行切词;(1)通过命名实体识别识别实体,识别不出或者识别出来实体的用m2e,搜索实体,根据答案以及实体在KB中寻找三元组存成(q:{e1,e2,...,en}) 以及(e1:[property,v]) 用的函数是 save_evc 保存
  5. """
  6. import jieba.analyse
  7. import math
  8. from collections import Counter
  9. import jieba.posseg
  10. from time import time
  11. from stanfordcorenlp import StanfordCoreNLP
  12. from KBQA_small_data_version1.kbqa.connectSQLServer import connectSQL
  13. import pickle
  14. # jieba.load_userdict('./../data/user_dict.txt')
  15. # host = 'DQ26-000018Z29'ls
  16. # user = 'chen'
  17. # password = '123456'
  18. # host = '172.17.0.169'
  19. host = '172.16.211.128'
  20. user = 'sa'
  21. password = 'chentian184616_'
  22. database= 'chentian'
  23. querySQL = connectSQL(host, user, password, database)
  24. class Entity:
  25. def __init__(self):
  26. self.jieba_pos=['i','j','l' ,'m' ,'nr','nt','nz','b','nrfg']
  27. self.tf_idf=jieba.analyse.extract_tags
  28. self.nlp = StanfordCoreNLP(path_or_host='../../stanford-corenlp/stanford-corenlp-full-2017-06-09/',lang='zh')
  29. self.sql="SELECT * FROM [chentian].[dbo].[baike_triples1] WHERE entity in %(name)s "
  30. self.sql2="SELECT * FROM [chentian].[dbo].[baike_triples1] WHERE entity ='%s' "
  31. # self.question='D:/QA/answer.txt'
  32. self.sql1="SELECT real_entities FROM [chentian].[dbo].[m2e1] where entity='%s'"
  33. self.sql3="SELECT value FROM [chentian].[dbo].[baike_triples1] WHERE property='BaiduTAG' "
  34. # self.KB='./../data/baike_triples.txt'
  35. # self.m2e='./../data/m2e.txt'
  36. 一些数据库的配置以及数据库的语句,比较简单
  37. def name_entity(self,entity):
  38. """
  39. 把实体对应的属性全部返回,包括对应类别
  40. :param entity:
  41. :return:
  42. """
  43. with open(self.KB,'r',encoding='utf-8') as f:
  44. lines=f.readlines()
  45. for line in lines:
  46. words=line.split("\t")
  47. if entity in words[0] :
  48. print(line)
  49. def get_synonym(self,sentence):
  50. """
  51. 获取实体对应的多义词
  52. :param entity:
  53. :return:
  54. """
  55. entiies=[]
  56. for line in open(self.m2e,'r',encoding='utf-8'):
  57. words=line.strip('\n').split("\t")
  58. if words[0] in sentence:
  59. entiies.append(words[1])
  60. return entiies
  61. def get_synonym2(self, entity):
  62. """
  63. 获取实体对应的多义词
  64. :param entity:
  65. :return:
  66. """
  67. entiies = []
  68. for line in open(self.m2e, 'r', encoding='utf-8'):
  69. words = line.strip('\n').split("\t")
  70. if words[0] == entity:
  71. entiies.append(words[1])
  72. return entiies
  73.  这个是真正利用的函数,比较简单,相当于把问句中的多个候选实体候选集合拿出来
  74. def get_synonym1(self,entity):
  75. """
  76. 获取实体对应的多义词
  77. :param entity:
  78. :return:
  79. """
  80. temp_sql = self.sql1 % entity
  81. result = querySQL.Query(temp_sql)
  82. return result
  83. def save_evc(self,sentence,answer):
  84. """
  85. 存储实体value以及对应类别
  86. :return: 返回问题为{key1 :{e1,p1,v1}, {e2,p2,v2}} 的形式
  87. """
  88. jieba_cut = "|".join(jieba.cut(sentence)).split("|")
  89. if "是谁唱的" in sentence or "是谁写的" in sentence or "谁唱" in sentence or "谁写" in sentence:
  90. question_entity = ''
  91. for e in sentence:
  92. if e == "是" or e=="谁": break
  93. question_entity += e
  94. question_entity=[question_entity]
  95. else:
  96. question_entity=self.nlp.ner(sentence) #获得Stanford的实体识别的结果,以及切词结
  97. # pos_re=self.nlp.pos_tag(sentence)
  98. print(question_entity,"2222222222222222")
  99. pos_jieba=jieba.posseg.cut(sentence)
  100. # print(pos_re)
  101. # print(question_entity)
  102. # print(jieba_cut)
  103. if len(jieba_cut)<len(question_entity):
  104. final_words = []
  105. for ele in jieba_cut:
  106. tem_word = ''
  107. flag = False
  108. for el in question_entity:
  109. if el[0] in ele:
  110. if el[1] !='O' and el[1]!='NT' and el[1]!='NUMBER': flag = True
  111. tem_word += el[0]
  112. if flag == True:
  113. final_words.append(tem_word)
  114. question_entity=final_words
  115. # print(question_entity,"^^^^^^^^^^^^^^^^^^^^^^^^")
  116. else:
  117. question_entity=self.entity_connect(question_entity)
  118. # print(question_entity,"@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@2")
  119. for i in pos_jieba:
  120. # print(i.word, i.flag, "#################################################")
  121. if i.flag in self.jieba_pos:
  122. question_entity.append(i.word)
  123. # print(question_entity, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!1")
  124. # #对实体进行连接,相邻作为一个实体在kb中寻找,依次递减
  125. #如果整个句子中不包含实体,则需要从m2e中寻找且此后对应的实体,从名词‘NN’中作为备选实体
  126. if len(question_entity)==0:
  127. jieba_entity=[]
  128. jieba_pos = jieba.posseg.cut(sentence)
  129. for i in jieba_pos:
  130. if i.flag in self.jieba_pos:
  131. jieba_entity.append(i.word)
  132. question_entity=jieba_entity
  133. # print(question_entity,"###################################################")
  134. if len(question_entity)==0:
  135. tf_idf=jieba.analyse.extract_tags
  136. JIE=tf_idf(sentence)
  137. # print(JIE)
  138. words_tag_jieba=JIE[:math.ceil(len(JIE)*0.3)] #这是jieba切词结果,要比stanford更符合中文习惯,
  139. question_entities=[]
  140. try:
  141. words_tag = self.nlp.pos_tag("".join(words_tag_jieba))
  142. if len(words_tag_jieba) < len(words_tag):
  143. final_words = []
  144. for ele in words_tag_jieba:
  145. tem_word = ''
  146. for el in words_tag:
  147. if el[0] in ele:
  148. tem_word += el[0]
  149. final_words.append(tem_word)
  150. question_entity = final_words
  151. else:
  152. for value in words_tag:
  153. # print(value)
  154. # if value[1] == 'NN'or value[1]=='NR':
  155. question_entities.append(value[0])
  156. question_entity=question_entities
  157. # print(question_entity,"$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$4")
  158. except:
  159. print(sentence,"$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$44")
  160. return 0
  161. question_e={}
  162. tf_idf = jieba.analyse.extract_tags
  163. JIE = tf_idf(sentence)
  164. # print(JIE[:2])
  165. # print(question_entity,"**************")
  166. extract={} #提取出问题中的实体以及答案中的value,还有对应的property ,类型为[entity,property,value]
  167. question_entity.extend(JIE[:2])
  168. question_entity=self.connect_entity(jieba_cut,question_entity)
  169. # print(question_entity, "**************")
  170. for entity in question_entity: #查找m2e文件把所有有关的实体全部找出
  171. # print(entity,"88888")
  172. temp_sql_origal = self.sql2 % entity # real_entity 是一个元组,
  173. result_origal = querySQL.Query(temp_sql_origal) # 用sqlserver的in (e1,e2,e3)元组中得到所有的结果,不用再对real_entity实体循环多次select查找
  174. if len(result_origal)!=0:
  175. values = result_origal['value']
  176. for index, va in enumerate(values):
  177. # print(va, answer, va.replace("<a>", '').replace("</a>", '') in answer)
  178. # print(va, answer, answer in va.replace("<a>", '').replace("</a>", ''))
  179. # print(va, answer, self.simple_similar(va.replace("<a>", '').replace("</a>", ''), answer))
  180. # print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
  181. # 对于搜索出来的实体有两个条件可以添加(e,p,v)一是kb被包含在答案中,或者两者简单相似度为0.9以上添加相似对
  182. if va.replace("<a>", '').replace("</a>", '') in answer or answer in va.replace("<a>", '').replace(
  183. "</a>", '') or self.simple_similar(va.replace("<a>", '').replace("</a>", ''), answer) > 0.8:
  184. if ' '.join(list(result_origal.loc[index])) in extract:
  185. extract['&&&&&'.join(list(result_origal.loc[index]))] += 1
  186. else:
  187. extract['&&&&&'.join(list(result_origal.loc[index]))] = 1
  188. entity=entity.replace("'","''")
  189. real_entity= [k.replace("'", "") for k in self.get_synonym1(entity)['real_entities']] #由于实体中可能包含',则替换为'' 在数据库中就认为是单引号
  190. if len(real_entity)==0:real_entity="('"+str(entity)+"')" #如果m2e文件中没有多义词,则实体自己为real_entity
  191. elif len(real_entity)==1:real_entity="('"+str(real_entity[0])+"')"
  192. else:real_entity=tuple(real_entity)
  193. # real_entity=self.get_synonym2(entity)
  194. temp_sql = self.sql % {'name':real_entity} #real_entity 是一个元组,
  195. result = querySQL.Query(temp_sql) #用sqlserver的in (e1,e2,e3)元组中得到所有的结果,不用再对real_entity实体循环多次select查找
  196. values=result['value']
  197. for index,va in enumerate(values):
  198. # print(va,answer,va.replace("<a>",'').replace("</a>",'') in answer)
  199. # print(va,answer,answer in va.replace("<a>",'').replace("</a>",''))
  200. # print(va,answer,self.simple_similar(va.replace("<a>",'').replace("</a>",''),answer))
  201. # print("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
  202. #对于搜索出来的实体有两个条件可以添加(e,p,v)一是kb被包含在答案中,或者两者简单相似度为0.9以上添加相似对
  203. if va.replace("<a>",'').replace("</a>",'') in answer or answer in va.replace("<a>",'').replace("</a>",'') or self.simple_similar(va.replace("<a>",'').replace("</a>",''),answer)>0.8:
  204. if ' '.join(list(result.loc[index])) in extract:
  205. extract['&&&&&'.join(list(result.loc[index]))]+=1
  206. else:
  207. extract['&&&&&'.join(list(result.loc[index]))]=1
  208. if len(extract)!=0:
  209. question_e[sentence]=extract
  210. print(question_e)
  211. return question_e
  212. else:
  213. # print(sentence,"%%%%",answer)
  214. # print("&&&&&&&&&&&")
  215. return 0
  216. def connect_entity(self,question,question_entity):
  217. prio = []
  218. real_enity=[]
  219. for question_e in question_entity:
  220. if question_e in question:
  221. prio.append(question.index(question_e))
  222. k=1
  223. print(question_entity)
  224. while k<len(prio):
  225. if prio[k]-prio[k-1]==1:
  226. temp_enity=question[prio[k-1]]+question[prio[k]]
  227. print(question[prio[k-1]])
  228. print(question[prio[k]])
  229. print(question_entity,"^^^^^^^^^^^")
  230. question_entity.remove(question[prio[k-1]])
  231. question_entity.remove(question[prio[k]])
  232. real_enity.append(temp_enity)
  233. k+=1
  234. real_enity.extend(question_entity)
  235. return real_enity
  236. 这个也是利用的函数,非常将单,函数说明也比较清除就不再说明
  237. def entity_connect(self,entity,flag=['O','NUMBER']):
  238. """
  239. 函数作用就是如果两个识别出来的实体相连就认为是一个,某则作为新的实体添加
  240. """
  241. entities = [] # 根据stanford找到所有问题中的实体
  242. temp_entity = ''
  243. for value in entity:
  244. if value[1] not in flag:
  245. temp_entity += value[0]
  246. else:
  247. if temp_entity != '':
  248. entities.append(temp_entity)
  249. temp_entity = ''
  250. if temp_entity != '':
  251. entities.append(temp_entity)
  252. return entities
  253. def simple_similar(self,answer, sent):
  254. """
  255. 比较两个字符串含有共同字符的个数的比例
  256. :return: 返回比例
  257. """
  258. count = 0
  259. answer_len = len(answer)
  260. sent_len = len(sent)
  261. min_len = 0
  262. if answer_len < sent_len:
  263. min_len = answer_len
  264. for an in answer:
  265. if an in sent:
  266. count += 1
  267. else:
  268. min_len = sent_len
  269. for an in sent:
  270. if an in answer:
  271. count += 1
  272. return count * 1.0 / min_len
  273. def get_pevq(self):
  274. """
  275. 这个函数是所有的主函数,把问题答案QA语料得到基于KB的EV对
  276. :return: 返回【{'奥巴马什么时候出生的': {'奥巴马(圣枪游侠) 其他名称 奥巴马': 1, '奥巴马(美国第44任总统) 出生日期 1961年8月4日': 1}}】
  277. 这样的列表形式,以后得存储形式,在效率不足的情况下,在进行讨论
  278. """
  279. final_pevq=[]
  280. i=0
  281. with open('./../data/train_questions_with_evidence1.txt','r',encoding='utf-8') as f:
  282. lines=f.readlines()
  283. start = time()
  284. for line in lines:
  285. # print(line)
  286. question,answer=line.strip().replace("\t","").split("&&&&&")
  287. question_dict=self.save_evc(question,answer)
  288. if question_dict!=0:
  289. final_pevq.append(question_dict)
  290. i+=1
  291. if i%100==0:
  292. end=time()
  293. print("消耗的时间为"+str(end-start)+"秒")
  294. output=open('./../data/pqev_final_update.pkl','wb')
  295. pickle.dump(final_pevq,output)
  296. output.close()
  297. def store_EV(self,file_path):
  298. """
  299. 本函数的作用是把pqev_final.pkl的构造成类似于e:{v1:频数,v2:频数,...,}和v:{e1:频数,e2:频数,...}
  300. :param file_path: 对应的pqev_final.pkl路径
  301. """
  302. entities_values={}
  303. value_entity={}
  304. file_path=open(file_path,"rb")
  305. train_data=pickle.load(file_path)
  306. for que1 in train_data:
  307. evi = list(que1.values())[0] # 问题中的所有(实体-属性-值)
  308. for key in evi.keys():
  309. value_temp={}
  310. entity_temp={}
  311. e, p, v = key.split("&&&&&") # 接下来对每一个v 遍历每一个问题中所有的相同v,得到对应的实体e,并且记录实体出现的频数 实体e可能出现多次,对第一个概率没有影响,但是对第二个有影响,本来有结果,
  312. if e in entities_values:
  313. if v!='':
  314. if v in entities_values[e]:
  315. entities_values[e][v]+=1
  316. else:
  317. entities_values[e][v]=1
  318. else:
  319. if v!='':
  320. value_temp[v]=1
  321. entities_values[e]=value_temp
  322. if v!='':
  323. if v in value_entity:
  324. if e !='':
  325. if e in value_entity[v]:
  326. value_entity[v][e]+=1
  327. else:
  328. value_entity[v][e]=1
  329. else:
  330. if e!='':
  331. entity_temp[e]=1
  332. value_entity[v]=entity_temp
  333. output = open('./../data/EV_two.pkl', 'wb')
  334. pickle.dump(entities_values, output)
  335. pickle.dump(value_entity,output)
  336. output.close()
  337. file_path.close()
  338. def get_baiduTag(self):
  339. """
  340. 此函数是获取到concept ,并且计数每一个概念的频数作为概念的权重
  341. :return:
  342. """
  343. tags = querySQL.Query(self.sql3) # 用sqlserver的in (e
  344. print(list(tags['value'])[:20])
  345. concept_count=Counter(list(tags['value']))
  346. concept_count=dict(concept_count)
  347. output = open('./../data/concept_count.pkl', 'wb')
  348. pickle.dump(concept_count, output)
  349. output.close()
  350. if __name__=="__main__":
  351. # entity=Entity()
  352. # entity.get_baiduTag()
  353. # entity.store_EV("E:\chenmingwei\KBQA_small_data\data\pqev_final.pkl")
  354. # entity.get_pevq()
  355. EV=open("E:\chenmingwei\KBQA_small_data\data\pqev_final.pkl",'rb')
  356. entity_value=pickle.load(EV)
  357. for key in entity_value:
  358. print(key)
  359. # value_entity=pickle.load(EV)
  360. # for key,value in entity_value.items():
  361. # print(key,value)
  362. # b='全面内战爆发后,国民党反动派在昆明杀害的民盟中央委员是: & & & & & 李公朴'
  363. # a='“昌黎先生”是?&&&&&韩愈'
  364. # que,ans=a.split("&&&&&")
  365. # print(len(ans))
  366. # result=entity.save_evc(que,ans)
  367. # print(result)
  368. # sentence='123广西贺州重大故意伤害案什么时候发生的'
  369. # words=' '.join(jieba.cut(sentence))
  370. # question = '奥巴马什么时候出生的'
  371. # answer = '奥巴马出生于1961年8月4日'
  372. # question='控制器原理'
  373. # answer='控制器(英文名称:controller)是指按照预定顺序改变主电路或控制电路的接线和改变<a>电路'
  374. # start1=datetime.datetime.now()
  375. # final_dict = entity.save_evc(question, answer)
  376. # print(final_dict)
  377. # result=entity.get_synonym1('蝴蝶')
  378. # result=tuple([k.replace("'",'"') for k in result['real_entities']])
  379. # temp_sql = entity.sql % {'name': result} # real_entity 是一个元组,
  380. # print(temp_sql)
  381. # result = querySQL.Query(temp_sql)
  382. # print(result)
  383. # end1=datetime.datetime.now()
  384. # entiies=entity.get_synonym() #用于获取所有问题的实体,不进行切词处理,防止因为切词造成实体的丢失
  385. #对于答案,
  386. # for entit in entiies:
  387. # entity.name_entity(entit,answer)
  388.   当然这个文件中还有其他函数,就是训练使用的函数,至此差不多就完成了,在训练好模型的前提下,整个服务启动,能够使用的文件比较简单。主要是依赖就是数据集,训练模型参数,就是初始化函数加载的文件,依赖包的安装,用这么多函数,主要就是standford的命名实体识别局限性导致的。

接下来就讲解训练部分的代码,请看三。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/722764
推荐阅读
相关标签
  

闽ICP备14008679号