赞
踩
《原始论文:Attention-based bidirectional long short-term memory networks for relation classification》
传统的方法中,大多数研究依赖于一些现有的词汇资源(例如WordNet)、NLP系 统或一些手工提取的特征。这样的方法可能导致计算复杂度的增加,并且特征提取工作本身会耗费大量的时间和精力,特征提取质量的对于实验的结果也有很大的影响。
提出了 ATT-BLSTM的网络结构解决关系端对端识别问题
这篇论文从这一角度出发,提出一个基于Attention机制的双向 LSTM神经网络模型进行关系抽取研究,Attention机制能够自动 发现那些对于分类起到关键作用的词,使得这个模型可以从每个句子中捕获最重要的语义信息,它不依赖于任何外部的知识或者NLP系统
巧妙地在双向LSTM模型中加入Attention机制,用于关系抽取任务,避免了传统的 任务中复杂的特征工程,大大简化了实验过程并得到相当不错的结果,也为相关的研究提供了可操作性的思路
这篇论文的整体的逻辑十分清晰,紧紧围绕研究动机.整篇论文的思路十分简单,模型也一目了然,但是结果表现优秀
ATT-BLSTM网络结构以word embeding为基础,加入实体标识位,通过ATT-BLSTM的结构让模型动态区分关系分类的重要词汇。
As shown in Figure 1, the model proposed in this paper contains five components:
Attention 原理:Attention Mechanism可以帮助模型对输入的X每个部分赋予不同的权重,抽取出更加关键及重要的信息,使模型做出更加准确的判断,同时不会对模型的计算和存储带来更大的开销。
根据Attention的计算区域,可以分成以下几种:
对实体前后添加特定标识符标明实体位置
采用带约束的正则损失
compare various model configurations on the SemEval-2010 Task 8 dataset
不依赖任何其他NLP工具
引入Attention-BiLSTM结构
网格结构完全不依何nlp工具或词法资源,只需要带位置标识的原始文本作为输入。
This model does not rely on NLP tools or lexical resources to get, it uses raw text with position indicators as input.
train_file.txt【样本1-8000】
1 "The system as described above has its greatest application in an arrayed <e1>configuration</e1> of antenna <e2>elements</e2>." Component-Whole(e2,e1) Comment: Not a collection: there is structure here, organisation. 2 "The <e1>child</e1> was carefully wrapped and bound into the <e2>cradle</e2> by means of a cord." Other Comment: 3 "The <e1>author</e1> of a keygen uses a <e2>disassembler</e2> to look at the raw assembly code." Instrument-Agency(e2,e1) Comment: 4 "A misty <e1>ridge</e1> uprises from the <e2>surge</e2>." Other Comment: 5 "The <e1>student</e1> <e2>association</e2> is the voice of the undergraduate student population of the State University of New York at Buffalo." Member-Collection(e1,e2) Comment: 6 "This is the sprawling <e1>complex</e1> that is Peru's largest <e2>producer</e2> of silver." Other Comment: 7 "The current view is that the chronic <e1>inflammation</e1> in the distal part of the stomach caused by Helicobacter pylori <e2>infection</e2> results in an increased acid production from the non-infected upper corpus region of the stomach." Cause-Effect(e2,e1) Comment: 8 "<e1>People</e1> have been moving back into <e2>downtown</e2>." Entity-Destination(e1,e2) Comment: 9 "The <e1>lawsonite</e1> was contained in a <e2>platinum crucible</e2> and the counter-weight was a plastic crucible with metal pieces." Content-Container(e1,e2) Comment: prototypical example 10 "The solute was placed inside a beaker and 5 mL of the <e1>solvent</e1> was pipetted into a 25 mL glass <e2>flask</e2> for each trial." Entity-Destination(e1,e2) Comment: ......
test_file.txt【样本8001-10717】
8001 "The most common <e1>audits</e1> were about <e2>waste</e2> and recycling." Message-Topic(e1,e2) Comment: Assuming an audit = an audit document. 8002 "The <e1>company</e1> fabricates plastic <e2>chairs</e2>." Product-Producer(e2,e1) Comment: (a) is satisfied 8003 "The school <e1>master</e1> teaches the lesson with a <e2>stick</e2>." Instrument-Agency(e2,e1) Comment: 8004 "The suspect dumped the dead <e1>body</e1> into a local <e2>reservoir</e2>." Entity-Destination(e1,e2) Comment: 8005 "Avian <e1>influenza</e1> is an infectious disease of birds caused by type A strains of the influenza <e2>virus</e2>." Cause-Effect(e2,e1) Comment: 8006 "The <e1>ear</e1> of the African <e2>elephant</e2> is significantly larger--measuring 183 cm by 114 cm in the bush elephant." Component-Whole(e1,e2) Comment: 8007 "A child is told a <e1>lie</e1> for several years by their <e2>parents</e2> before he/she realizes that a Santa Claus does not exist." Product-Producer(e1,e2) Comment: (a) is satisfied; negation is outside 8008 "Skype, a free software, allows a <e1>hookup</e1> of multiple computer <e2>users</e2> to join in an online conference call without incurring any telephone costs." Member-Collection(e2,e1) Comment: 8009 "The disgusting scene was retaliation against her brother Philip who rents the <e1>room</e1> inside this apartment <e2>house</e2> on Lombard street." Component-Whole(e1,e2) Comment: 8010 "This <e1>thesis</e1> defines the <e2>clinical characteristics</e2> of amyloid disease." Message-Topic(e1,e2) Comment: may be we could leave clinical out of e2.
preprocess.py
#!/usr/bin/env python # -*- encoding: utf-8 -*- # @Version : Python 3.6 import json import re from nltk.tokenize import word_tokenize def search_entity(sentence): e1 = re.findall(r'<e1>(.*)</e1>', sentence)[0] e2 = re.findall(r'<e2>(.*)</e2>', sentence)[0] sentence = sentence.replace('<e1>' + e1 + '</e1>', ' <e1> ' + e1 + ' </e1> ', 1) sentence = sentence.replace('<e2>' + e2 + '</e2>', ' <e2> ' + e2 + ' </e2> ', 1) sentence = word_tokenize(sentence) sentence = ' '.join(sentence) sentence = sentence.replace('< e1 >', '<e1>') sentence = sentence.replace('< e2 >', '<e2>') sentence = sentence.replace('< /e1 >', '</e1>') sentence = sentence.replace('< /e2 >', '</e2>') sentence = sentence.split() assert '<e1>' in sentence assert '<e2>' in sentence assert '</e1>' in sentence assert '</e2>' in sentence return sentence def convert(path_src, path_des): with open(path_src, 'r', encoding='utf-8') as fr: data = fr.readlines() with open(path_des, 'w', encoding='utf-8') as fw: for i in range(0, len(data), 4): id_s, sentence = data[i].strip().split('\t') sentence = sentence[1:-1] sentence = search_entity(sentence) meta = dict( id=id_s, relation=data[i+1].strip(), sentence=sentence, comment=data[i+2].strip()[8:] ) json.dump(meta, fw, ensure_ascii=False) fw.write('\n') if __name__ == '__main__': path_train = './SemEval2010_task8_all_data/SemEval2010_task8_training/TRAIN_FILE.TXT' path_test = './SemEval2010_task8_all_data/SemEval2010_task8_testing_keys/TEST_FILE_FULL.TXT' convert(path_train, 'train.json') convert(path_test, 'test.json')
train.json
{"id": "1", "relation": "Component-Whole(e2,e1)", "sentence": ["The", "system", "as", "described", "above", "has", "its", "greatest", "application", "in", "an", "arrayed", "<e1>", "configuration", "</e1>", "of", "antenna", "<e2>", "elements", "</e2>", "."], "comment": " Not a collection: there is structure here, organisation."}
{"id": "2", "relation": "Other", "sentence": ["The", "<e1>", "child", "</e1>", "was", "carefully", "wrapped", "and", "bound", "into", "the", "<e2>", "cradle", "</e2>", "by", "means", "of", "a", "cord", "."], "comment": ""}
{"id": "3", "relation": "Instrument-Agency(e2,e1)", "sentence": ["The", "<e1>", "author", "</e1>", "of", "a", "keygen", "uses", "a", "<e2>", "disassembler", "</e2>", "to", "look", "at", "the", "raw", "assembly", "code", "."], "comment": ""}
{"id": "4", "relation": "Other", "sentence": ["A", "misty", "<e1>", "ridge", "</e1>", "uprises", "from", "the", "<e2>", "surge", "</e2>", "."], "comment": ""}
{"id": "5", "relation": "Member-Collection(e1,e2)", "sentence": ["The", "<e1>", "student", "</e1>", "<e2>", "association", "</e2>", "is", "the", "voice", "of", "the", "undergraduate", "student", "population", "of", "the", "State", "University", "of", "New", "York", "at", "Buffalo", "."], "comment": ""}
......
test.json
{"id": "8001", "relation": "Message-Topic(e1,e2)", "sentence": ["The", "most", "common", "<e1>", "audits", "</e1>", "were", "about", "<e2>", "waste", "</e2>", "and", "recycling", "."], "comment": " Assuming an audit = an audit document."}
{"id": "8002", "relation": "Product-Producer(e2,e1)", "sentence": ["The", "<e1>", "company", "</e1>", "fabricates", "plastic", "<e2>", "chairs", "</e2>", "."], "comment": " (a) is satisfied"}
{"id": "8003", "relation": "Instrument-Agency(e2,e1)", "sentence": ["The", "school", "<e1>", "master", "</e1>", "teaches", "the", "lesson", "with", "a", "<e2>", "stick", "</e2>", "."], "comment": ""}
{"id": "8004", "relation": "Entity-Destination(e1,e2)", "sentence": ["The", "suspect", "dumped", "the", "dead", "<e1>", "body", "</e1>", "into", "a", "local", "<e2>", "reservoir", "</e2>", "."], "comment": ""}
{"id": "8005", "relation": "Cause-Effect(e2,e1)", "sentence": ["Avian", "<e1>", "influenza", "</e1>", "is", "an", "infectious", "disease", "of", "birds", "caused", "by", "type", "A", "strains", "of", "the", "influenza", "<e2>", "virus", "</e2>", "."], "comment": ""}
......
Other 0 Cause-Effect(e1,e2) 1 Cause-Effect(e2,e1) 2 Component-Whole(e1,e2) 3 Component-Whole(e2,e1) 4 Content-Container(e1,e2) 5 Content-Container(e2,e1) 6 Entity-Destination(e1,e2) 7 Entity-Destination(e2,e1) 8 Entity-Origin(e1,e2) 9 Entity-Origin(e2,e1) 10 Instrument-Agency(e1,e2) 11 Instrument-Agency(e2,e1) 12 Member-Collection(e1,e2) 13 Member-Collection(e2,e1) 14 Message-Topic(e1,e2) 15 Message-Topic(e2,e1) 16 Product-Producer(e1,e2) 17 Product-Producer(e2,e1) 18
hlbl-embeddings-scaled.EMBEDDING_SIZE=50
*UNKNOWN* -0.166038776479 0.104395984608 0.163119732357 0.0899594154863 -0.0192271099805 -0.0417631572501 -0.0163376687927 0.0357616216019 0.0536077591673 0.0127688536503 -0.00284508433021 -0.0626207031228 -0.0379452734015 -0.103548297666 0.0381169119981 0.00199421074321 -0.0474636488659 -0.0127526851513 0.016404178535 -0.12759853361 -0.0292937037717 -0.0512566352549 0.0233097445983 0.0360505083995 0.00229317984472 -0.0771565284227 0.0071461584378 -0.051608090196 -0.0267547654304 0.0492994451068 -0.0531630844999 0.00787191810391 0.082280106873 0.066908641868 -0.0283930612982 0.216840166248 0.164923151267 0.00188498983723 0.0328679039324 -0.00175432516758 0.0614261774935 0.0987773071377 0.0548423375506 -0.0307057922059 0.053074241476 0.04982054279 -0.0572485864016 0.132236444766 -0.0379717035014 -0.120915939814
the -0.0841015569168 0.145263825738 0.116945121935 -0.0754618634155 0.17901499611 -0.000652852605208 -0.0713783879233 0.207273704502 0.060711721477 0.0366727701165 -0.0269791566731 -0.156993473526 -0.0393947453024 0.00749161628231 -0.332851634057 -0.1708430781 -0.275163605231 -0.266592614101 0.43349041466 -0.00779248211778 0.031101796379 -0.0257114150838 0.174856713352 -0.0543054233622 -0.0846669459476 -0.006234398456 0.00414488584462 0.119738648443 -0.0914876936952 -0.317381121871 -0.27471439742 0.234269597998 0.170305945138 -0.0282815073325 -0.10127814458 0.156451476203 0.154703520781 -0.0014827085612 0.164287521114 0.0328582913203 0.0356570354049 -0.190254406793 -0.112029936115 -0.198875312619 0.00102875631152 -0.00161517169984 -0.125210890327 0.196903181061 -0.112017915766 -0.00838804375065
. -0.0875932389444 -0.0586365253633 0.0729727126603 0.32072000431 0.0745620569276 -0.0494709138174 0.208708067552 -0.025035364294 -0.197531050237 0.177318202028 0.297077745222 -0.0256369072571 0.182364658364 0.189089099105 0.0589179494006 -0.0627276310572 0.0682898379459 0.241161712515 0.253510796291 -0.0325139691451 -0.0129081882483 -0.083367340352 0.0276167362372 -0.00757124183183 -0.0905801885623 0.305015208385 0.0755474920504 -0.00516459185438 -0.0412876867803 0.105047372601 -0.718674456034 0.184682477295 0.232732814491 0.0929975692214 0.0999329447708 -0.0968008990987 0.421525505372 -0.136460066398 -0.323294448817 0.118318915141 0.415411774103 -0.135770867168 0.0404792691614 0.264279769529 -0.133076243622 0.195087919022 -0.087589323012 0.0335223022065 -0.0365650611956 -0.0163760300203
, -0.023019838485 0.277215570968 0.241932261453 -0.105403438907 0.247316949736 0.0859618436243 -0.0130132156599 0.123988163629 -0.150741462418 0.129993766762 0.0766431623839 0.0547135456598 0.187342182554 0.176303102861 -0.121401723217 0.0458278230666 0.0339804870854 -0.0619606057248 0.0514787739809 0.00732501266557 0.0879996990484 -0.369288823679 0.235222707122 -0.0528783055204 0.0121891472663 -0.165169815904 -0.136829953355 -0.0750751223049 -0.0503433833321 0.0782539868365 -0.400940778018 -0.099745222007 -0.152448498545 -0.0815002789835 -0.010575616616 0.331604536668 -0.0124179474775 0.00173559407939 -0.230971231526 0.0162523457081 0.213848645598 0.184698023693 0.158368229826 0.0975422545404 -0.0307127563081 0.093420146492 -0.0377856184872 -0.0181716170654 0.43322993915 -0.113289957059
to 0.134693667961 0.392203653086 0.0346151199225 0.135354475458 0.0719918082372 0.118667933013 -0.0698386234679 -0.0139927084407 0.144452931939 0.0383223273458 -0.0491954394553 -0.126435975874 0.23979196724 -0.186550477314 0.0602616605691 -0.0875395769807 0.0788848675161 0.132691898026 0.155618778336 0.00680378469567 -0.126513561203 -0.436124771467 0.132675129426 -0.0946286638801 0.0986847070674 -0.354397304845 -0.196909463175 -0.0911408611189 0.134975690877 0.0625931974859 0.0108112360985 -0.107933544401 -0.166545488854 0.0137397678012 -0.0268394211932 -0.260328038765 0.0745185746772 0.020864049205 0.133485534344 -0.0479098207297 0.145382061477 -0.116284346216 0.0822848147919 -0.00621959258902 0.0135679910959 -0.0723116375013 -0.422793539068 0.144456402991 -0.119019192402 0.0659297394103
......
#!/usr/bin/env python # -*- encoding: utf-8 -*- # @Version : Python 3.6 import argparse import torch import os import random import json import numpy as np class Config(object): def __init__(self): # get init config args = self.__get_config() for key in args.__dict__: setattr(self, key, args.__dict__[key]) # select device self.device = None if self.cuda >= 0 and torch.cuda.is_available(): self.device = torch.device('cuda:{}'.format(self.cuda)) else: self.device = torch.device('cpu') # determine the model name and model dir if self.model_name is None: self.model_name = 'Att_BLSTM' self.model_dir = os.path.join(self.output_dir, self.model_name) if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) # backup data self.__config_backup(args) # set the random seed self.__set_seed(self.seed) def __get_config(self): parser = argparse.ArgumentParser() parser.description = 'config for models' # several key selective parameters parser.add_argument('--data_dir', type=str, default='./data', help='dir to load data') parser.add_argument('--output_dir', type=str, default='./output', help='dir to save output') # word embedding parser.add_argument('--embedding_path', type=str, default='./embedding/glove.6B.100d.txt', help='pre_trained word embedding') parser.add_argument('--word_dim', type=int, default=100, help='dimension of word embedding') # train settings parser.add_argument('--model_name', type=str, default=None, help='model name') parser.add_argument('--mode', type=int, default=1, choices=[0, 1], help='running mode: 1 for training; otherwise testing') parser.add_argument('--seed', type=int, default=5782, help='random seed') parser.add_argument('--cuda', type=int, default=0, help='num of gpu device, if -1, select cpu') parser.add_argument('--epoch', type=int, default=30, help='max epoches during training') # hyper parameters parser.add_argument('--batch_size', type=int, default=10, help='batch size') parser.add_argument('--lr', type=float, default=1.0, help='learning rate') parser.add_argument('--max_len', type=int, default=100, help='max length of sentence') parser.add_argument('--emb_dropout', type=float, default=0.3, help='the possiblity of dropout in embedding layer') parser.add_argument('--lstm_dropout', type=float, default=0.3, help='the possiblity of dropout in (Bi)LSTM layer') parser.add_argument('--linear_dropout', type=float, default=0.5, help='the possiblity of dropout in liner layer') parser.add_argument('--hidden_size', type=int, default=100, help='the dimension of hidden units in (Bi)LSTM layer') parser.add_argument('--layers_num', type=int, default=1, help='num of RNN layers') parser.add_argument('--L2_decay', type=float, default=1e-5, help='L2 weight decay') args = parser.parse_args() return args def __set_seed(self, seed=1234): os.environ['PYTHONHASHSEED'] = '{}'.format(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # set seed for cpu torch.cuda.manual_seed(seed) # set seed for current gpu torch.cuda.manual_seed_all(seed) # set seed for all gpu def __config_backup(self, args): config_backup_path = os.path.join(self.model_dir, 'config.json') with open(config_backup_path, 'w', encoding='utf-8') as fw: json.dump(vars(args), fw, ensure_ascii=False) def print_config(self): for key in self.__dict__: print(key, end=' = ') print(self.__dict__[key]) if __name__ == '__main__': config = Config() config.print_config()
#!/usr/bin/env python # -*- encoding: utf-8 -*- # @Version : Python 3.6 import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import init from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence class Att_BLSTM(nn.Module): def __init__(self, word_vec, class_num, config): super().__init__() self.word_vec = word_vec self.class_num = class_num # hyper parameters and others self.max_len = config.max_len self.word_dim = config.word_dim self.hidden_size = config.hidden_size self.layers_num = config.layers_num self.emb_dropout_value = config.emb_dropout self.lstm_dropout_value = config.lstm_dropout self.linear_dropout_value = config.linear_dropout # net structures and operations self.word_embedding = nn.Embedding.from_pretrained( embeddings=self.word_vec, freeze=False, ) self.lstm = nn.LSTM( input_size=self.word_dim, hidden_size=self.hidden_size, num_layers=self.layers_num, bias=True, batch_first=True, dropout=0, bidirectional=True, ) self.tanh = nn.Tanh() self.emb_dropout = nn.Dropout(self.emb_dropout_value) self.lstm_dropout = nn.Dropout(self.lstm_dropout_value) self.linear_dropout = nn.Dropout(self.linear_dropout_value) self.att_weight = nn.Parameter(torch.randn(1, self.hidden_size, 1)) self.dense = nn.Linear( in_features=self.hidden_size, out_features=self.class_num, bias=True ) # initialize weight init.xavier_normal_(self.dense.weight) init.constant_(self.dense.bias, 0.) def lstm_layer(self, x, mask): lengths = torch.sum(mask.gt(0), dim=-1) x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) h, (_, _) = self.lstm(x) h, _ = pad_packed_sequence(h, batch_first=True, padding_value=0.0, total_length=self.max_len) h = h.view(-1, self.max_len, 2, self.hidden_size) h = torch.sum(h, dim=2) # B*L*H return h def attention_layer(self, h, mask): att_weight = self.att_weight.expand(mask.shape[0], -1, -1) # B*H*1 att_score = torch.bmm(self.tanh(h), att_weight) # B*L*H * B*H*1 -> B*L*1 # mask, remove the effect of 'PAD' mask = mask.unsqueeze(dim=-1) # B*L*1 att_score = att_score.masked_fill(mask.eq(0), float('-inf')) # B*L*1 att_weight = F.softmax(att_score, dim=1) # B*L*1 reps = torch.bmm(h.transpose(1, 2), att_weight).squeeze(dim=-1) # B*H*L * B*L*1 -> B*H*1 -> B*H reps = self.tanh(reps) # B*H return reps def forward(self, data): token = data[:, 0, :].view(-1, self.max_len) mask = data[:, 1, :].view(-1, self.max_len) emb = self.word_embedding(token) # B*L*word_dim emb = self.emb_dropout(emb) h = self.lstm_layer(emb, mask) # B*L*H h = self.lstm_dropout(h) reps = self.attention_layer(h, mask) # B*reps reps = self.linear_dropout(reps) logits = self.dense(reps) return logits
#!/usr/bin/env python # -*- encoding: utf-8 -*- # @Version : Python 3.6 import os import torch import torch.nn as nn import torch.optim as optim from config import Config from utils import WordEmbeddingLoader, RelationLoader, SemEvalDataLoader from model import Att_BLSTM from evaluate import Eval def print_result(predict_label, id2rel, start_idx=8001): with open('predicted_result.txt', 'w', encoding='utf-8') as fw: for i in range(0, predict_label.shape[0]): fw.write('{}\t{}\n'.format(start_idx+i, id2rel[int(predict_label[i])])) def train(model, criterion, loader, config): train_loader, dev_loader, _ = loader optimizer = optim.Adadelta(model.parameters(), lr=config.lr, weight_decay=config.L2_decay) print(model) print('traning model parameters:') for name, param in model.named_parameters(): if param.requires_grad: print('%s : %s' % (name, str(param.data.shape))) print('--------------------------------------') print('start to train the model ...') eval_tool = Eval(config) min_f1 = -float('inf') for epoch in range(1, config.epoch+1): for step, (data, label) in enumerate(train_loader): model.train() data = data.to(config.device) label = label.to(config.device) optimizer.zero_grad() logits = model(data) loss = criterion(logits, label) loss.backward() nn.utils.clip_grad_value_(model.parameters(), clip_value=5) optimizer.step() _, train_loss, _ = eval_tool.evaluate(model, criterion, train_loader) f1, dev_loss, _ = eval_tool.evaluate(model, criterion, dev_loader) print('[%03d] train_loss: %.3f | dev_loss: %.3f | micro f1 on dev: %.4f' % (epoch, train_loss, dev_loss, f1), end=' ') if f1 > min_f1: min_f1 = f1 torch.save(model.state_dict(), os.path.join(config.model_dir, 'model.pkl')) print('>>> save models!') else: print() def test(model, criterion, loader, config): print('--------------------------------------') print('start test ...') _, _, test_loader = loader model.load_state_dict(torch.load(os.path.join(config.model_dir, 'model.pkl'))) eval_tool = Eval(config) f1, test_loss, predict_label = eval_tool.evaluate(model, criterion, test_loader) print('test_loss: %.3f | micro f1 on test: %.4f' % (test_loss, f1)) return predict_label if __name__ == '__main__': config = Config() print('--------------------------------------') print('some config:') config.print_config() print('--------------------------------------') print('start to load data ...') word2id, word_vec = WordEmbeddingLoader(config).load_embedding() rel2id, id2rel, class_num = RelationLoader(config).get_relation() loader = SemEvalDataLoader(rel2id, word2id, config) train_loader, dev_loader = None, None if config.mode == 1: # train mode train_loader = loader.get_train() dev_loader = loader.get_dev() test_loader = loader.get_test() loader = [train_loader, dev_loader, test_loader] print('finish!') print('--------------------------------------') model = Att_BLSTM(word_vec=word_vec, class_num=class_num, config=config) model = model.to(config.device) criterion = nn.CrossEntropyLoss() if config.mode == 1: # train mode train(model, criterion, loader, config) predict_label = test(model, criterion, loader, config) print_result(predict_label, id2rel)
#!/usr/bin/env python # -*- encoding: utf-8 -*- # @Version : Python 3.6 import numpy as np import torch def semeval_scorer(predict_label, true_label, class_num=10): import math assert true_label.shape[0] == predict_label.shape[0] confusion_matrix = np.zeros(shape=[class_num, class_num], dtype=np.float32) xDIRx = np.zeros(shape=[class_num], dtype=np.float32) for i in range(true_label.shape[0]): true_idx = math.ceil(true_label[i]/2) predict_idx = math.ceil(predict_label[i]/2) if true_label[i] == predict_label[i]: confusion_matrix[predict_idx][true_idx] += 1 else: if true_idx == predict_idx: xDIRx[predict_idx] += 1 else: confusion_matrix[predict_idx][true_idx] += 1 col_sum = np.sum(confusion_matrix, axis=0).reshape(-1) row_sum = np.sum(confusion_matrix, axis=1).reshape(-1) f1 = np.zeros(shape=[class_num], dtype=np.float32) for i in range(0, class_num): # ignore the 'Other' try: p = float(confusion_matrix[i][i]) / float(col_sum[i] + xDIRx[i]) r = float(confusion_matrix[i][i]) / float(row_sum[i] + xDIRx[i]) f1[i] = (2 * p * r / (p + r)) except: pass actual_class = 0 total_f1 = 0.0 for i in range(1, class_num): if f1[i] > 0.0: # classes that not in the predict label are not considered actual_class += 1 total_f1 += f1[i] try: macro_f1 = total_f1 / actual_class except: macro_f1 = 0.0 return macro_f1 class Eval(object): def __init__(self, config): self.device = config.device def evaluate(self, model, criterion, data_loader): predict_label = [] true_label = [] total_loss = 0.0 with torch.no_grad(): model.eval() for _, (data, label) in enumerate(data_loader): data = data.to(self.device) label = label.to(self.device) logits = model(data) loss = criterion(logits, label) total_loss += loss.item() * logits.shape[0] _, pred = torch.max(logits, dim=1) # replace softmax with max function, same impacts pred = pred.cpu().detach().numpy().reshape((-1, 1)) label = label.cpu().detach().numpy().reshape((-1, 1)) predict_label.append(pred) true_label.append(label) predict_label = np.concatenate(predict_label, axis=0).reshape(-1).astype(np.int64) true_label = np.concatenate(true_label, axis=0).reshape(-1).astype(np.int64) eval_loss = total_loss / predict_label.shape[0] f1 = semeval_scorer(predict_label, true_label) return f1, eval_loss, predict_label
#!/usr/bin/env python # -*- encoding: utf-8 -*- # @Version : Python 3.6 import os import json import torch import numpy as np from torch.utils.data import Dataset, DataLoader class WordEmbeddingLoader(object): """ A loader for pre-trained word embedding """ def __init__(self, config): self.path_word = config.embedding_path # path of pre-trained word embedding self.word_dim = config.word_dim # dimension of word embedding def load_embedding(self): word2id = dict() # word to wordID word_vec = list() # wordID to word embedding word2id['PAD'] = len(word2id) # PAD character word2id['UNK'] = len(word2id) # out of vocabulary word2id['<e1>'] = len(word2id) word2id['<e2>'] = len(word2id) word2id['</e1>'] = len(word2id) word2id['</e2>'] = len(word2id) with open(self.path_word, 'r', encoding='utf-8') as fr: for line in fr: line = line.strip().split() if len(line) != self.word_dim + 1: continue word2id[line[0]] = len(word2id) word_vec.append(np.asarray(line[1:], dtype=np.float32)) word_vec = np.stack(word_vec) vec_mean, vec_std = word_vec.mean(), word_vec.std() special_emb = np.random.normal(vec_mean, vec_std, (6, self.word_dim)) special_emb[0] = 0 # <pad> is initialize as zero word_vec = np.concatenate((special_emb, word_vec), axis=0) word_vec = word_vec.astype(np.float32).reshape(-1, self.word_dim) word_vec = torch.from_numpy(word_vec) return word2id, word_vec class RelationLoader(object): def __init__(self, config): self.data_dir = config.data_dir def __load_relation(self): relation_file = os.path.join(self.data_dir, 'relation2id.txt') rel2id = {} id2rel = {} with open(relation_file, 'r', encoding='utf-8') as fr: for line in fr: relation, id_s = line.strip().split() id_d = int(id_s) rel2id[relation] = id_d id2rel[id_d] = relation return rel2id, id2rel, len(rel2id) def get_relation(self): return self.__load_relation() class SemEvalDateset(Dataset): def __init__(self, filename, rel2id, word2id, config): self.filename = filename self.rel2id = rel2id self.word2id = word2id self.max_len = config.max_len self.data_dir = config.data_dir self.dataset, self.label = self.__load_data() def __symbolize_sentence(self, sentence): """ Args: sentence (list) """ mask = [1] * len(sentence) words = [] length = min(self.max_len, len(sentence)) mask = mask[:length] for i in range(length): words.append(self.word2id.get(sentence[i].lower(), self.word2id['UNK'])) if length < self.max_len: for i in range(length, self.max_len): mask.append(0) # 'PAD' mask is zero words.append(self.word2id['PAD']) unit = np.asarray([words, mask], dtype=np.int64) unit = np.reshape(unit, newshape=(1, 2, self.max_len)) return unit def __load_data(self): path_data_file = os.path.join(self.data_dir, self.filename) data = [] labels = [] with open(path_data_file, 'r', encoding='utf-8') as fr: for line in fr: line = json.loads(line.strip()) label = line['relation'] sentence = line['sentence'] label_idx = self.rel2id[label] one_sentence = self.__symbolize_sentence(sentence) data.append(one_sentence) labels.append(label_idx) return data, labels def __getitem__(self, index): data = self.dataset[index] label = self.label[index] return data, label def __len__(self): return len(self.label) class SemEvalDataLoader(object): def __init__(self, rel2id, word2id, config): self.rel2id = rel2id self.word2id = word2id self.config = config def __collate_fn(self, batch): data, label = zip(*batch) # unzip the batch data data = list(data) label = list(label) data = torch.from_numpy(np.concatenate(data, axis=0)) label = torch.from_numpy(np.asarray(label, dtype=np.int64)) return data, label def __get_data(self, filename, shuffle=False): dataset = SemEvalDateset(filename, self.rel2id, self.word2id, self.config) loader = DataLoader( dataset=dataset, batch_size=self.config.batch_size, shuffle=shuffle, num_workers=2, collate_fn=self.__collate_fn ) return loader def get_train(self): return self.__get_data('train.json', shuffle=True) def get_dev(self): return self.__get_data('test.json', shuffle=False) def get_test(self): return self.__get_data('test.json', shuffle=False) if __name__ == '__main__': from config import Config config = Config() word2id, word_vec = WordEmbeddingLoader(config).load_embedding() rel2id, id2rel, class_num = RelationLoader(config).get_relation() loader = SemEvalDataLoader(rel2id, word2id, config) test_loader = loader.get_train() for step, (data, label) in enumerate(test_loader): print(type(data), data.shape) print(type(label), label.shape) break
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。