当前位置:   article > 正文

bert-实体抽取

bert-实体抽取
  1. import tensorflow as tf
  2. import numpy as np
  3. from bert import modeling
  4. from bert import tokenization
  5. from bert import optimization
  6. import os
  7. import pandas as pd
  8. flags = tf.flags
  9. FLAGS = flags.FLAGS
  10. flags.DEFINE_integer('train_batch_size',32,'define the train batch size')
  11. flags.DEFINE_integer('num_train_epochs',3,'define the num train epochs')
  12. flags.DEFINE_float('warmup_proportion',0.1,'define the warmup proportion')
  13. flags.DEFINE_float('learning_rate',5e-5,'the initial learning rate for adam')
  14. flags.DEFINE_bool('is_traning',True,'define weather fine-tune the bert model')
  15. data = pd.read_csv('data/event_type_entity_extract_train.csv',encoding='UTF-8',header=None)
  16. data = data[data[2] != u'其他']
  17. classes = set(data[2])
  18. train_data = []
  19. for t,c,n in zip(data[1],data[2],data[3]):
  20. train_data.append((t.strip(),c.strip(),n.strip()))
  21. np.random.shuffle(train_data)
  22. def get_start_end_index(text,subtext):
  23. for i in range(len(text)):
  24. if text[i:i+len(subtext)] == subtext:
  25. return (i,i+len(subtext)-1)
  26. return (-1,-1)
  27. tmp_train_data = []
  28. for item in train_data:
  29. start,end = get_start_end_index(item[0],item[2])
  30. if start != -1:
  31. tmp_train_data.append(item)
  32. train_data = tmp_train_data
  33. np.random.shuffle(train_data)
  34. data = pd.read_csv('data/event_type_entity_extract_train.csv',encoding='UTF-8',header=None)
  35. test_data = []
  36. for t,c in zip(data[1],data[2]):
  37. test_data.append((t.strip(),c.strip()))
  38. config_path = r'D:\NLP_SOUNDAI\learnTensor\package9\bert\chinese_L-12_H-768_A-12\bert_config.json'
  39. checkpoint_path = r'D:\NLP_SOUNDAI\learnTensor\package9\bert\chinese_L-12_H-768_A-12\bert_model.ckpt'
  40. dict_path = r'D:\NLP_SOUNDAI\learnTensor\package9\bert\chinese_L-12_H-768_A-12\vocab.txt'
  41. bert_config = modeling.BertConfig.from_json_file(config_path)
  42. tokenizer = tokenization.FullTokenizer(vocab_file=dict_path,do_lower_case=False)
  43. def input_str_concat(inputList):
  44. assert len(inputList) == 2
  45. t,c = inputList
  46. newStr = '__%s__%s'%(c,t)
  47. tokens = tokenizer.tokenize(newStr)
  48. tokens = ['[CLS]']+tokens+['[SEP]']
  49. input_ids = tokenizer.convert_tokens_to_ids(tokens)
  50. input_mask = [1]*len(input_ids)
  51. segment_ids = [0]*len(input_ids)
  52. return tokens,(input_ids,input_mask,segment_ids)
  53. for i in train_data:
  54. print(input_str_concat(i[:-1]))
  55. break
  56. for i in test_data:
  57. print(input_str_concat(i))
  58. break
  59. def sequence_padding(sequence):
  60. lenlist = [len(item) for item in sequence]
  61. maxlen = max(lenlist)
  62. return np.array([
  63. np.concatenate([item,[0]*(maxlen - len(item))]) if len(item) < maxlen else item for item in sequence
  64. ])
  65. # 定于批训练数据函数
  66. def get_data_batch():
  67. batch_size = FLAGS.train_batch_size
  68. epoch = FLAGS.num_train_epochs
  69. for oneEpoch in range(epoch):
  70. num_batches = ((len(train_data) -1) // batch_size) + 1
  71. for i in range(num_batches):
  72. batch_data = train_data[i*batch_size:(i+1)*batch_size]
  73. yield_batch_data = {
  74. 'input_ids':[],
  75. 'input_mask':[],
  76. 'segment_ids':[],
  77. 'start_ids':[],
  78. 'end_ids':[]
  79. }
  80. for item in batch_data:
  81. tokens,(input_ids,input_mask,segment_ids) = input_str_concat(item[:-1])
  82. start,end = get_start_end_index(item[0],item[2])
  83. start += 1
  84. end += 1
  85. start_ids = [0]*len(input_ids)
  86. end_ids = [0]*len(input_ids)
  87. start_ids[start] = 1
  88. end_ids[end] = 1
  89. yield_batch_data['input_ids'].append(input_ids)
  90. yield_batch_data['input_mask'].append(input_mask)
  91. yield_batch_data['segment_ids'].append(segment_ids)
  92. yield_batch_data['start_ids'].append(start_ids)
  93. yield_batch_data['end_ids'].append(end_ids)
  94. yield_batch_data['input_ids'] = sequence_padding(yield_batch_data['input_ids'])
  95. yield_batch_data['input_mask'] = sequence_padding(yield_batch_data['input_mask'])
  96. yield_batch_data['segment_ids'] = sequence_padding(yield_batch_data['segment_ids'])
  97. yield_batch_data['start_ids'] = sequence_padding(yield_batch_data['start_ids'])
  98. yield_batch_data['end_ids'] = sequence_padding(yield_batch_data['end_ids'])
  99. yield yield_batch_data
  100. with tf.Graph().as_default(),tf.Session() as sess:
  101. input_ids_p = tf.placeholder(dtype=tf.int64,shape=[None,None],name='input_ids_p')
  102. input_mask_p = tf.placeholder(dtype=tf.int64,shape=[None,None],name='input_mask_p')
  103. segment_ids_p = tf.placeholder(dtype=tf.int64,shape=[None,None],name='segment_ids_p')
  104. start_p = tf.placeholder(dtype=tf.int64,shape=[None,None],name='start_p')
  105. end_p = tf.placeholder(dtype=tf.int64,shape=[None,None],name='end_p')
  106. model = modeling.BertModel(config=bert_config,
  107. is_training=False,
  108. input_ids=input_ids_p,
  109. input_mask=input_mask_p,
  110. token_type_ids=segment_ids_p,
  111. use_one_hot_embeddings=False)
  112. output_layer = model.get_sequence_output()
  113. # batch_size, sentence_max_len, word_dim
  114. batch_size, sentence_max_len, word_dim = tf.shape(output_layer)[0],tf.shape(output_layer)[1],tf.shape(output_layer)[2]
  115. output_reshape = tf.reshape(output_layer,shape=[-1,word_dim],name='output_reshape')
  116. with tf.variable_scope('weitht_and_bias',reuse=tf.AUTO_REUSE,initializer=tf.truncated_normal_initializer(mean=0.,stddev=0.05)):
  117. weight_start = tf.get_variable(name='weight_start',shape=[word_dim,1])
  118. bias_start = tf.get_variable(name='bias_start',shape=[1])
  119. weight_end = tf.get_variable(name='weight_end',shape=[word_dim,1])
  120. bias_end = tf.get_variable(name='bias_end',shape=[1])
  121. with tf.name_scope('predict_start_and_end'):
  122. pred_start = tf.nn.bias_add(tf.matmul(output_reshape,weight_start),bias_start)
  123. pred_start = tf.reshape(pred_start,shape=[batch_size,sentence_max_len,1])
  124. pred_start = tf.squeeze(pred_start,-1)
  125. pred_end = tf.nn.bias_add(tf.matmul(output_reshape, weight_end), bias_end)
  126. pred_end = tf.reshape(pred_end, shape=[batch_size, sentence_max_len, 1])
  127. pred_end = tf.squeeze(pred_end, -1)
  128. with tf.name_scope('loss'):
  129. loss1 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred_start,labels=start_p))
  130. end_p -= (1-tf.cumsum(start_p,axis=1))*1e10
  131. loss2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred_end,labels=end_p))
  132. loss = loss1 + loss2
  133. with tf.name_scope('acc_predict'):
  134. start_acc = tf.cast(tf.equal(tf.argmax(start_p,axis=1),tf.argmax(pred_start,axis=1)),dtype=tf.float32)
  135. end_acc = tf.cast(tf.equal(tf.argmax(end_p,axis=1),tf.argmax(pred_end,axis=1)),dtype=tf.float32)
  136. start_acc_val = tf.reduce_mean(start_acc)
  137. end_acc_val = tf.reduce_mean(end_acc)
  138. total_acc = tf.reduce_mean(tf.cast(tf.equal(start_acc,end_acc),dtype=tf.float32))
  139. with tf.name_scope('train_op'):
  140. num_train_steps = int(
  141. len(train_data) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
  142. num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
  143. train_op = optimization.create_optimizer(
  144. loss, FLAGS.learning_rate, num_train_steps, num_warmup_steps, use_tpu=False)
  145. tvars = tf.trainable_variables()
  146. (assignment_map,initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,checkpoint_path)
  147. tf.train.init_from_checkpoint(checkpoint_path,assignment_map)
  148. sess.run(tf.variables_initializer(tf.global_variables()))
  149. total_steps = 0
  150. for yield_batch_data in get_data_batch():
  151. total_steps += 1
  152. feed_dict = {
  153. input_ids_p:yield_batch_data['input_ids'],
  154. input_mask_p:yield_batch_data['input_mask'],
  155. segment_ids_p:yield_batch_data['segment_ids'],
  156. start_p:yield_batch_data['start_ids'],
  157. end_p:yield_batch_data['end_ids']
  158. }
  159. fetches =[train_op,loss,start_acc_val,end_acc_val,total_acc]
  160. _,loss_val,start_acc_now,end_acc_now,total_acc_val = sess.run(fetches,feed_dict = feed_dict)
  161. print('i : %s, loss : %s, start_acc : %s, end_acc : %s, total_acc : %s'%(total_steps,loss_val,start_acc_now,end_acc_now,total_acc_val))

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/378062
推荐阅读
相关标签
  

闽ICP备14008679号