当前位置:   article > 正文

NLP实践——Few-shot事件抽取《Building an Event Extractor with Only a Few Examples》

event extractor with only a few examples


title

0. overview

今天介绍的是伊利诺伊大学Blender Lab的一个工作,发表在NAACL2022的workshop。
这篇论文包含了两个部分,第一部分是介绍怎样抽取事件的触发词,第二部分是介绍怎样在已知触发词的情况下抽取事件论元。本篇博客将主要介绍第二部分,即论元抽取的部分,其主要工作是针对few-shot事件论元抽取的,致力于以少量的标注数据,建立以融合了触发词与候选论元实体的特征空间,与论元角色名称的特征空间建立一个映射,然后试图以余弦相似度的方式判断某候选论元是否与该事件的此角色相对应。

在事件抽取方面,目前研究的热点,也是既有模型的痛点,可以总结为:

  • 解决长距离依赖的问题:篇章级论元抽取
  • 统一信息来源的问题:跨篇章、多语种、多模态
  • 解决数据来源的问题:few-shot/zero-shot/prompt-learning

从题目就可以直接看出这篇论文是针对数据来源问题的。

先谈一下个人的感受。Blender Lab我关注很久了,近两年的工作我基本上也都实验过,总的来说这个团队在事件抽取领域做的是相当不错的,非常具有开源精神,也帮到了我很多,但是few-shot这个工作,看完之后其实是有些失望的。为什么这么说呢,首先从创新点上讲,我个人认为这篇论文核心思想的创新性不太够,符合基本认知但是给我一种有点草率的感觉,另外,我试了一下作者发布的训练好的模型(时间和精力原因我没有自己去训练它),感觉效果不是很理想,对这个结果有点失望,原本是期待着模型能够解决一些few-shot的问题的。

为什么要介绍这篇论文呢,一句话讲就是,它的核心思想虽然很简单,但是很“合理”,符合思维逻辑,其实之前我也有过类似的想法,用论元的编码与事件角色的编码直接计算一个相似度来判断论元是否是该角色,按照万物皆可embedding的理念,这是一个非常容易产生的想法,所以这篇论文算是做了一件我一直想做但是无从下手的事情,让我比较感兴趣。

本文会对论文的主要思想进行简单的介绍,并且提供一套可以快速上手使用的代码。如果你对这篇论文感兴趣,想试着训练一个自己模型进一步评测它,请直接跳转作者给的git链接。由于代码比较简单,作者介绍的训练方法也比较详细,本文就不介绍怎样训练了。

以下是原文和项目地址:
https://blender.cs.illinois.edu/paper/weaksupervision2022.pdf
第一部分项目地址(获取触发词):
https://github.com/Perfec-Yu/efficient-event-extraction
第二部分项目地址(获取论元):
https://github.com/zhangzx-uiuc/zero-shot-event-arguments
两部分综合在一起的docker地址:
https://hub.docker.com/repository/docker/zixuan11/event-extractor

1. 论文思想介绍

1.1 关键词聚类与基于关键词簇找事件触发词

略。
trigger labeling

1.2 论元抽取

首先,用spacy或nltk抽取实体,作为事件论元的候选。

对于如何确定一个实体是否作为事件的某角色的论元,文章用BERT对所有的角色名称进行了编码,如,某事件类型下,有角色名称 R 1 , R 2 , . . . R_{1}, R_{2}, ... R1,R2,...,编码之后得到角色对应的编码 R 1 , R 2 , . . . {\mathbf R_{1}},{\mathbf R_{2}}, ... R1,R2,...

对于1.1中获取到的触发词 t i t_{i} ti和刚抽取出来的实体 e i e_{i} ei,构建触发词-实体对
x i = [ t i , e i ] x_i = [t_{i}, e_i] xi=[ti,ei]

于是训练的目标就是最小化hinge loss:

L o s s i = ∑ i ≠ j m a x ( m − C ( x i , r i ) + C ( x i , r j ) ) Loss_i = \sum_{i \neq j} max(m - C(x_i, r_i) + C(x_i, r_j)) Lossi=i=jmax(mC(xi,ri)+C(xi,rj))
也就是,如果一个实体是某个事件的某个角色的论元,则拉进实体与触发词编码contact的结果(即 x i x_i xi),与角色编码表征的距离,反之则使它们尽可能远离。

2. 代码及使用

