当前位置:   article > 正文

NLP-信息抽取-关系抽取-2016:Attention-BiLSTM实体关系分类器【基于双向LSTM及注意力机制的关系分类】【数据集:SemEval-2010 Task 8】_bilstm关系抽取模型有哪些

bilstm关系抽取模型有哪些

《原始论文:Attention-based bidirectional long short-term memory networks for relation classification》

在这里插入图片描述

一、概述

1、本文idea提出原因

传统的方法中,大多数研究依赖于一些现有的词汇资源(例如WordNet)、NLP系 统或一些手工提取的特征。这样的方法可能导致计算复杂度的增加,并且特征提取工作本身会耗费大量的时间和精力,特征提取质量的对于实验的结果也有很大的影响。

提出了 ATT-BLSTM的网络结构解决关系端对端识别问题

这篇论文从这一角度出发,提出一个基于Attention机制的双向 LSTM神经网络模型进行关系抽取研究,Attention机制能够自动 发现那些对于分类起到关键作用的词,使得这个模型可以从每个句子中捕获最重要的语义信息,它不依赖于任何外部的知识或者NLP系统

2、本论文历史意义

巧妙地在双向LSTM模型中加入Attention机制,用于关系抽取任务,避免了传统的 任务中复杂的特征工程,大大简化了实验过程并得到相当不错的结果,也为相关的研究提供了可操作性的思路

这篇论文的整体的逻辑十分清晰,紧紧围绕研究动机.整篇论文的思路十分简单,模型也一目了然,但是结果表现优秀

3、摘要核心

  1. 目前关系识别依赖于Mp工具提取特征;
  2. 提出一种不需要复杂预处理的关系识别方法att-blstm;
  3. 实验结果表明该方法是有效的,达到the state-of-the-art的效果

二、Attention-BiLSTM模型结构

1、模型结构

在这里插入图片描述
ATT-BLSTM网络结构以word embeding为基础,加入实体标识位,通过ATT-BLSTM的结构让模型动态区分关系分类的重要词汇。
As shown in Figure 1, the model proposed in this paper contains five components:

  1. 输入句子:Input layer: input sentence to this model;
  2. Embedding layer: map each word into a low dimension vector;
  3. BiLSTM:LSTM layer: utilize BLSTM to get high level features from step (2);
  4. Attention layer: produce a weight vector, and merge word-level features from each time step into a sentence-level feature vector, by multiplying the weight vector;
  5. Output layer: the sentence-level feature vec- tor is finally used for relation classification.

2、Attention 原理

Attention 原理:Attention Mechanism可以帮助模型对输入的X每个部分赋予不同的权重,抽取出更加关键及重要的信息,使模型做出更加准确的判断,同时不会对模型的计算和存储带来更大的开销。
在这里插入图片描述
在这里插入图片描述
根据Attention的计算区域,可以分成以下几种:

  1. Soft-Attention/Global Attention:这是比较常见的Attention方式,对所有key求权重概率,每个key都有一个对应的权重,是一种全局的计算方式(也可以叫Global Attention).
  2. Hard-Attention:这种方式是直接精准定位到某个key,其余key就都不管了,相当于这个key的 概率是1 ,其余key的概率全部是0。因此这种对齐方式要求很高,要求一步到位,如果没有正确对齐, 会带来很大的影响。另一方面,因为不可导,一般需要用强化学习的方法进行训练
  3. Local-Attention:这种方式其实是以上两种方式的一个折中,对一个窗口区域进行计算。先用 Hard方式定位到某个地方,以这个点为中心可以得到一个窗口区域,在这个小区域内用Soft方式来
    算 Attention。

3、小技巧

对实体前后添加特定标识符标明实体位置
在这里插入图片描述

采用带约束的正则损失
在这里插入图片描述

三、实验结果

compare various model configurations on the SemEval-2010 Task 8 dataset
在这里插入图片描述

四、论文结论

1、关键点

不依赖任何其他NLP工具

2、创新点

引入Attention-BiLSTM结构

3、启发点

网格结构完全不依何nlp工具或词法资源,只需要带位置标识的原始文本作为输入。

This model does not rely on NLP tools or lexical resources to get, it uses raw text with position indicators as input.

五、论文代码

1、数据集

1.1 原始数据集

train_file.txt【样本1-8000】

1	"The system as described above has its greatest application in an arrayed <e1>configuration</e1> of antenna <e2>elements</e2>."
Component-Whole(e2,e1)
Comment: Not a collection: there is structure here, organisation.

2	"The <e1>child</e1> was carefully wrapped and bound into the <e2>cradle</e2> by means of a cord."
Other
Comment:

3	"The <e1>author</e1> of a keygen uses a <e2>disassembler</e2> to look at the raw assembly code."
Instrument-Agency(e2,e1)
Comment:

4	"A misty <e1>ridge</e1> uprises from the <e2>surge</e2>."
Other
Comment:

5	"The <e1>student</e1> <e2>association</e2> is the voice of the undergraduate student population of the State University of New York at Buffalo."
Member-Collection(e1,e2)
Comment:

6	"This is the sprawling <e1>complex</e1> that is Peru's largest <e2>producer</e2> of silver."
Other
Comment:

7	"The current view is that the chronic <e1>inflammation</e1> in the distal part of the stomach caused by Helicobacter pylori <e2>infection</e2> results in an increased acid production from the non-infected upper corpus region of the stomach."
Cause-Effect(e2,e1)
Comment:

8	"<e1>People</e1> have been moving back into <e2>downtown</e2>."
Entity-Destination(e1,e2)
Comment:

9	"The <e1>lawsonite</e1> was contained in a <e2>platinum crucible</e2> and the counter-weight was a plastic crucible with metal pieces."
Content-Container(e1,e2)
Comment: prototypical example

10	"The solute was placed inside a beaker and 5 mL of the <e1>solvent</e1> was pipetted into a 25 mL glass <e2>flask</e2> for each trial."
Entity-Destination(e1,e2)
Comment:
......
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

test_file.txt【样本8001-10717】

8001	"The most common <e1>audits</e1> were about <e2>waste</e2> and recycling."
Message-Topic(e1,e2)
Comment: Assuming an audit = an audit document.

8002	"The <e1>company</e1> fabricates plastic <e2>chairs</e2>."
Product-Producer(e2,e1)
Comment: (a) is satisfied

8003	"The school <e1>master</e1> teaches the lesson with a <e2>stick</e2>."
Instrument-Agency(e2,e1)
Comment:

8004	"The suspect dumped the dead <e1>body</e1> into a local <e2>reservoir</e2>."
Entity-Destination(e1,e2)
Comment:

8005	"Avian <e1>influenza</e1> is an infectious disease of birds caused by type A strains of the influenza <e2>virus</e2>."
Cause-Effect(e2,e1)
Comment:

8006	"The <e1>ear</e1> of the African <e2>elephant</e2> is significantly larger--measuring 183 cm by 114 cm in the bush elephant."
Component-Whole(e1,e2)
Comment:

8007	"A child is told a <e1>lie</e1> for several years by their <e2>parents</e2> before he/she realizes that a Santa Claus does not exist."
Product-Producer(e1,e2)
Comment: (a) is satisfied; negation is outside

8008	"Skype, a free software, allows a <e1>hookup</e1> of multiple computer <e2>users</e2> to join in an online conference call without incurring any telephone costs."
Member-Collection(e2,e1)
Comment:

8009	"The disgusting scene was retaliation against her brother Philip who rents the <e1>room</e1> inside this apartment <e2>house</e2> on Lombard street."
Component-Whole(e1,e2)
Comment:

8010	"This <e1>thesis</e1> defines the <e2>clinical characteristics</e2> of amyloid disease."
Message-Topic(e1,e2)
Comment: may be we could leave clinical out of e2.
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39

1.2 处理后的数据

preprocess.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6

import json
import re
from nltk.tokenize import word_tokenize


def search_entity(sentence):
    e1 = re.findall(r'<e1>(.*)</e1>', sentence)[0]
    e2 = re.findall(r'<e2>(.*)</e2>', sentence)[0]
    sentence = sentence.replace('<e1>' + e1 + '</e1>', ' <e1> ' + e1 + ' </e1> ', 1)
    sentence = sentence.replace('<e2>' + e2 + '</e2>', ' <e2> ' + e2 + ' </e2> ', 1)
    sentence = word_tokenize(sentence)
    sentence = ' '.join(sentence)
    sentence = sentence.replace('< e1 >', '<e1>')
    sentence = sentence.replace('< e2 >', '<e2>')
    sentence = sentence.replace('< /e1 >', '</e1>')
    sentence = sentence.replace('< /e2 >', '</e2>')
    sentence = sentence.split()

    assert '<e1>' in sentence
    assert '<e2>' in sentence
    assert '</e1>' in sentence
    assert '</e2>' in sentence

    return sentence


def convert(path_src, path_des):
    with open(path_src, 'r', encoding='utf-8') as fr:
        data = fr.readlines()
    with open(path_des, 'w', encoding='utf-8') as fw:
        for i in range(0, len(data), 4):
            id_s, sentence = data[i].strip().split('\t')
            sentence = sentence[1:-1]
            sentence = search_entity(sentence)
            meta = dict(
                id=id_s,
                relation=data[i+1].strip(),
                sentence=sentence,
                comment=data[i+2].strip()[8:]
            )
            json.dump(meta, fw, ensure_ascii=False)
            fw.write('\n')


