当前位置:   article > 正文

NLP-新闻文本分类(六、基于深度学习的文本分类3_class sentencoder(nn.module):

class sentencoder(nn.module):

BERT

微调将最后一层的第一个token即[CLS]的隐藏向量作为句子的表示,然后输入到softmax层进行分类。

  1. import logging
  2. import random
  3. import numpy as np
  4. import torch
  5. logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(levelname)s: %(message)s')
  6. # set seed
  7. seed = 666
  8. random.seed(seed)
  9. np.random.seed(seed)
  10. torch.cuda.manual_seed(seed)
  11. torch.manual_seed(seed)
  12. # set cuda
  13. gpu = 0
  14. use_cuda = gpu >= 0 and torch.cuda.is_available()
  15. if use_cuda:
  16. torch.cuda.set_device(gpu)
  17. device = torch.device("cuda", gpu)
  18. else:
  19. device = torch.device("cpu")
  20. logging.info("Use cuda: %s, gpu id: %d.", use_cuda, gpu)
  21. # split data to 10 fold
  22. fold_num = 10
  23. data_file = '../data/train_set.csv'
  24. import pandas as pd
  25. def all_data2fold(fold_num, num=10000):
  26. fold_data = []
  27. f = pd.read_csv(data_file, sep='\t', encoding='UTF-8')
  28. texts = f['text'].tolist()[:num]
  29. labels = f['label'].tolist()[:num]
  30. total = len(labels)
  31. index = list(range(total))
  32. np.random.shuffle(index)
  33. all_texts = []
  34. all_labels = []
  35. for i in index:
  36. all_texts.append(texts[i])
  37. all_labels.append(labels[i])
  38. label2id = {}
  39. for i in range(total):
  40. label = str(all_labels[i])
  41. if label not in label2id:
  42. label2id[label] = [i]
  43. else:
  44. label2id[label].append(i)
  45. all_index = [[] for _ in range(fold_num)]
  46. for label, data in label2id.items():
  47. # print(label, len(data))
  48. batch_size = int(len(data) / fold_num)
  49. other = len(data) - batch_size * fold_num
  50. for i in range(fold_num):
  51. cur_batch_size = batch_size + 1 if i < other else batch_size
  52. # print(cur_batch_size)
  53. batch_data = [data[i * batch_size + b] for b in range(cur_batch_size)]
  54. all_index[i].extend(batch_data)
  55. batch_size = int(total / fold_num)
  56. other_texts = []
  57. other_labels = []
  58. other_num = 0
  59. start = 0
  60. for fold in range(fold_num):
  61. num = len(all_index[fold])
  62. texts = [all_texts[i] for i in all_index[fold]]
  63. labels = [all_labels[i] for i in all_index[fold]]
  64. if num > batch_size:
  65. fold_texts = texts[:batch_size]
  66. other_texts.extend(texts[batch_size:])
  67. fold_labels = labels[:batch_size]
  68. other_labels.extend(labels[batch_size:])
  69. other_num += num - batch_size
  70. elif num < batch_size:
  71. end = start + batch_size - num
  72. fold_texts = texts + other_texts[start: end]
  73. fold_labels = labels + other_labels[start: end]
  74. start = end
  75. else:
  76. fold_texts = texts
  77. fold_labels = labels
  78. assert batch_size == len(fold_labels)
  79. # shuffle
  80. index = list(range(batch_size))
  81. np.random.shuffle(index)
  82. shuffle_fold_texts = []
  83. shuffle_fold_labels = []
  84. for i in index:
  85. shuffle_fold_texts.append(fold_texts[i])
  86. shuffle_fold_labels.append(fold_labels[i])
  87. data = {'label': shuffle_fold_labels, 'text': shuffle_fold_texts}
  88. fold_data.append(data)
  89. logging.info("Fold lens %s", str([len(data['label']) for data in fold_data]))
  90. return fold_data
  91. fold_data = all_data2fold(10)
  92. # build train, dev, test data
  93. fold_id = 9
  94. # dev
  95. dev_data = fold_data[fold_id]
  96. # train
  97. train_texts = []
  98. train_labels = []
  99. for i in range(0, fold_id):
  100. data = fold_data[i]
  101. train_texts.extend(data['text'])
  102. train_labels.extend(data['label'])
  103. train_data = {'label': train_labels, 'text': train_texts}
  104. # test
  105. test_data_file = '../data/test_a.csv'
  106. f = pd.read_csv(test_data_file, sep='\t', encoding='UTF-8')
  107. texts = f['text'].tolist()
  108. test_data = {'label': [0] * len(texts), 'text': texts}
  109. # build vocab
  110. from collections import Counter
  111. from transformers import BasicTokenizer
  112. basic_tokenizer = BasicTokenizer()
  113. class Vocab():
  114. def __init__(self, train_data):
  115. self.min_count = 5
  116. self.pad = 0
  117. self.unk = 1
  118. self._id2word = ['[PAD]', '[UNK]']
  119. self._id2extword = ['[PAD]', '[UNK]']
  120. self._id2label = []
  121. self.target_names = []
  122. self.build_vocab(train_data)
  123. reverse = lambda x: dict(zip(x, range(len(x))))
  124. self._word2id = reverse(self._id2word)
  125. self._label2id = reverse(self._id2label)
  126. logging.info("Build vocab: words %d, labels %d." % (self.word_size, self.label_size))
  127. def build_vocab(self, data):
  128. self.word_counter = Counter()
  129. for text in data['text']:
  130. words = text.split()
  131. for word in words:
  132. self.word_counter[word] += 1
  133. for word, count in self.word_counter.most_common():
  134. if count >= self.min_count:
  135. self._id2word.append(word)
  136. label2name = {0: '科技', 1: '股票', 2: '体育', 3: '娱乐', 4: '时政', 5: '社会', 6: '教育', 7: '财经',
  137. 8: '家居', 9: '游戏', 10: '房产', 11: '时尚', 12: '彩票', 13: '星座'}
  138. self.label_counter = Counter(data['label'])
  139. for label in range(len(self.label_counter)):
  140. count = self.label_counter[label]
  141. self._id2label.append(label)
  142. self.target_names.append(label2name[label])
  143. def load_pretrained_embs(self, embfile):
  144. with open(embfile, encoding='utf-8') as f:
  145. lines = f.readlines()
  146. items = lines[0].split()
  147. word_count, embedding_dim = int(items[0]), int(items[1])
  148. index = len(self._id2extword)
  149. embeddings = np.zeros((word_count + index, embedding_dim))
  150. for line in lines[1:]:
  151. values = line.split()
  152. self._id2extword.append(values[0])
  153. vector = np.array(values[1:], dtype='float64')
  154. embeddings[self.unk] += vector
  155. embeddings[index] = vector
  156. index += 1
  157. embeddings[self.unk] = embeddings[self.unk] / word_count
  158. embeddings = embeddings / np.std(embeddings)
  159. reverse = lambda x: dict(zip(x, range(len(x))))
  160. self._extword2id = reverse(self._id2extword)
  161. assert len(set(self._id2extword)) == len(self._id2extword)
  162. return embeddings
  163. def word2id(self, xs):
  164. if isinstance(xs, list):
  165. return [self._word2id.get(x, self.unk) for x in xs]
  166. return self._word2id.get(xs, self.unk)
  167. def extword2id(self, xs):
  168. if isinstance(xs, list):
  169. return [self._extword2id.get(x, self.unk) for x in xs]
  170. return self._extword2id.get(xs, self.unk)
  171. def label2id(self, xs):
  172. if isinstance(xs, list):
  173. return [self._label2id.get(x, self.unk) for x in xs]
  174. return self._label2id.get(xs, self.unk)
  175. @property
  176. def word_size(self):
  177. return len(self._id2word)
  178. @property
  179. def extword_size(self):
  180. return len(self._id2extword)
  181. @property
  182. def label_size(self):
  183. return len(self._id2label)
  184. vocab = Vocab(train_data)
  185. # build module
  186. import torch.nn as nn
  187. import torch.nn.functional as F
  188. class Attention(nn.Module):
  189. def __init__(self, hidden_size):
  190. super(Attention, self).__init__()
  191. self.weight = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
  192. self.weight.data.normal_(mean=0.0, std=0.05)
  193. self.bias = nn.Parameter(torch.Tensor(hidden_size))
  194. b = np.zeros(hidden_size, dtype=np.float32)
  195. self.bias.data.copy_(torch.from_numpy(b))
  196. self.query = nn.Parameter(torch.Tensor(hidden_size))
  197. self.query.data.normal_(mean=0.0, std=0.05)
  198. def forward(self, batch_hidden, batch_masks):
  199. # batch_hidden: b x len x hidden_size (2 * hidden_size of lstm)
  200. # batch_masks: b x len
  201. # linear
  202. key = torch.matmul(batch_hidden, self.weight) + self.bias # b x len x hidden
  203. # compute attention
  204. outputs = torch.matmul(key, self.query) # b x len
  205. masked_outputs = outputs.masked_fill((1 - batch_masks).bool(), float(-1e32))
  206. attn_scores = F.softmax(masked_outputs, dim=1) # b x len
  207. # 对于全零向量,-1e32的结果为 1/len, -inf为nan, 额外补0
  208. masked_attn_scores = attn_scores.masked_fill((1 - batch_masks).bool(), 0.0)
  209. # sum weighted sources
  210. batch_outputs = torch.bmm(masked_attn_scores.unsqueeze(1), key).squeeze(1) # b x hidden
  211. return batch_outputs, attn_scores
  212. # build word encoder
  213. bert_path = '../emb/bert-mini/'
  214. dropout = 0.15
  215. from transformers import BertModel
  216. class WordBertEncoder(nn.Module):
  217. def __init__(self):
  218. super(WordBertEncoder, self).__init__()
  219. self.dropout = nn.Dropout(dropout)
  220. self.tokenizer = WhitespaceTokenizer()
  221. self.bert = BertModel.from_pretrained(bert_path)
  222. self.pooled = False
  223. logging.info('Build Bert encoder with pooled {}.'.format(self.pooled))
  224. def encode(self, tokens):
  225. tokens = self.tokenizer.tokenize(tokens)
  226. return tokens
  227. def get_bert_parameters(self):
  228. no_decay = ['bias', 'LayerNorm.weight']
  229. optimizer_parameters = [
  230. {'params': [p for n, p in self.bert.named_parameters() if not any(nd in n for nd in no_decay)],
  231. 'weight_decay': 0.01},
  232. {'params': [p for n, p in self.bert.named_parameters() if any(nd in n for nd in no_decay)],
  233. 'weight_decay': 0.0}
  234. ]
  235. return optimizer_parameters
  236. def forward(self, input_ids, token_type_ids):
  237. # input_ids: sen_num x bert_len
  238. # token_type_ids: sen_num x bert_len
  239. # sen_num x bert_len x 256, sen_num x 256
  240. sequence_output, pooled_output = self.bert(input_ids=input_ids, token_type_ids=token_type_ids)
  241. if self.pooled:
  242. reps = pooled_output
  243. else:
  244. reps = sequence_output[:, 0, :] # sen_num x 256
  245. if self.training:
  246. reps = self.dropout(reps)
  247. return reps
  248. class WhitespaceTokenizer():
  249. """WhitespaceTokenizer with vocab."""
  250. def __init__(self):
  251. vocab_file = bert_path + 'vocab.txt'
  252. self._token2id = self.load_vocab(vocab_file)
  253. self._id2token = {v: k for k, v in self._token2id.items()}
  254. self.max_len = 256
  255. self.unk = 1
  256. logging.info("Build Bert vocab with size %d." % (self.vocab_size))
  257. def load_vocab(self, vocab_file):
  258. f = open(vocab_file, 'r')
  259. lines = f.readlines()
  260. lines = list(map(lambda x: x.strip(), lines))
  261. vocab = dict(zip(lines, range(len(lines))))
  262. return vocab
  263. def tokenize(self, tokens):
  264. assert len(tokens) <= self.max_len - 2
  265. tokens = ["[CLS]"] + tokens + ["[SEP]"]
  266. output_tokens = self.token2id(tokens)
  267. return output_tokens
  268. def token2id(self, xs):
  269. if isinstance(xs, list):
  270. return [self._token2id.get(x, self.unk) for x in xs]
  271. return self._token2id.get(xs, self.unk)
  272. @property
  273. def vocab_size(self):
  274. return len(self._id2token)
  275. # build sent encoder
  276. sent_hidden_size = 256
  277. sent_num_layers = 2
  278. class SentEncoder(nn.Module):
  279. def __init__(self, sent_rep_size):
  280. super(SentEncoder, self).__init__()
  281. self.dropout = nn.Dropout(dropout)
  282. self.sent_lstm = nn.LSTM(
  283. input_size=sent_rep_size,
  284. hidden_size=sent_hidden_size,
  285. num_layers=sent_num_layers,
  286. batch_first=True,
  287. bidirectional=True
  288. )
  289. def forward(self, sent_reps, sent_masks):
  290. # sent_reps: b x doc_len x sent_rep_size
  291. # sent_masks: b x doc_len
  292. sent_hiddens, _ = self.sent_lstm(sent_reps) # b x doc_len x hidden*2
  293. sent_hiddens = sent_hiddens * sent_masks.unsqueeze(2)
  294. if self.training:
  295. sent_hiddens = self.dropout(sent_hiddens)
  296. return sent_hiddens
  297. # build model
  298. class Model(nn.Module):
  299. def __init__(self, vocab):
  300. super(Model, self).__init__()
  301. self.sent_rep_size = 256
  302. self.doc_rep_size = sent_hidden_size * 2
  303. self.all_parameters = {}
  304. parameters = []
  305. self.word_encoder = WordBertEncoder()
  306. bert_parameters = self.word_encoder.get_bert_parameters()
  307. self.sent_encoder = SentEncoder(self.sent_rep_size)
  308. self.sent_attention = Attention(self.doc_rep_size)
  309. parameters.extend(list(filter(lambda p: p.requires_grad, self.sent_encoder.parameters())))
  310. parameters.extend(list(filter(lambda p: p.requires_grad, self.sent_attention.parameters())))
  311. self.out = nn.Linear(self.doc_rep_size, vocab.label_size, bias=True)
  312. parameters.extend(list(filter(lambda p: p.requires_grad, self.out.parameters())))
  313. if use_cuda:
  314. self.to(device)
  315. if len(parameters) > 0:
  316. self.all_parameters["basic_parameters"] = parameters
  317. self.all_parameters["bert_parameters"] = bert_parameters
  318. logging.info('Build model with bert word encoder, lstm sent encoder.')
  319. para_num = sum([np.prod(list(p.size())) for p in self.parameters()])
  320. logging.info('Model param num: %.2f M.' % (para_num / 1e6))
  321. def forward(self, batch_inputs):
  322. # batch_inputs(batch_inputs1, batch_inputs2): b x doc_len x sent_len
  323. # batch_masks : b x doc_len x sent_len
  324. batch_inputs1, batch_inputs2, batch_masks = batch_inputs
  325. batch_size, max_doc_len, max_sent_len = batch_inputs1.shape[0], batch_inputs1.shape[1], batch_inputs1.shape[2]
  326. batch_inputs1 = batch_inputs1.view(batch_size * max_doc_len, max_sent_len) # sen_num x sent_len
  327. batch_inputs2 = batch_inputs2.view(batch_size * max_doc_len, max_sent_len) # sen_num x sent_len
  328. batch_masks = batch_masks.view(batch_size * max_doc_len, max_sent_len) # sen_num x sent_len
  329. sent_reps = self.word_encoder(batch_inputs1, batch_inputs2) # sen_num x sent_rep_size
  330. sent_reps = sent_reps.view(batch_size, max_doc_len, self.sent_rep_size) # b x doc_len x sent_rep_size
  331. batch_masks = batch_masks.view(batch_size, max_doc_len, max_sent_len) # b x doc_len x max_sent_len
  332. sent_masks = batch_masks.bool().any(2).float() # b x doc_len
  333. sent_hiddens = self.sent_encoder(sent_reps, sent_masks) # b x doc_len x doc_rep_size
  334. doc_reps, atten_scores = self.sent_attention(sent_hiddens, sent_masks) # b x doc_rep_size
  335. batch_outputs = self.out(doc_reps) # b x num_labels
  336. return batch_outputs
  337. model = Model(vocab)
  338. # build optimizer
  339. learning_rate = 2e-4
  340. bert_lr = 5e-5
  341. decay = .75
  342. decay_step = 1000
  343. from transformers import AdamW, get_linear_schedule_with_warmup
  344. class Optimizer:
  345. def __init__(self, model_parameters, steps):
  346. self.all_params = []
  347. self.optims = []
  348. self.schedulers = []
  349. for name, parameters in model_parameters.items():
  350. if name.startswith("basic"):
  351. optim = torch.optim.Adam(parameters, lr=learning_rate)
  352. self.optims.append(optim)
  353. l = lambda step: decay ** (step // decay_step)
  354. scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=l)
  355. self.schedulers.append(scheduler)
  356. self.all_params.extend(parameters)
  357. elif name.startswith("bert"):
  358. optim_bert = AdamW(parameters, bert_lr, eps=1e-8)
  359. self.optims.append(optim_bert)
  360. scheduler_bert = get_linear_schedule_with_warmup(optim_bert, 0, steps)
  361. self.schedulers.append(scheduler_bert)
  362. for group in parameters:
  363. for p in group['params']:
  364. self.all_params.append(p)
  365. else:
  366. Exception("no nameed parameters.")
  367. self.num = len(self.optims)
  368. def step(self):
  369. for optim, scheduler in zip(self.optims, self.schedulers):
  370. optim.step()
  371. scheduler.step()
  372. optim.zero_grad()
  373. def zero_grad(self):
  374. for optim in self.optims:
  375. optim.zero_grad()
  376. def get_lr(self):
  377. lrs = tuple(map(lambda x: x.get_lr()[-1], self.schedulers))
  378. lr = ' %.5f' * self.num
  379. res = lr % lrs
  380. return res
  381. # build dataset
  382. def sentence_split(text, vocab, max_sent_len=256, max_segment=16):
  383. words = text.strip().split()
  384. document_len = len(words)
  385. index = list(range(0, document_len, max_sent_len))
  386. index.append(document_len)
  387. segments = []
  388. for i in range(len(index) - 1):
  389. segment = words[index[i]: index[i + 1]]
  390. assert len(segment) > 0
  391. segment = [word if word in vocab._id2word else '<UNK>' for word in segment]
  392. segments.append([len(segment), segment])
  393. assert len(segments) > 0
  394. if len(segments) > max_segment:
  395. segment_ = int(max_segment / 2)
  396. return segments[:segment_] + segments[-segment_:]
  397. else:
  398. return segments
  399. def get_examples(data, word_encoder, vocab, max_sent_len=256, max_segment=8):
  400. label2id = vocab.label2id
  401. examples = []
  402. for text, label in zip(data['text'], data['label']):
  403. # label
  404. id = label2id(label)
  405. # words
  406. sents_words = sentence_split(text, vocab, max_sent_len-2, max_segment)
  407. doc = []
  408. for sent_len, sent_words in sents_words:
  409. token_ids = word_encoder.encode(sent_words)
  410. sent_len = len(token_ids)
  411. token_type_ids = [0] * sent_len
  412. doc.append([sent_len, token_ids, token_type_ids])
  413. examples.append([id, len(doc), doc])
  414. logging.info('Total %d docs.' % len(examples))
  415. return examples
  416. # some function
  417. from sklearn.metrics import f1_score, precision_score, recall_score
  418. def get_score(y_ture, y_pred):
  419. y_ture = np.array(y_ture)
  420. y_pred = np.array(y_pred)
  421. f1 = f1_score(y_ture, y_pred, average='macro') * 100
  422. p = precision_score(y_ture, y_pred, average='macro') * 100
  423. r = recall_score(y_ture, y_pred, average='macro') * 100
  424. return str((reformat(p, 2), reformat(r, 2), reformat(f1, 2))), reformat(f1, 2)
  425. def reformat(num, n):
  426. return float(format(num, '0.' + str(n) + 'f'))
  427. # build trainer
  428. import time
  429. from sklearn.metrics import classification_report
  430. clip = 5.0
  431. epochs = 1
  432. early_stops = 3
  433. log_interval = 50
  434. test_batch_size = 16
  435. train_batch_size = 16
  436. save_model = './bert.bin'
  437. save_test = './bert.csv'
  438. class Trainer():
  439. def __init__(self, model, vocab):
  440. self.model = model
  441. self.report = True
  442. self.train_data = get_examples(train_data, model.word_encoder, vocab)
  443. self.batch_num = int(np.ceil(len(self.train_data) / float(train_batch_size)))
  444. self.dev_data = get_examples(dev_data, model.word_encoder, vocab)
  445. self.test_data = get_examples(test_data, model.word_encoder, vocab)
  446. # criterion
  447. self.criterion = nn.CrossEntropyLoss()
  448. # label name
  449. self.target_names = vocab.target_names
  450. # optimizer
  451. self.optimizer = Optimizer(model.all_parameters, steps=self.batch_num * epochs)
  452. # count
  453. self.step = 0
  454. self.early_stop = -1
  455. self.best_train_f1, self.best_dev_f1 = 0, 0
  456. self.last_epoch = epochs
  457. def train(self):
  458. logging.info('Start training...')
  459. for epoch in range(1, epochs + 1):
  460. train_f1 = self._train(epoch)
  461. dev_f1 = self._eval(epoch)
  462. if self.best_dev_f1 <= dev_f1:
  463. logging.info(
  464. "Exceed history dev = %.2f, current dev = %.2f" % (self.best_dev_f1, dev_f1))
  465. torch.save(self.model.state_dict(), save_model)
  466. self.best_train_f1 = train_f1
  467. self.best_dev_f1 = dev_f1
  468. self.early_stop = 0
  469. else:
  470. self.early_stop += 1
  471. if self.early_stop == early_stops:
  472. logging.info(
  473. "Eearly stop in epoch %d, best train: %.2f, dev: %.2f" % (
  474. epoch - early_stops, self.best_train_f1, self.best_dev_f1))
  475. self.last_epoch = epoch
  476. break
  477. def test(self):
  478. self.model.load_state_dict(torch.load(save_model))
  479. self._eval(self.last_epoch + 1, test=True)
  480. def _train(self, epoch):
  481. self.optimizer.zero_grad()
  482. self.model.train()
  483. start_time = time.time()
  484. epoch_start_time = time.time()
  485. overall_losses = 0
  486. losses = 0
  487. batch_idx = 1
  488. y_pred = []
  489. y_true = []
  490. for batch_data in data_iter(self.train_data, train_batch_size, shuffle=True):
  491. torch.cuda.empty_cache()
  492. batch_inputs, batch_labels = self.batch2tensor(batch_data)
  493. batch_outputs = self.model(batch_inputs)
  494. loss = self.criterion(batch_outputs, batch_labels)
  495. loss.backward()
  496. loss_value = loss.detach().cpu().item()
  497. losses += loss_value
  498. overall_losses += loss_value
  499. y_pred.extend(torch.max(batch_outputs, dim=1)[1].cpu().numpy().tolist())
  500. y_true.extend(batch_labels.cpu().numpy().tolist())
  501. nn.utils.clip_grad_norm_(self.optimizer.all_params, max_norm=clip)
  502. for optimizer, scheduler in zip(self.optimizer.optims, self.optimizer.schedulers):
  503. optimizer.step()
  504. scheduler.step()
  505. self.optimizer.zero_grad()
  506. self.step += 1
  507. if batch_idx % log_interval == 0:
  508. elapsed = time.time() - start_time
  509. lrs = self.optimizer.get_lr()
  510. logging.info(
  511. '| epoch {:3d} | step {:3d} | batch {:3d}/{:3d} | lr{} | loss {:.4f} | s/batch {:.2f}'.format(
  512. epoch, self.step, batch_idx, self.batch_num, lrs,
  513. losses / log_interval,
  514. elapsed / log_interval))
  515. losses = 0
  516. start_time = time.time()
  517. batch_idx += 1
  518. overall_losses /= self.batch_num
  519. during_time = time.time() - epoch_start_time
  520. # reformat
  521. overall_losses = reformat(overall_losses, 4)
  522. score, f1 = get_score(y_true, y_pred)
  523. logging.info(
  524. '| epoch {:3d} | score {} | f1 {} | loss {:.4f} | time {:.2f}'.format(epoch, score, f1,
  525. overall_losses,
  526. during_time))
  527. if set(y_true) == set(y_pred) and self.report:
  528. report = classification_report(y_true, y_pred, digits=4, target_names=self.target_names)
  529. logging.info('\n' + report)
  530. return f1
  531. def _eval(self, epoch, test=False):
  532. self.model.eval()
  533. start_time = time.time()
  534. data = self.test_data if test else self.dev_data
  535. y_pred = []
  536. y_true = []
  537. with torch.no_grad():
  538. for batch_data in data_iter(data, test_batch_size, shuffle=False):
  539. torch.cuda.empty_cache()
  540. batch_inputs, batch_labels = self.batch2tensor(batch_data)
  541. batch_outputs = self.model(batch_inputs)
  542. y_pred.extend(torch.max(batch_outputs, dim=1)[1].cpu().numpy().tolist())
  543. y_true.extend(batch_labels.cpu().numpy().tolist())
  544. score, f1 = get_score(y_true, y_pred)
  545. during_time = time.time() - start_time
  546. if test:
  547. df = pd.DataFrame({'label': y_pred})
  548. df.to_csv(save_test, index=False, sep=',')
  549. else:
  550. logging.info(
  551. '| epoch {:3d} | dev | score {} | f1 {} | time {:.2f}'.format(epoch, score, f1,
  552. during_time))
  553. if set(y_true) == set(y_pred) and self.report:
  554. report = classification_report(y_true, y_pred, digits=4, target_names=self.target_names)
  555. logging.info('\n' + report)
  556. return f1
  557. def batch2tensor(self, batch_data):
  558. '''
  559. [[label, doc_len, [[sent_len, [sent_id0, ...], [sent_id1, ...]], ...]]
  560. '''
  561. batch_size = len(batch_data)
  562. doc_labels = []
  563. doc_lens = []
  564. doc_max_sent_len = []
  565. for doc_data in batch_data:
  566. doc_labels.append(doc_data[0])
  567. doc_lens.append(doc_data[1])
  568. sent_lens = [sent_data[0] for sent_data in doc_data[2]]
  569. max_sent_len = max(sent_lens)
  570. doc_max_sent_len.append(max_sent_len)
  571. max_doc_len = max(doc_lens)
  572. max_sent_len = max(doc_max_sent_len)
  573. batch_inputs1 = torch.zeros((batch_size, max_doc_len, max_sent_len), dtype=torch.int64)
  574. batch_inputs2 = torch.zeros((batch_size, max_doc_len, max_sent_len), dtype=torch.int64)
  575. batch_masks = torch.zeros((batch_size, max_doc_len, max_sent_len), dtype=torch.float32)
  576. batch_labels = torch.LongTensor(doc_labels)
  577. for b in range(batch_size):
  578. for sent_idx in range(doc_lens[b]):
  579. sent_data = batch_data[b][2][sent_idx]
  580. for word_idx in range(sent_data[0]):
  581. batch_inputs1[b, sent_idx, word_idx] = sent_data[1][word_idx]
  582. batch_inputs2[b, sent_idx, word_idx] = sent_data[2][word_idx]
  583. batch_masks[b, sent_idx, word_idx] = 1
  584. if use_cuda:
  585. batch_inputs1 = batch_inputs1.to(device)
  586. batch_inputs2 = batch_inputs2.to(device)
  587. batch_masks = batch_masks.to(device)
  588. batch_labels = batch_labels.to(device)
  589. return (batch_inputs1, batch_inputs2, batch_masks), batch_labels
  590. # train
  591. trainer = Trainer(model, vocab)
  592. trainer.train()
  593. # test
  594. trainer.test()

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/731135
推荐阅读
相关标签
  

闽ICP备14008679号