在作者提供的git项目中,提供了训练和评估的代码,但是没有提供已经训练好的模型以供测试,也就是说需要自己去训练一个模型出来。训练的数据格式与OneIE格式比较接近,但是也有点区别,所以我偷了个懒,没有去做数据格式转换,而是从作者提供的docker里把训练好的模型拿出来。

考虑到有些同学对docker不熟悉,我把整个docker项目传网盘了,里边有两个已经训练好的模型,是预测时需要用的,分别对应论文的两个部分,触发词抽取和论元抽取。

地址1:完整的文件,但是由于超出了4g的限制我上传不了,所以我把论元模型的文件拿出来,单独上传了。
链接:https://pan.baidu.com/s/1ZKz18tusPbmt-nAVdCZtLA
提取码:e55a

地址2:论元模型。
链接:https://pan.baidu.com/s/1IWCoOXhvTmo89c7XGijIFg
提取码:i3mv

2.1 触发词抽取模型

这里需要分别加载两个模型,先来看第一个,触发词抽取模型。

建一个utils.py,写入以下内容(其实就是docker文件里边的api.py):

import torch
import torch.nn as nn
import transformers
from transformers import BatchEncoding
from typing import *
import re

IO_match = re.compile(r'(?P<start>\d+)I-(?P<label>\S+)\s(?:(?P<end>\d+)I-(?P=label)\s)*')


def to_char(predictions: List[Union[List[Tuple[int, int, str]], Set[Tuple[int, int, str]]]],
            encodings: Union[List[BatchEncoding], BatchEncoding]) \
        -> List[Union[List[Tuple[int, int, str]], Set[Tuple[int, int, str]]]]:
    fw = None
    corpus_annotations = []
    for i, prediction in enumerate(predictions):
        if isinstance(encodings, list):
            encoding = encodings[i]
        annotations = []
        for annotation in prediction:
            start_pt = annotation[0]
            end_pt = annotation[1]
            if isinstance(encodings, list):
                start = encoding.token_to_chars(start_pt).start
                end = encoding.token_to_chars(end_pt - 1).end
            else:
                start = encodings.token_to_chars(i, start_pt).start
                end = encodings.token_to_chars(i, end_pt - 1).end
            annotations.append([start, end, annotation[2]])
        corpus_annotations.append(annotations)
    return corpus_annotations


class IEToken(nn.Module):
    def __init__(self, nclass: int, model_name: str, id2label: Dict[int, str], **kwargs):
        super().__init__()
        self.pretrained_lm = transformers.AutoModel.from_pretrained(model_name)
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
        self.linear_map = nn.Linear(2048, nclass)
        self.crit = nn.CrossEntropyLoss()
        self.id2label = id2label

    def compute_cross_entropy(self, logits, labels):
        mask = labels >= 0
        return self.crit(logits[mask], labels[mask])

    def forward(self, batch):
        token_ids, attention_masks, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
        encoded = self.pretrained_lm(token_ids, attention_masks, output_hidden_states=True)
        encoded = torch.cat((encoded.last_hidden_state, encoded.hidden_states[-3]), dim=-1)
        outputs = self.linear_map(encoded)
        loss = self.compute_cross_entropy(outputs, labels)
        preds = torch.argmax(outputs, dim=-1)
        preds[labels < 0] = labels[labels < 0]
        return {
            "loss": loss,
            "prediction": preds.long().detach(),
            "label": labels.long().detach()
        }


def find_offsets(seq_str: str, match: re.Pattern):
    annotations = []
    for annotation in match.finditer(seq_str):
        start = int(annotation.group('start'))
        label = annotation.group('label')
        end = annotation.group('end')
        end = start + 1 if end is None else int(end) + 1
        annotations.append((start, end, label))
    return annotations


def collect_spans(sequence: str, tag2label: Optional[Dict[str, str]] = None) -> Set[Tuple[int, int, str]]:
    spans = find_offsets(sequence, IO_match)
    if tag2label:
        label_spans = set()
        for span in spans:
            label_spans.add((span[0], span[1], tag2label[span[2]]))
    else:
        label_spans = set(spans)
    return label_spans