if __name__ == '__main__':
    path_train = './SemEval2010_task8_all_data/SemEval2010_task8_training/TRAIN_FILE.TXT'
    path_test = './SemEval2010_task8_all_data/SemEval2010_task8_testing_keys/TEST_FILE_FULL.TXT'

    convert(path_train, 'train.json')
    convert(path_test, 'test.json')

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55

train.json

{"id": "1", "relation": "Component-Whole(e2,e1)", "sentence": ["The", "system", "as", "described", "above", "has", "its", "greatest", "application", "in", "an", "arrayed", "<e1>", "configuration", "</e1>", "of", "antenna", "<e2>", "elements", "</e2>", "."], "comment": " Not a collection: there is structure here, organisation."}
{"id": "2", "relation": "Other", "sentence": ["The", "<e1>", "child", "</e1>", "was", "carefully", "wrapped", "and", "bound", "into", "the", "<e2>", "cradle", "</e2>", "by", "means", "of", "a", "cord", "."], "comment": ""}
{"id": "3", "relation": "Instrument-Agency(e2,e1)", "sentence": ["The", "<e1>", "author", "</e1>", "of", "a", "keygen", "uses", "a", "<e2>", "disassembler", "</e2>", "to", "look", "at", "the", "raw", "assembly", "code", "."], "comment": ""}
{"id": "4", "relation": "Other", "sentence": ["A", "misty", "<e1>", "ridge", "</e1>", "uprises", "from", "the", "<e2>", "surge", "</e2>", "."], "comment": ""}
{"id": "5", "relation": "Member-Collection(e1,e2)", "sentence": ["The", "<e1>", "student", "</e1>", "<e2>", "association", "</e2>", "is", "the", "voice", "of", "the", "undergraduate", "student", "population", "of", "the", "State", "University", "of", "New", "York", "at", "Buffalo", "."], "comment": ""}
......
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

test.json

{"id": "8001", "relation": "Message-Topic(e1,e2)", "sentence": ["The", "most", "common", "<e1>", "audits", "</e1>", "were", "about", "<e2>", "waste", "</e2>", "and", "recycling", "."], "comment": " Assuming an audit = an audit document."}
{"id": "8002", "relation": "Product-Producer(e2,e1)", "sentence": ["The", "<e1>", "company", "</e1>", "fabricates", "plastic", "<e2>", "chairs", "</e2>", "."], "comment": " (a) is satisfied"}
{"id": "8003", "relation": "Instrument-Agency(e2,e1)", "sentence": ["The", "school", "<e1>", "master", "</e1>", "teaches", "the", "lesson", "with", "a", "<e2>", "stick", "</e2>", "."], "comment": ""}
{"id": "8004", "relation": "Entity-Destination(e1,e2)", "sentence": ["The", "suspect", "dumped", "the", "dead", "<e1>", "body", "</e1>", "into", "a", "local", "<e2>", "reservoir", "</e2>", "."], "comment": ""}
{"id": "8005", "relation": "Cause-Effect(e2,e1)", "sentence": ["Avian", "<e1>", "influenza", "</e1>", "is", "an", "infectious", "disease", "of", "birds", "caused", "by", "type", "A", "strains", "of", "the", "influenza", "<e2>", "virus", "</e2>", "."], "comment": ""}
......
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

1.3 relation2id

Other	0
Cause-Effect(e1,e2)	1
Cause-Effect(e2,e1)	2
Component-Whole(e1,e2)	3
Component-Whole(e2,e1)	4
Content-Container(e1,e2)	5
Content-Container(e2,e1)	6
Entity-Destination(e1,e2)	7
Entity-Destination(e2,e1)	8
Entity-Origin(e1,e2)	9
Entity-Origin(e2,e1)	10
Instrument-Agency(e1,e2)	11
Instrument-Agency(e2,e1)	12
Member-Collection(e1,e2)	13
Member-Collection(e2,e1)	14
Message-Topic(e1,e2)	15
Message-Topic(e2,e1)	16
Product-Producer(e1,e2)	17
Product-Producer(e2,e1)	18
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

2、预训练词向量:静态词向量HLBL

hlbl-embeddings-scaled.EMBEDDING_SIZE=50