def annotate(sentence, model: IEToken, batch_size=8, max_length=96):
    # print(sentence)
    # print(model.id2label)
    # return [[{"trigger": [80, 86, "Transaction:Transfer-Money"], "arguments": [[20, 30, "Giver"], [87, 92, "Recipient"]]}]]
    model.eval()
    label2tag = {
        v: v.replace("-", "_") for v in model.id2label.values()
    }
    tag2label = {
        v: k for k, v in label2tag.items()
    }

    def get_tag(id):
        if id == 0:
            return 'O'
        else:
            return f'I-{label2tag[model.id2label[id]]}'

    annotations = []
    with torch.no_grad():
        for i in range(0, len(sentence), batch_size):
            encoded = model.tokenizer(
                text=sentence[i:i + batch_size],
                max_length=max_length,
                is_split_into_words=isinstance(sentence[0], list),
                add_special_tokens=True,
                padding='longest',
                truncation=True,
                return_attention_mask=True,
                return_special_tokens_mask=False,
                return_tensors='pt'
            )
            encoded.to(model.linear_map.weight.device)
            input_ids, attention_masks = encoded["input_ids"], encoded["attention_mask"]
            hidden = model.pretrained_lm(input_ids, attention_masks, output_hidden_states=True)
            hidden = torch.cat((hidden.last_hidden_state, hidden.hidden_states[-3]), dim=-1)
            outputs = model.linear_map(hidden)
            preds = torch.argmax(outputs, dim=-1)
            preds[attention_masks == 0] = -100
            preds = preds.cpu().numpy()
            sequences = []
            for idx, sequence in enumerate(preds):
                sequence = sequence[sequence != -100]
                sequences.append(" ".join([f'{offset}{get_tag(token)}' for offset, token in enumerate(sequence)]) + " ")
            sequences = [list(collect_spans(sequence, tag2label)) for sequence in sequences]
            annotations.extend(to_char(sequences, encoded))

    results = []
    for k, trigger in enumerate(annotations[0]):
        result_k = {"trigger": trigger, "arguments": []}
        results.append(result_k)

    return [results]


def annotate_arguments(trigger_annotations, sentence, arg_model, tokenizer, spacy_model, nltk_tokenizer):
    data_input = {"sentence": sentence, "events": trigger_annotations[0]}
    output_res = arg_model.predict_one_example(tokenizer, data_input, spacy_model, nltk_tokenizer)
    return [output_res["events"]]


def load_ckpt(ckpt_path: str, device="cuda:0"):
    ckpt = torch.load(ckpt_path, map_location=torch.device(device))
    state_dict = ckpt['state_dict']
    nclass = state_dict['linear_map.weight'].size(0)
    id2label = ckpt['id2label']
    model = IEToken(nclass, "roberta-large", id2label)
    model.to(torch.device(device))
    model.load_state_dict(state_dict=state_dict)
    return model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153

然后就可以加载这个触发词模型:

如果你的显卡算力不是很高的话,可以直接load它的整个模型结构:

from utils import load_ckpt

# 模型文件是这个model.best
model_path = 'xxxxxx/instance/data/log/model.best'
model = load_ckpt(model_path)
# 注意这个过程需要联网,因为会自动去下载一个roberta-large
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

如果报错了,可能是因为你的环境和作者保存时候的环境不一致,比如你的显卡比较新,跟CUDA和驱动不匹配,例如3090就读不了这个模型,不过不要担心,我用一台相对旧一点的机器读了这个模型并且保存了statedict,可以从下面的链接下载。在这里也强烈建议大家,在写自己模型的时候,一定要让这个模型可以初始化,保存的时候存statedict,这样在加载的时候就不存在环境限制了。

statedict我放在网盘了:
链接:https://pan.baidu.com/s/16Ptm_sDEUQmM6NGTFZ3mxQ
提取码:ip4g

# 从刚刚建立的utils.py引用模型类
import torch
from utils import IEToken


id2label = {2: 'Movement:Transport',
 5: 'Personnel:Elect',
 13: 'Personnel:Start-Position',
 28: 'Personnel:Nominate',
 6: 'Personnel:End-Position',
 1: 'Conflict:Attack',
 4: 'Contact:Meet',
 19: 'Life:Marry',
 7: 'Transaction:Transfer-Money',
 16: 'Conflict:Demonstrate',
 23: 'Business:End-Org',
 17: 'Justice:Sue',
 8: 'Life:Injure',
 3: 'Life:Die',
 15: 'Justice:Arrest-Jail',
 9: 'Contact:Phone-Write',
 12: 'Transaction:Transfer-Ownership',
 25: 'Business:Start-Org',
 26: 'Justice:Execute',
 11: 'Justice:Trial-Hearing',
 21: 'Life:Be-Born',
 10: 'Justice:Charge-Indict',
 18: 'Justice:Convict',
 14: 'Justice:Sentence',
 22: 'Business:Declare-Bankruptcy',
 20: 'Justice:Release-Parole',
 27: 'Justice:Fine',
 33: 'Justice:Pardon',
 24: 'Justice:Appeal',
 31: 'Justice:Extradite',
 30: 'Life:Divorce',
 29: 'Business:Merge-Org',
 32: 'Justice:Acquit',
 0: 'NA'}