*UNKNOWN* -0.166038776479 0.104395984608 0.163119732357 0.0899594154863 -0.0192271099805 -0.0417631572501 -0.0163376687927 0.0357616216019 0.0536077591673 0.0127688536503 -0.00284508433021 -0.0626207031228 -0.0379452734015 -0.103548297666 0.0381169119981 0.00199421074321 -0.0474636488659 -0.0127526851513 0.016404178535 -0.12759853361 -0.0292937037717 -0.0512566352549 0.0233097445983 0.0360505083995 0.00229317984472 -0.0771565284227 0.0071461584378 -0.051608090196 -0.0267547654304 0.0492994451068 -0.0531630844999 0.00787191810391 0.082280106873 0.066908641868 -0.0283930612982 0.216840166248 0.164923151267 0.00188498983723 0.0328679039324 -0.00175432516758 0.0614261774935 0.0987773071377 0.0548423375506 -0.0307057922059 0.053074241476 0.04982054279 -0.0572485864016 0.132236444766 -0.0379717035014 -0.120915939814
the -0.0841015569168 0.145263825738 0.116945121935 -0.0754618634155 0.17901499611 -0.000652852605208 -0.0713783879233 0.207273704502 0.060711721477 0.0366727701165 -0.0269791566731 -0.156993473526 -0.0393947453024 0.00749161628231 -0.332851634057 -0.1708430781 -0.275163605231 -0.266592614101 0.43349041466 -0.00779248211778 0.031101796379 -0.0257114150838 0.174856713352 -0.0543054233622 -0.0846669459476 -0.006234398456 0.00414488584462 0.119738648443 -0.0914876936952 -0.317381121871 -0.27471439742 0.234269597998 0.170305945138 -0.0282815073325 -0.10127814458 0.156451476203 0.154703520781 -0.0014827085612 0.164287521114 0.0328582913203 0.0356570354049 -0.190254406793 -0.112029936115 -0.198875312619 0.00102875631152 -0.00161517169984 -0.125210890327 0.196903181061 -0.112017915766 -0.00838804375065
. -0.0875932389444 -0.0586365253633 0.0729727126603 0.32072000431 0.0745620569276 -0.0494709138174 0.208708067552 -0.025035364294 -0.197531050237 0.177318202028 0.297077745222 -0.0256369072571 0.182364658364 0.189089099105 0.0589179494006 -0.0627276310572 0.0682898379459 0.241161712515 0.253510796291 -0.0325139691451 -0.0129081882483 -0.083367340352 0.0276167362372 -0.00757124183183 -0.0905801885623 0.305015208385 0.0755474920504 -0.00516459185438 -0.0412876867803 0.105047372601 -0.718674456034 0.184682477295 0.232732814491 0.0929975692214 0.0999329447708 -0.0968008990987 0.421525505372 -0.136460066398 -0.323294448817 0.118318915141 0.415411774103 -0.135770867168 0.0404792691614 0.264279769529 -0.133076243622 0.195087919022 -0.087589323012 0.0335223022065 -0.0365650611956 -0.0163760300203
, -0.023019838485 0.277215570968 0.241932261453 -0.105403438907 0.247316949736 0.0859618436243 -0.0130132156599 0.123988163629 -0.150741462418 0.129993766762 0.0766431623839 0.0547135456598 0.187342182554 0.176303102861 -0.121401723217 0.0458278230666 0.0339804870854 -0.0619606057248 0.0514787739809 0.00732501266557 0.0879996990484 -0.369288823679 0.235222707122 -0.0528783055204 0.0121891472663 -0.165169815904 -0.136829953355 -0.0750751223049 -0.0503433833321 0.0782539868365 -0.400940778018 -0.099745222007 -0.152448498545 -0.0815002789835 -0.010575616616 0.331604536668 -0.0124179474775 0.00173559407939 -0.230971231526 0.0162523457081 0.213848645598 0.184698023693 0.158368229826 0.0975422545404 -0.0307127563081 0.093420146492 -0.0377856184872 -0.0181716170654 0.43322993915 -0.113289957059
to 0.134693667961 0.392203653086 0.0346151199225 0.135354475458 0.0719918082372 0.118667933013 -0.0698386234679 -0.0139927084407 0.144452931939 0.0383223273458 -0.0491954394553 -0.126435975874 0.23979196724 -0.186550477314 0.0602616605691 -0.0875395769807 0.0788848675161 0.132691898026 0.155618778336 0.00680378469567 -0.126513561203 -0.436124771467 0.132675129426 -0.0946286638801 0.0986847070674 -0.354397304845 -0.196909463175 -0.0911408611189 0.134975690877 0.0625931974859 0.0108112360985 -0.107933544401 -0.166545488854 0.0137397678012 -0.0268394211932 -0.260328038765 0.0745185746772 0.020864049205 0.133485534344 -0.0479098207297 0.145382061477 -0.116284346216 0.0822848147919 -0.00621959258902 0.0135679910959 -0.0723116375013 -0.422793539068 0.144456402991 -0.119019192402 0.0659297394103
......
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