model = IEToken(nclass=34, model_name='roberta-large', id2label=id2label)
# 这里的roberta-large如果你已经下载过了,可以直接传它的路径

model.load_state_dict(torch.load('ie_model.bin'))  # 'ie_model.bin'就是我保存的statedict的路径
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44

现在你可以抽取触发词了:

from utils import annotate

sent = """Bagri was also charged with trying to murder Tara Singh Hayer, editor of The Indo-Canadian Times, North America's largest Punjabi newspaper, in 1998."""

trigger_annotations = annotate([sent], model)
# [[{'trigger': [15, 22, 'Justice:Charge-Indict'], 'arguments': []},
#   {'trigger': [38, 44, 'Life:Die'], 'arguments': []}]]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

2.2 论元抽取模型

读取ontology:

import json


label_info_path = 'xxxxxxxx/instance/data/label_info.json'
with open(label_info_path, 'r', encoding='utf-8') as f:
    label_info = json.loads(f.read())
ontology = {}
for type in label_info:
    ontology[type] = label_info[type]["roles"]

event_type_idxs, role_type_idxs = {}, {"unrelated object": -1}
event_num, role_num = 0, 0
for event_type in ontology:
    if event_type not in event_type_idxs:
        event_type_idxs[event_type] = event_num
        event_num += 1

for event_type in ontology:
    roles = ontology[event_type]
    for role in roles:
        if role not in role_type_idxs:
            role_type_idxs[role] = role_num
            role_num += 1
event_type_idxs = event_type_idxs
role_type_idxs = role_type_idxs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

建立模型:
创建一个arg_utils.py,写入内容如下:

from transformers import BertModel, BertTokenizerFast

import torch
import torch.nn as nn
import numpy as np

import nltk
import math

from copy import deepcopy

def get_spacy_entities(sent, model):
    entity_type_list = ["ORG", "PERSON", "GPE", "LOC"]
    token_offsets = []
    doc = model(sent)
    for ent in doc.ents:
        if ent.label_ in entity_type_list:
            token_offsets.append([ent.start_char, ent.end_char])
    return token_offsets


def get_noun_entities(sent, nltk_tokenizer):
    offsets = []
    tokens = nltk_tokenizer.tokenize(sent)
    spans = list(nltk_tokenizer.span_tokenize(sent))
    pos_tags = nltk.pos_tag(tokens)
    pos_tag_num = len(pos_tags)
    pos_tags.append(("null", "null"))
    for i in range(pos_tag_num):
        if pos_tags[i][1].startswith("NN") and (not pos_tags[i+1][1].startswith("NN")):
            offsets.append(spans[i])
    return offsets  


def get_entities(sent, way, tokenizer, spacy_model):
    if way == "spacy":
        return get_spacy_entities(sent, spacy_model)
    else:
        return get_noun_entities(sent, tokenizer)


def transform_offsets(start, end, offsets_list):
    curr_list = offsets_list[1:-1].copy()
    length = len(offsets_list)
    curr_list.append((math.inf, math.inf))
    start_idx, end_idx = 0, 1
    for i in range(length - 1):
        if start > curr_list[i][0] and start <= curr_list[i+1][0]:
            start_idx = i+1
        if end > curr_list[i][0] and end <= curr_list[i+1][0]:
            end_idx = i+1
    return start_idx, end_idx


def transform_to_list(start, end, seq_length):
    output_list = [0.0 for _ in range(seq_length)]
    for i in range(start, end):
        output_list[i] = 1.0
    return output_list


def padding_mask_list(input_mask_list):
    mask_list = deepcopy(input_mask_list)
    max_subwords_num = max([len(mask) for mask in mask_list])
    for mask in mask_list:
        mask[0] = 0
        mask[-1] = 0
    padded_attn_mask_list = [mask+[0 for _ in range(max_subwords_num-len(mask))] for mask in mask_list]
    return padded_attn_mask_list


def get_bert_embeddings(word_list, model, tokenizer, gpu):
    # input: a list of words, bert_model, bert_tokenizer
    # output: numpy tensor (word_num, dim)
    segments = tokenizer(word_list)
    attn_mask_list = segments["attention_mask"]
    padded_attn_mask_list = padding_mask_list(attn_mask_list)

    padded_segments = tokenizer.pad(segments)
    input_ids, attn_mask = padded_segments["input_ids"], padded_segments["attention_mask"]

    if gpu == "cpu":
        batch_input_ids, batch_attn_mask = torch.LongTensor(input_ids), torch.LongTensor(attn_mask)
        batch_padded_mask = torch.FloatTensor(padded_attn_mask_list)
    else:
        batch_input_ids, batch_attn_mask = torch.LongTensor(input_ids).to(gpu), torch.LongTensor(attn_mask).to(gpu)
        batch_padded_mask = torch.FloatTensor(padded_attn_mask_list).to(gpu)
    
    encodes = model(batch_input_ids, attention_mask=batch_attn_mask)[0]

    avg_padded_mask = batch_padded_mask / (torch.sum(batch_padded_mask, 1).unsqueeze(-1))
    output_embeds = torch.sum(torch.stack([avg_padded_mask for _ in range(encodes.shape[-1])], 2) * encodes, 1)
    return output_embeds


def send_to_gpu(batch, gpu):
    for key,item in batch.items():
        if hasattr(item, "shape"):
            batch[key] = item.to(gpu)


if __name__ == "__main__":
    b = BertModel.from_pretrained("bert-large-uncased")
    t = BertTokenizerFast.from_pretrained("bert-large-uncased")

    roles = ['Adjudicator', 'Agent', 'Artifact', 'Attacker', 'Beneficiary', 'Buyer', 'Defendant', 'Destination', 'Entity', 'Giver', 'Instrument', 'Organization', 'Origin', 'Person', 'Place', 'Plaintiff', 'Prosecutor', 'Recipient', 'Seller', 'Target', 'Vehicle', 'Victim']
    words = [role.lower() for role in roles]

    word_embeds = get_bert_embeddings(words, b, t, "cpu")
    role_num = len(roles)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110

然后建一个model.py,写入以下内容:

from arg_utils import *
from copy import deepcopy

def get_numberized_onto(onto, event_type_idx, role_type_idx):
  num_onto = {}
  for event_type in onto:
      et_idx = event_type_idx[event_type]
      roles = onto[event_type]
      num_roles = [role_type_idx[role] for role in roles]
      # num_roles.append(0)
      num_onto.update({et_idx: num_roles.copy()})
  return num_onto


class Linears(nn.Module):
  """Multiple linear layers with Dropout."""
  def __init__(self, dimensions, activation='tanh', dropout_prob=0.3, bias=True):
      super().__init__()
      assert len(dimensions) > 1
      self.layers = nn.ModuleList([nn.Linear(dimensions[i], dimensions[i + 1], bias=bias)
                                   for i in range(len(dimensions) - 1)])
      self.activation = getattr(torch, activation)
      self.dropout = nn.Dropout(dropout_prob)

  def forward(self, inputs):
      for i, layer in enumerate(self.layers):
          if i > 0:
              inputs = self.activation(inputs)
              inputs = self.dropout(inputs)
          inputs = layer(inputs)
      return inputs