3、config.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6

import argparse
import torch
import os
import random
import json
import numpy as np


class Config(object):
    def __init__(self):
        # get init config
        args = self.__get_config()
        for key in args.__dict__:
            setattr(self, key, args.__dict__[key])

        # select device
        self.device = None
        if self.cuda >= 0 and torch.cuda.is_available():
            self.device = torch.device('cuda:{}'.format(self.cuda))
        else:
            self.device = torch.device('cpu')

        # determine the model name and model dir
        if self.model_name is None:
            self.model_name = 'Att_BLSTM'
        self.model_dir = os.path.join(self.output_dir, self.model_name)
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

        # backup data
        self.__config_backup(args)

        # set the random seed
        self.__set_seed(self.seed)

    def __get_config(self):
        parser = argparse.ArgumentParser()
        parser.description = 'config for models'

        # several key selective parameters
        parser.add_argument('--data_dir', type=str,
                            default='./data',
                            help='dir to load data')
        parser.add_argument('--output_dir', type=str,
                            default='./output',
                            help='dir to save output')

        # word embedding
        parser.add_argument('--embedding_path', type=str,
                            default='./embedding/glove.6B.100d.txt',
                            help='pre_trained word embedding')
        parser.add_argument('--word_dim', type=int,
                            default=100,
                            help='dimension of word embedding')

        # train settings
        parser.add_argument('--model_name', type=str,
                            default=None,
                            help='model name')
        parser.add_argument('--mode', type=int,
                            default=1,
                            choices=[0, 1],
                            help='running mode: 1 for training; otherwise testing')
        parser.add_argument('--seed', type=int,
                            default=5782,
                            help='random seed')
        parser.add_argument('--cuda', type=int,
                            default=0,
                            help='num of gpu device, if -1, select cpu')
        parser.add_argument('--epoch', type=int,
                            default=30,
                            help='max epoches during training')

        # hyper parameters
        parser.add_argument('--batch_size', type=int,
                            default=10,
                            help='batch size')
        parser.add_argument('--lr', type=float,
                            default=1.0,
                            help='learning rate')
        parser.add_argument('--max_len', type=int,
                            default=100,
                            help='max length of sentence')

        parser.add_argument('--emb_dropout', type=float,
                            default=0.3,
                            help='the possiblity of dropout in embedding layer')
        parser.add_argument('--lstm_dropout', type=float,
                            default=0.3,
                            help='the possiblity of dropout in (Bi)LSTM layer')
        parser.add_argument('--linear_dropout', type=float,
                            default=0.5,
                            help='the possiblity of dropout in liner layer')
        parser.add_argument('--hidden_size', type=int,
                            default=100,
                            help='the dimension of hidden units in (Bi)LSTM layer')
        parser.add_argument('--layers_num', type=int,
                            default=1,
                            help='num of RNN layers')

        parser.add_argument('--L2_decay', type=float, default=1e-5,
                            help='L2 weight decay')

        args = parser.parse_args()
        return args

    def __set_seed(self, seed=1234):
        os.environ['PYTHONHASHSEED'] = '{}'.format(seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)  # set seed for cpu
        torch.cuda.manual_seed(seed)  # set seed for current gpu
        torch.cuda.manual_seed_all(seed)  # set seed for all gpu

    def __config_backup(self, args):
        config_backup_path = os.path.join(self.model_dir, 'config.json')
        with open(config_backup_path, 'w', encoding='utf-8') as fw:
            json.dump(vars(args), fw, ensure_ascii=False)

    def print_config(self):
        for key in self.__dict__:
            print(key, end=' = ')
            print(self.__dict__[key])


if __name__ == '__main__':
    config = Config()
    config.print_config()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133

4、model.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


class Att_BLSTM(nn.Module):
    def __init__(self, word_vec, class_num, config):
        super().__init__()
        self.word_vec = word_vec
        self.class_num = class_num

        # hyper parameters and others
        self.max_len = config.max_len
        self.word_dim = config.word_dim
        self.hidden_size = config.hidden_size
        self.layers_num = config.layers_num
        self.emb_dropout_value = config.emb_dropout
        self.lstm_dropout_value = config.lstm_dropout
        self.linear_dropout_value = config.linear_dropout

        # net structures and operations
        self.word_embedding = nn.Embedding.from_pretrained(
            embeddings=self.word_vec,
            freeze=False,
        )
        self.lstm = nn.LSTM(
            input_size=self.word_dim,
            hidden_size=self.hidden_size,
            num_layers=self.layers_num,
            bias=True,
            batch_first=True,
            dropout=0,
            bidirectional=True,
        )
        self.tanh = nn.Tanh()
        self.emb_dropout = nn.Dropout(self.emb_dropout_value)
        self.lstm_dropout = nn.Dropout(self.lstm_dropout_value)
        self.linear_dropout = nn.Dropout(self.linear_dropout_value)

        self.att_weight = nn.Parameter(torch.randn(1, self.hidden_size, 1))
        self.dense = nn.Linear(
            in_features=self.hidden_size,
            out_features=self.class_num,
            bias=True
        )

        # initialize weight
        init.xavier_normal_(self.dense.weight)
        init.constant_(self.dense.bias, 0.)

    def lstm_layer(self, x, mask):
        lengths = torch.sum(mask.gt(0), dim=-1)
        x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        h, (_, _) = self.lstm(x)
        h, _ = pad_packed_sequence(h, batch_first=True, padding_value=0.0, total_length=self.max_len)
        h = h.view(-1, self.max_len, 2, self.hidden_size)
        h = torch.sum(h, dim=2)  # B*L*H
        return h

    def attention_layer(self, h, mask):
        att_weight = self.att_weight.expand(mask.shape[0], -1, -1)  # B*H*1
        att_score = torch.bmm(self.tanh(h), att_weight)  # B*L*H  *  B*H*1 -> B*L*1

        # mask, remove the effect of 'PAD'
        mask = mask.unsqueeze(dim=-1)  # B*L*1
        att_score = att_score.masked_fill(mask.eq(0), float('-inf'))  # B*L*1
        att_weight = F.softmax(att_score, dim=1)  # B*L*1

        reps = torch.bmm(h.transpose(1, 2), att_weight).squeeze(dim=-1)  # B*H*L *  B*L*1 -> B*H*1 -> B*H
        reps = self.tanh(reps)  # B*H
        return reps

    def forward(self, data):
        token = data[:, 0, :].view(-1, self.max_len)
        mask = data[:, 1, :].view(-1, self.max_len)
        emb = self.word_embedding(token)  # B*L*word_dim
        emb = self.emb_dropout(emb)
        h = self.lstm_layer(emb, mask)  # B*L*H
        h = self.lstm_dropout(h)
        reps = self.attention_layer(h, mask)  # B*reps
        reps = self.linear_dropout(reps)
        logits = self.dense(reps)
        return logits
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89

5、train_or_test.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6

import os
import torch
import torch.nn as nn
import torch.optim as optim

from config import Config
from utils import WordEmbeddingLoader, RelationLoader, SemEvalDataLoader
from model import Att_BLSTM
from evaluate import Eval


def print_result(predict_label, id2rel, start_idx=8001):
    with open('predicted_result.txt', 'w', encoding='utf-8') as fw:
        for i in range(0, predict_label.shape[0]):
            fw.write('{}\t{}\n'.format(start_idx+i, id2rel[int(predict_label[i])]))


def train(model, criterion, loader, config):
    train_loader, dev_loader, _ = loader
    optimizer = optim.Adadelta(model.parameters(), lr=config.lr, weight_decay=config.L2_decay)

    print(model)
    print('traning model parameters:')
    for name, param in model.named_parameters():
        if param.requires_grad:
            print('%s :  %s' % (name, str(param.data.shape)))
    print('--------------------------------------')
    print('start to train the model ...')

    eval_tool = Eval(config)
    min_f1 = -float('inf')
    for epoch in range(1, config.epoch+1):
        for step, (data, label) in enumerate(train_loader):
            model.train()
            data = data.to(config.device)
            label = label.to(config.device)

            optimizer.zero_grad()
            logits = model(data)
            loss = criterion(logits, label)
            loss.backward()
            nn.utils.clip_grad_value_(model.parameters(), clip_value=5)
            optimizer.step()

        _, train_loss, _ = eval_tool.evaluate(model, criterion, train_loader)
        f1, dev_loss, _ = eval_tool.evaluate(model, criterion, dev_loader)

        print('[%03d] train_loss: %.3f | dev_loss: %.3f | micro f1 on dev: %.4f'
              % (epoch, train_loss, dev_loss, f1), end=' ')
        if f1 > min_f1:
            min_f1 = f1
            torch.save(model.state_dict(), os.path.join(config.model_dir, 'model.pkl'))
            print('>>> save models!')
        else:
            print()