class ZeroShotModel(nn.Module):
  def __init__(self, bert_name, train_role_types, test_role_types, train_onto, test_onto, train_event_types, test_event_types, dropout, bert_dim, ft_hidden_dim, role_hidden_dim, output_dim, alpha, device):
      super(ZeroShotModel, self).__init__()

      nltk.download('averaged_perceptron_tagger')

      self.train_onto = train_onto
      self.test_onto = test_onto
      # ontology: {"Conflict:Attack": ["Attacker", "Victim", "Target", "Place"]}

      self.train_idx_to_type = {v:k for k,v in train_event_types.items()}
      self.test_idx_to_type = {v:k for k,v in test_event_types.items()}
      # idx_to_type: {0: "Conflict:Attack"}

      self.train_event_types = train_event_types
      self.test_event_types = test_event_types
      # event_types: {"Conflict:Attack": 0}

      self.train_role_type_idx = train_role_types # {role_name: idx}
      self.train_role_type_num = len(self.train_role_type_idx) - 1

      self.test_role_type_idx = test_role_types # {role_name: idx}
      self.test_role_type_num = len(self.test_role_type_idx) - 1
      self.test_rev_role_type_idx = {v:k for k,v in self.test_role_type_idx.items()}

      self.numberized_test_onto = get_numberized_onto(test_onto, test_event_types, test_role_types)

      self.device = device
      self.dropout = dropout
      self.bert_dim = bert_dim
      self.ft_hidden_dim = ft_hidden_dim
      self.role_hidden_dim = role_hidden_dim
      self.output_dim = output_dim
      self.alpha = alpha

      self.bert = BertModel.from_pretrained(bert_name)
      self.role_name_encoder = Linears([bert_dim, role_hidden_dim, output_dim], dropout_prob=dropout)
      self.text_ft_encoder = Linears([2*bert_dim, ft_hidden_dim, output_dim], dropout_prob=dropout)
      self.cosine_sim_2 = nn.CosineSimilarity(dim=2)
      self.cosine_sim_3 = nn.CosineSimilarity(dim=3)

      self.bert.to(device)

  def compute_train_role_reprs(self, tokenizer):
      role_names = sorted(self.train_role_type_idx.items(), key=lambda x:x[1])
      names = []
      for name in role_names:
          names.append(name[0])
      train_role_reprs = get_bert_embeddings(names, self.bert, tokenizer, self.device)
      self.train_role_reprs = train_role_reprs.detach()[1:, :]
  
  def compute_test_role_reprs(self, tokenizer):
      role_names = sorted(self.test_role_type_idx.items(), key=lambda x:x[1])
      names = []
      for name in role_names:
          names.append(name[0])

      test_role_reprs = get_bert_embeddings(names, self.bert, tokenizer, self.device)
      self.test_role_reprs = test_role_reprs.detach()[1:, :]

  def span_encode(self, bert_output, span_input):
      # bert_output: (batch, seq_len, dim)
      # span_input: (batch, num, seq_len)
      # OUTPUT: (batch, num, dim)
      dim = bert_output.shape[2]
      num = span_input.shape[1]
      avg_span_input = span_input / torch.sum(span_input, 2).unsqueeze(2)
      avg_weights = avg_span_input.unsqueeze(3).repeat(1, 1, 1, dim)
      bert_repeated = bert_output.unsqueeze(1).repeat(1, num, 1, 1)
      span_output = torch.sum(bert_repeated * avg_weights, 2)
      return span_output
  
  def forward(self, batch):
      # batch: {"input_ids", "attn_mask", "trigger_spans", "entity_spans", "label_idxs", "neg_label_idxs", "pair_mask"}
      # pairs_num = batch["pair_mask"].sum()
      bert_outputs = self.bert(batch["input_ids"], attention_mask=batch["attn_mask"])[0]
      trigger_reprs = self.span_encode(bert_outputs, batch["trigger_spans"])
      entity_reprs = self.span_encode(bert_outputs, batch["entity_spans"])
      ta_reprs = torch.cat((trigger_reprs, entity_reprs), 2) 

      # minimize the distance between correct pairs
      role_reprs = self.role_name_encoder(self.train_role_reprs) # (role_num, output_dim)
      label_reprs = role_reprs[batch["label_idxs"]] # (bs, num, output_dim)
      # print() 
      output_ta_reprs = self.text_ft_encoder(ta_reprs) # (bs, num, output_dim)
      pos_cos_sim = self.cosine_sim_2(output_ta_reprs, label_reprs) # (bs, num)

      # print(self.train_role_reprs.shape)
      neg_label_reprs = role_reprs[batch["neg_label_idxs"]] # (bs, num, neg_role_num, output_dim)
      repeated_ta_reprs = output_ta_reprs.unsqueeze(2).repeat(1, 1, self.train_role_type_num-1, 1)

      neg_cos_sim = self.cosine_sim_3(repeated_ta_reprs, neg_label_reprs) # (bs, num, neg_role_num)
      pos_cos_sims = pos_cos_sim.unsqueeze(2).repeat(1, 1, self.train_role_type_num-1)

      hinge_matrix = torch.sum(torch.clamp(neg_cos_sim - pos_cos_sims + self.alpha, min=0), 2)

      hinge_loss = (hinge_matrix * batch["train_pair_mask"]).sum()

      return hinge_loss

  def predict(self, batch):
      # batch: {"input_ids", "attn_mask", "trigger_spans", "entity_spans", "label_idxs", "neg_label_idxs", "trigger_idxs", "pair_mask"}
      output_list = []
      bs, max_num = batch["trigger_spans"].shape[0], batch["trigger_spans"].shape[1]
      trigger_idxs = batch["trigger_idxs"] # (batch_num, max_num)

      with torch.no_grad():
          bert_outputs = self.bert(batch["input_ids"], attention_mask=batch["attn_mask"])[0]
          trigger_reprs = self.span_encode(bert_outputs, batch["trigger_spans"])
          entity_reprs = self.span_encode(bert_outputs, batch["entity_spans"])
          ta_reprs = torch.cat((trigger_reprs, entity_reprs), 2) 
          output_ta_reprs = self.text_ft_encoder(ta_reprs) # (batch_size, pair_num, output_dim)
          role_reprs = self.role_name_encoder(self.test_role_reprs) # (role_num, output_dim)
          sum_mask = torch.sum(batch["pair_mask"], 1).long().tolist()

          role_num = role_reprs.shape[0]

          for i in range(bs):
              output_i = []
              pair_num_i = sum_mask[i]

              ta_reprs_i = output_ta_reprs[i][0:pair_num_i]
              repeated_ta_reprs_i = ta_reprs_i.unsqueeze(1).repeat(1, role_num, 1)
              repeated_role_reprs = role_reprs.unsqueeze(0).repeat(pair_num_i, 1, 1)

              cos_sim = self.cosine_sim_2(repeated_ta_reprs_i, repeated_role_reprs) # (pair_num_i, role_num)
              event_type_idx_i = trigger_idxs[i].tolist()

              for j in range(pair_num_i):
                  cos_sim_j = cos_sim[j]
                  event_type = event_type_idx_i[j]
                  role_idxs = self.numberized_test_onto[event_type]
                  role_scores = [cos_sim_j[idx].item() for idx in role_idxs]

                  idxs = np.argsort(-np.array(role_scores))
                  if role_scores[idxs[0]] - role_scores[idxs[1]] > 1.5 * self.alpha:
                      output_i.append(role_idxs[idxs[0]])
                  else:
                      output_i.append(-1)
  
              output_list.append(output_i.copy())
      
      return output_list
  
  def change_test_ontology(self, test_ontology, test_event_types, test_role_types, tokenizer):
      self.test_event_types = test_event_types
      self.test_idx_to_type = {v:k for k,v in test_event_types.items()}
      self.test_onto = test_ontology
      self.test_role_type_idx = test_role_types # {role_name: idx}
      self.test_role_type_idx.update({"unrelated object": -1})
      self.test_rev_role_type_idx = {v:k for k,v in self.test_role_type_idx.items()}
      self.test_role_type_num = len(self.test_role_type_idx) - 1
      self.numberized_test_onto = get_numberized_onto(test_ontology, test_event_types, test_role_types)
      self.compute_test_role_reprs(tokenizer)
  
  def predict_one_example(self, tokenizer, data_item, spacy_model, nltk_tokenizer):
      if len(data_item["events"]) == 0:
          return deepcopy(data_item)

      bert_inputs = tokenizer(data_item["sentence"], return_offsets_mapping=True)
      input_ids = torch.LongTensor([bert_inputs["input_ids"]]).to(self.device)
      attn_mask = torch.LongTensor([bert_inputs["attention_mask"]]).to(self.device)
      offset_mapping = bert_inputs["offset_mapping"]
      batch_input = {"input_ids": input_ids, "attn_mask": attn_mask}
      triggers = data_item["events"]

      entity_offsets = get_entities(data_item["sentence"], "nltk", nltk_tokenizer, spacy_model)

      if len(entity_offsets) == 0:
          return deepcopy(data_item)

      seq_len = len(bert_inputs["input_ids"])

      trigger_span, entity_span = [], []
      trigger_idxs = []
      for i,trig in enumerate(triggers):
          trigger = trig["trigger"]
          for j,entity in enumerate(entity_offsets):
              trig_s, trig_e = transform_offsets(trigger[0], trigger[1], offset_mapping)
              ent_s, ent_e = transform_offsets(entity[0], entity[1], offset_mapping)
              
              trig_list = transform_to_list(trig_s+1, trig_e+1, seq_len)
              ent_list = transform_to_list(ent_s+1, ent_e+1, seq_len)

              trigger_span.append(trig_list)
              entity_span.append(ent_list)
              trigger_idxs.append(self.test_event_types[trigger[-1]])
      
      batch_input["trigger_spans"] = torch.FloatTensor([trigger_span]).to(self.device)
      batch_input["entity_spans"] = torch.FloatTensor([entity_span]).to(self.device)
      batch_input["trigger_idxs"] = torch.LongTensor([trigger_idxs]).to(self.device)
      batch_input["pair_mask"] = torch.FloatTensor([[1.0 for _ in range(len(trigger_span))]])

      output_list = self.predict(batch_input)[0]
      
      output_item = deepcopy(data_item)
      for i,trigger in enumerate(triggers):
          args = []
          for j,entity in enumerate(entity_offsets):
              output_idx = i * len(entity_offsets) + j
              res = output_list[output_idx]
              if res != -1:
                  arg = [entity[0], entity[1], self.test_rev_role_type_idx[res]]
                  args.append(arg)
          output_item["events"][i].update({"arguments": args})

      return output_item
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240