def test(model, criterion, loader, config):
    print('--------------------------------------')
    print('start test ...')

    _, _, test_loader = loader
    model.load_state_dict(torch.load(os.path.join(config.model_dir, 'model.pkl')))
    eval_tool = Eval(config)
    f1, test_loss, predict_label = eval_tool.evaluate(model, criterion, test_loader)
    print('test_loss: %.3f | micro f1 on test:  %.4f' % (test_loss, f1))
    return predict_label


if __name__ == '__main__':
    config = Config()
    print('--------------------------------------')
    print('some config:')
    config.print_config()

    print('--------------------------------------')
    print('start to load data ...')
    word2id, word_vec = WordEmbeddingLoader(config).load_embedding()
    rel2id, id2rel, class_num = RelationLoader(config).get_relation()
    loader = SemEvalDataLoader(rel2id, word2id, config)

    train_loader, dev_loader = None, None
    if config.mode == 1:  # train mode
        train_loader = loader.get_train()
        dev_loader = loader.get_dev()
    test_loader = loader.get_test()
    loader = [train_loader, dev_loader, test_loader]
    print('finish!')

    print('--------------------------------------')
    model = Att_BLSTM(word_vec=word_vec, class_num=class_num, config=config)
    model = model.to(config.device)
    criterion = nn.CrossEntropyLoss()

    if config.mode == 1:  # train mode
        train(model, criterion, loader, config)
    predict_label = test(model, criterion, loader, config)
    print_result(predict_label, id2rel)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103

6、evaluate.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6

import numpy as np
import torch


def semeval_scorer(predict_label, true_label, class_num=10):
    import math
    assert true_label.shape[0] == predict_label.shape[0]
    confusion_matrix = np.zeros(shape=[class_num, class_num], dtype=np.float32)
    xDIRx = np.zeros(shape=[class_num], dtype=np.float32)
    for i in range(true_label.shape[0]):
        true_idx = math.ceil(true_label[i]/2)
        predict_idx = math.ceil(predict_label[i]/2)
        if true_label[i] == predict_label[i]:
            confusion_matrix[predict_idx][true_idx] += 1
        else:
            if true_idx == predict_idx:
                xDIRx[predict_idx] += 1
            else:
                confusion_matrix[predict_idx][true_idx] += 1

    col_sum = np.sum(confusion_matrix, axis=0).reshape(-1)
    row_sum = np.sum(confusion_matrix, axis=1).reshape(-1)
    f1 = np.zeros(shape=[class_num], dtype=np.float32)

    for i in range(0, class_num):  # ignore the 'Other'
        try:
            p = float(confusion_matrix[i][i]) / float(col_sum[i] + xDIRx[i])
            r = float(confusion_matrix[i][i]) / float(row_sum[i] + xDIRx[i])
            f1[i] = (2 * p * r / (p + r))
        except:
            pass
    actual_class = 0
    total_f1 = 0.0
    for i in range(1, class_num):
        if f1[i] > 0.0:  # classes that not in the predict label are not considered
            actual_class += 1
            total_f1 += f1[i]
    try:
        macro_f1 = total_f1 / actual_class
    except:
        macro_f1 = 0.0
    return macro_f1


class Eval(object):
    def __init__(self, config):
        self.device = config.device

    def evaluate(self, model, criterion, data_loader):
        predict_label = []
        true_label = []
        total_loss = 0.0
        with torch.no_grad():
            model.eval()
            for _, (data, label) in enumerate(data_loader):
                data = data.to(self.device)
                label = label.to(self.device)

                logits = model(data)
                loss = criterion(logits, label)
                total_loss += loss.item() * logits.shape[0]

                _, pred = torch.max(logits, dim=1)  # replace softmax with max function, same impacts
                pred = pred.cpu().detach().numpy().reshape((-1, 1))
                label = label.cpu().detach().numpy().reshape((-1, 1))
                predict_label.append(pred)
                true_label.append(label)
        predict_label = np.concatenate(predict_label, axis=0).reshape(-1).astype(np.int64)
        true_label = np.concatenate(true_label, axis=0).reshape(-1).astype(np.int64)
        eval_loss = total_loss / predict_label.shape[0]

        f1 = semeval_scorer(predict_label, true_label)
        return f1, eval_loss, predict_label

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78

7、util.py

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Version : Python 3.6


import os
import json
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader


class WordEmbeddingLoader(object):
    """
    A loader for pre-trained word embedding
    """

    def __init__(self, config):
        self.path_word = config.embedding_path  # path of pre-trained word embedding
        self.word_dim = config.word_dim  # dimension of word embedding

    def load_embedding(self):
        word2id = dict()  # word to wordID
        word_vec = list()  # wordID to word embedding

        word2id['PAD'] = len(word2id)  # PAD character
        word2id['UNK'] = len(word2id)  # out of vocabulary
        word2id['<e1>'] = len(word2id)
        word2id['<e2>'] = len(word2id)
        word2id['</e1>'] = len(word2id)
        word2id['</e2>'] = len(word2id)

        with open(self.path_word, 'r', encoding='utf-8') as fr:
            for line in fr:
                line = line.strip().split()
                if len(line) != self.word_dim + 1:
                    continue
                word2id[line[0]] = len(word2id)
                word_vec.append(np.asarray(line[1:], dtype=np.float32))

        word_vec = np.stack(word_vec)
        vec_mean, vec_std = word_vec.mean(), word_vec.std()
        special_emb = np.random.normal(vec_mean, vec_std, (6, self.word_dim))
        special_emb[0] = 0  # <pad> is initialize as zero

        word_vec = np.concatenate((special_emb, word_vec), axis=0)
        word_vec = word_vec.astype(np.float32).reshape(-1, self.word_dim)
        word_vec = torch.from_numpy(word_vec)
        return word2id, word_vec


class RelationLoader(object):
    def __init__(self, config):
        self.data_dir = config.data_dir

    def __load_relation(self):
        relation_file = os.path.join(self.data_dir, 'relation2id.txt')
        rel2id = {}
        id2rel = {}
        with open(relation_file, 'r', encoding='utf-8') as fr:
            for line in fr:
                relation, id_s = line.strip().split()
                id_d = int(id_s)
                rel2id[relation] = id_d
                id2rel[id_d] = relation
        return rel2id, id2rel, len(rel2id)

    def get_relation(self):
        return self.__load_relation()


class SemEvalDateset(Dataset):
    def __init__(self, filename, rel2id, word2id, config):
        self.filename = filename
        self.rel2id = rel2id
        self.word2id = word2id
        self.max_len = config.max_len
        self.data_dir = config.data_dir
        self.dataset, self.label = self.__load_data()

    def __symbolize_sentence(self, sentence):
        """
            Args:
                sentence (list)
        """
        mask = [1] * len(sentence)
        words = []
        length = min(self.max_len, len(sentence))
        mask = mask[:length]

        for i in range(length):
            words.append(self.word2id.get(sentence[i].lower(), self.word2id['UNK']))

        if length < self.max_len:
            for i in range(length, self.max_len):
                mask.append(0)  # 'PAD' mask is zero
                words.append(self.word2id['PAD'])

        unit = np.asarray([words, mask], dtype=np.int64)
        unit = np.reshape(unit, newshape=(1, 2, self.max_len))
        return unit

    def __load_data(self):
        path_data_file = os.path.join(self.data_dir, self.filename)
        data = []
        labels = []
        with open(path_data_file, 'r', encoding='utf-8') as fr:
            for line in fr:
                line = json.loads(line.strip())
                label = line['relation']
                sentence = line['sentence']
                label_idx = self.rel2id[label]

                one_sentence = self.__symbolize_sentence(sentence)
                data.append(one_sentence)
                labels.append(label_idx)
        return data, labels

    def __getitem__(self, index):
        data = self.dataset[index]
        label = self.label[index]
        return data, label

    def __len__(self):
        return len(self.label)


class SemEvalDataLoader(object):
    def __init__(self, rel2id, word2id, config):
        self.rel2id = rel2id
        self.word2id = word2id
        self.config = config

    def __collate_fn(self, batch):
        data, label = zip(*batch)  # unzip the batch data
        data = list(data)
        label = list(label)
        data = torch.from_numpy(np.concatenate(data, axis=0))
        label = torch.from_numpy(np.asarray(label, dtype=np.int64))
        return data, label

    def __get_data(self, filename, shuffle=False):
        dataset = SemEvalDateset(filename, self.rel2id, self.word2id, self.config)
        loader = DataLoader(
            dataset=dataset,
            batch_size=self.config.batch_size,
            shuffle=shuffle,
            num_workers=2,
            collate_fn=self.__collate_fn
        )
        return loader

    def get_train(self):
        return self.__get_data('train.json', shuffle=True)

    def get_dev(self):
        return self.__get_data('test.json', shuffle=False)

    def get_test(self):
        return self.__get_data('test.json', shuffle=False)


if __name__ == '__main__':
    from config import Config
    config = Config()
    word2id, word_vec = WordEmbeddingLoader(config).load_embedding()
    rel2id, id2rel, class_num = RelationLoader(config).get_relation()
    loader = SemEvalDataLoader(rel2id, word2id, config)
    test_loader = loader.get_train()

    for step, (data, label) in enumerate(test_loader):
        print(type(data), data.shape)
        print(type(label), label.shape)
        break

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/496485
推荐阅读
相关标签
  

闽ICP备14008679号