然后,实例化这个模型:

from model import ZeroShotModel   # 刚刚创建的model.py
from transformers import BertTokenizerFast

arg_model = ZeroShotModel("bert-large-uncased", role_type_idxs, role_type_idxs, ontology,
                          ontology, event_type_idxs, event_type_idxs, 0.3, 1024, 256, 128, 128,
                          0.1, "cpu")
# 上面的bert-large-uncased也可以传模型路径

# 加载模型参数
arg_path = 'checkpoint.pt'   # 这个文件是网盘的第二个文件
arg_model.load_state_dict(torch.load(arg_path, map_location='cpu'))

# 初始化角色和触发词编码
tokenizer = BertTokenizerFast.from_pretrained('bert-large-uncased')   # 这里也可以写本地路径
arg_model.compute_test_role_reprs(tokenizer)
arg_model.compute_train_role_reprs(tokenizer)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

此外,还需要实例化一个nltk tokenizer和一个spacy model:

import spacy
from nltk.tokenize import TreebankWordTokenizer

spacy_model = spacy.load("en_core_web_sm")
nltk_tokenizer = TreebankWordTokenizer()
  • 1
  • 2
  • 3
  • 4
  • 5

然后就可以用这个模型取预测论元了。需要在刚刚预测的触发词的基础上进行。

annotations = annotate_arguments(trigger_annotations, sent, arg_model, tokenizer, spacy_model,
                                 nltk_tokenizer)
# [[{'trigger': [15, 22, 'Justice:Charge-Indict'],
#    'arguments': [[0, 5, 'Defendant'],
#     [91, 96, 'Defendant'],
#     [130, 139, 'Defendant']]},
#   {'trigger': [38, 44, 'Life:Die'], 'arguments': [[104, 111, 'Place']]}]]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

3. 个人思考

首先我觉得这个思路没什么问题,但是直接拼接触发词和候选实体的编码,总感觉心里有点不踏实。这种做法并没有对模型结构做出实质的调整,而是直接拼接特征然后做了一个映射,相当于在两个相对独立的特征空间里边打锚点,利用这些锚点去拉进两个特征空间,而不是联合训练一个整体的特征空间。

熟悉transformers模块的同学可能用过其自带的快捷应用pipeline模块,其中就包含了zero-shot分类模型,只需要输入想要分类的类别去实例化一个分类器,然后再输入文本就可以实现分类了。其实那个zero-shot分类的原理,就是对类别名称进行了编码,然后再用每个类别的编码与文本编码去计算相似度,思想与今天介绍的论文其实是很接近的。

对于实际应用的场景,可能不如直接匹配和基于句法规则的弱监督方法。另外,个人认为QA结构的模型,无论是生成式还是判别式,在处理这类特征的能力上,应该比这个方法要更擅长一些。

另外就是候选实体的问题,作者利用的是NLP工具直接抽的,实际应用中一般会训练一个针对自己实际场景的NER模型,毕竟NER任务算是一个简单任务,在候选实体准确的前提下,再去算相似度找论元,效果或许会好一点。

如果是以研究角度,这个实验还是挺有价值的,以相似度去找论元,的确是一个非常值得一试的思路,后续可能也还有很多可以继续做的空间。

以上仅为个人观点。总之,这篇文章的思想不难,代码也不复杂,如果你在做事件抽取方面的工作,还是值得一试的。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/365868
推荐阅读
相关标签
  

闽ICP备14008679号