当前位置:   article > 正文

BERT模型实体关系抽取实战(医学领域)_bert 中医药 关系抽取 处理流程

bert 中医药 关系抽取 处理流程

BERT实体关系抽取

序言

项目参考了BioBERThttps://github.com/yuanxiaosc/Entity-Relation-Extraction两个BERT应用模型,BioBERT 是一种生物医学语言表示模型,专为生物医学命名实体识别、关系提取、问答等生物医学文本挖掘任务而设计。由于本项目是BERT在生物医学领域的研究与应用,因此使用了BioBERT训练好的模型作为初始模型,在Entity-Relation-Extraction项目上进行微调实现实体关系抽取。

项目部署

环境要求

Pycharm、TensorFlow 1.11.0和、Python2 和或Python3(TensorFlow 1.12.0、python3.6实测可运行)

程序目录

--Entity-Relation-Extraction(Medical)
	--.github
	--bert
	--bin
		--evaluation
		--predicate_classifiction
		--subject_object_labeling
			--ner_data
	--output
	--pretrained_model
		--biobert_v1.1_pubmed
	--raw_data
	--produce_submit_json_file.py
	--run_predicate_classification.py
	--run_sequnce_labeling.py
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

evaluation:模型评估程序文件
predicate_classifiction:数据预处理文件
subject_object_labeling:数据预处理文件
output:输出文件夹
pretrained_model:与训练模型文件
raw_data:原始数据文件

项目运行

模型训练

首先我们要了解原始数据的格式,即raw_data中数据的格式,其中包括:train_data.tsv、test1_data_postag.tsv、dev_data.tsv(训练数据、预测数据、评估数据)以及定义的关系文档relation.txt。关系文档中存放了我们定义的关系及其对应的实体类型,如({“object_type”: “treat”, “predicate”: “GENE”, “subject_type”: “DIEASE”})。

  • train_data.tsv文件数据格式:
{ "text": "SAMD11 is used to treat diabetes", "spo_list": [{"predicate": "treat", "object_type": "GENE", "subject_type": "DISEASE", "object": "SAMD11", "subject": "diabetes"}]}
  • 1

其中,"text"表示我们需要模型预测的一句话,"spo_list"中存放的则是这句话中所含有的关系及实体,按照定义,一种关系对应了两种实体,即<实体一,关系,实体二>,"predicate"表示的是预测出来的关系,如(treat),“object_type”、"subject_type"表示的是该关系对应的两种实体的类型,如(GENE、DIEASE),“object”、"subject"表示的这句话中的关系所对应的具体实体,如(SAMD11、diabetes)。

  • test1_data_postag.tsv文件数据格式:
{ "text": "SAMD11 is used to treat diabetes"}
  • 1

和train.tsv文件相比,预测数据只有需要输入模型的文本,没有预测结果值。

run_predicate_classification.py

该模型是对Google-Bert模型的数据处理模块及下游任务进行了微调操作,主要是预测出输入文本中具有的关系。

在训练模型之前,我们需要先对原始数据进行数据预处理操作,将其转换成输入模型的数据格式

python bin/predicate_classifiction/predicate_data_manager.py
  • 1

运行完上述数据预处理程序后会在predicate_classifiction文件夹下生成classification_data文件存放处理完成的数据,其中包括test、train、valid数据,train、dev文件夹中有:predicate_out.txt、text.txt、token_in.txt、token_in_not_UNK.txt数据,相较之下test中只是少了predicate_out.txt数据。

  • predicate_out.txt中存放的是每句话中的关系
treat
treat
treat
treat
treat
  • 1
  • 2
  • 3
  • 4
  • 5
  • text.txt中存放的是输入的文本语句
SAMD11 is used to treat diabetes
CD105 is used to treat neurodegenerative
CD34 is used to treat cardiovascular
Gata4 is used to treat auto-immunes diseases
FAM41C is used to treat myocardial infarction
  • 1
  • 2
  • 3
  • 4
  • 5
  • token_in.txt中存放的是对输入文本进行分词后的结果
SA ##MD ##11 is used to treat diabetes
CD ##10 ##5 is used to treat ne ##uro ##de ##gene ##rative
CD ##34 is used to treat card ##iovascular
G ##ata ##4 is used to treat auto - immune ##s diseases
FA ##M ##41 ##C is used to treat my ##oc ##ard ##ial in ##far ##ction
  • 1
  • 2
  • 3
  • 4
  • 5
  • token_in_not_UNK.txt同样也是分词后的结果
SA ##MD ##11 is used to treat diabetes
CD ##10 ##5 is used to treat ne ##uro ##de ##gene ##rative
CD ##34 is used to treat card ##iovascular
G ##ata ##4 is used to treat auto - immune ##s diseases
FA ##M ##41 ##C is used to treat my ##oc ##ard ##ial in ##far ##ction
  • 1
  • 2
  • 3
  • 4
  • 5

有了处理后的数据,接下来可以进行关系标注模型run_predicate_classification.py的训练,训练参数如下:

python run_predicate_classification.py/
	--task_name=SKE_2019
	--do_train=true
	--do_eval=false
	--data_dir=bin/predicate_classifiction/classification_data
	--vocab_file=pretrained_model/biobert_v1.1_pubmed/vocab.txt
	--bert_config_file=pretrained_model/biobert_v1.1_pubmed/bert_config.json
	--init_checkpoint=pretrained_model/biobert_v1.1_pubmed/model.ckpt-1000000
	--max_seq_length=128
	--train_batch_size=32
	--learning_rate=2e-5
	--num_train_epochs=6.0
	--output_dir=./output/predicate_classification_model/epochs6/
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

该模型是在Google-Bert模型的基础上进行了Fine-Tuning操作,主要修改了数据预处理模块及模型的下游任务模块

  • 数据处理模块
	//首先定义了需要标注的标签,根据实际需求添加
    def get_labels(self):
        return ['treat','cause','unlabel']
    //数据处理模块
    def convert_single_example(ex_index, example, label_list, max_seq_length,
                           tokenizer):
    """Converts a single `InputExample` into a single `InputFeatures`."""

    if isinstance(example, PaddingInputExample):
        return InputFeatures(
            input_ids=[0] * max_seq_length,
            input_mask=[0] * max_seq_length,
            segment_ids=[0] * max_seq_length,
            label_ids=[0] * len(label_list),
            is_real_example=False)

    label_map = {}
    for (i, label) in enumerate(label_list):
        label_map[label] = i
    text = example.text_a

    tokens_a = example.text_a.split(" ")
    tokens_b = None
    if example.text_b:
        tokens_b = tokenizer.tokenize(example.text_b)

    if tokens_b:
        # Modifies `tokens_a` and `tokens_b` in place so that the total
        # length is less than the specified length.
        # Account for [CLS], [SEP], [SEP] with "- 3"
        _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
    else:
        # Account for [CLS] and [SEP] with "- 2"
        if len(tokens_a) > max_seq_length - 2:
            tokens_a = tokens_a[0:(max_seq_length - 2)]
    tokens = []
    segment_ids = []
    tokens.append("[CLS]")
    segment_ids.append(0)
    for token in tokens_a:
        tokens.append(token)
        segment_ids.append(0)
    tokens.append("[SEP]")
    segment_ids.append(0)

    if tokens_b:
        for token in tokens_b:
            tokens.append(token)
            segment_ids.append(1)
        tokens.append("[SEP]")
        segment_ids.append(1)

    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    # The mask has 1 for real tokens and 0 for padding tokens. Only real
    # tokens are attended to.
    input_mask = [1] * len(input_ids)

    # Zero-pad up to the sequence length.
    while len(input_ids) < max_seq_length:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length

    label_list = example.label.split(" ")
    label_ids = _predicate_label_to_id(label_list, label_map)

    if ex_index < 5:
        tf.logging.info("*** Example ***")
        tf.logging.info("guid: %s" % (example.guid))
        tf.logging.info("tokens: %s" % " ".join(
            [tokenization.printable_text(x) for x in tokens]))
        tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
        tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
        tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
        tf.logging.info("label_ids: %s" % " ".join([str(x) for x in label_ids]))

    feature = InputFeatures(
        input_ids=input_ids,
        input_mask=input_mask,
        segment_ids=segment_ids,
        label_ids=label_ids,
        is_real_example=True)
    return feature
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88

值得注意的是:由于我们输入的是一句话,因此只有text_a,而text_b则定义为None,数据处理后我们得到了guid、tokens、input_ids、input_mask、segment_ids、label_ids数据,如输入:"‘SA ##MD ##11 is used to treat diabetes’":

guid:'train-0'
tokens:['[CLS]', 'SA', '##MD', '##11', 'is', 'used', 'to', 'treat', 'diabetes', '[SEP]']
input_ids:[101, 13411, 18219, 14541, 1110, 1215, 1106, 7299, 17972, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...
input_mask:[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...
segment_ids:[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...
label_ids:[1, 0]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 模型下游任务
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
                 labels, num_labels, use_one_hot_embeddings):
    """Creates a classification model."""
    model = modeling.BertModel(
        config=bert_config,
        is_training=is_training,
        input_ids=input_ids,
        input_mask=input_mask,
        token_type_ids=segment_ids,
        use_one_hot_embeddings=use_one_hot_embeddings)
      
    output_layer = model.get_pooled_output() //获取[CLS]返回的特征向量
    hidden_size = output_layer.shape[-1].value //获取特征向量的维度;Eg:768
    //创建分类概率矩阵,维度(标签个数,向量维度);Eg:(2,768)
    output_weights = tf.get_variable(
        "output_weights", [num_labels, hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02))
    //构建偏置矩阵
    output_bias = tf.get_variable(
        "output_bias", [num_labels], initializer=tf.zeros_initializer())
    with tf.variable_scope("loss"):
        if is_training:
            # I.e., 0.1 dropout
            output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
        logits_wx = tf.matmul(output_layer, output_weights, transpose_b=True) //输出特征矩阵乘分类权重矩阵
        logits = tf.nn.bias_add(logits_wx, output_bias) //加上偏置矩阵
        probabilities = tf.sigmoid(logits) //使用sigmoid函数做概率映射
        label_ids = tf.cast(labels, tf.float32) //将实际标签映射成指定数值序列
        per_example_loss = tf.reduce_sum(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=label_ids), axis=-1) //使用交叉熵计算损失函数
        loss = tf.reduce_mean(per_example_loss)
        return loss, per_example_loss, logits, probabilities
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

由于是预测句子中的关系,是一个简单的多分类问题,因此只用get_pooled_output()获取[CLS]标签的词向量,另外,使用了sigmoid作为激活函数将词向量映射成标签的概率值。

模型训练完成后会在output文件夹中写入predicate_classification_model文件保存训练好的模型。

run_sequnce_labeling.py

该模型是对Google-Bert模型的数据处理模块及下游任务进行了微调,主要是实现实体及关系标注。

在训练模型之前,我么们同样要先进行数据预处理操作,将其转换为模型输入的数据格式。

python bin/subject_object_labeling/subject_object_labeling.py
  • 1

运行完上述数据预处理程序后会在subject_object_labeling文件夹下生成sequence_labeling_data文件存放处理完成的数据,其中包括test、train、valid数据,其中,train、valid文件夹中包括:bert_tokener_error_log.txt、text.txt、token_in.txt、token_in_not_UNK.txt、token_label_and_one_prdicate_ou.txtt数据。

  • bert_tokener_error_log.txt存放的是由于分词导致错误的数据文件
  • text.txt存放的是输入语句,若在一句话中有多个关系,则将这句话重复多遍
SAMD11 is used to treat diabetes, but it may lead to auto-immunes diseases
SAMD11 is used to treat diabetes, but it may lead to auto-immunes diseases
CD105 is used to treat neurodegenerative
CD34 is used to treat cardiovascular
Gata4 is used to treat auto-immunes diseases
  • 1
  • 2
  • 3
  • 4
  • 5
  • token_in.txt存放的是句子的分词结果及对应的关系
SA ##MD ##11 is used to treat diabetes , but it may lead to my ##oc ##ard ##ial in ##far ##ction		treat
SA ##MD ##11 is used to treat diabetes , but it may lead to my ##oc ##ard ##ial in ##far ##ction	cause
CD ##10 ##5 is used to treat ne ##uro ##de ##gene ##rative	treat
CD ##34 is used to treat card ##iovascular	treat
G ##ata ##4 is used to treat auto - immune ##s diseases	treat
  • 1
  • 2
  • 3
  • 4
  • 5
  • token_in_not_UNK.txt中存放的同样也是句子的分词结果及对应的关系
SA ##MD ##11 is used to treat diabetes , but it may lead to my ##oc ##ard ##ial in ##far ##ction		treat
SA ##MD ##11 is used to treat diabetes , but it may lead to my ##oc ##ard ##ial in ##far ##ction	cause
CD ##10 ##5 is used to treat ne ##uro ##de ##gene ##rative	treat
CD ##34 is used to treat card ##iovascular	treat
G ##ata ##4 is used to treat auto - immune ##s diseases	treat
  • 1
  • 2
  • 3
  • 4
  • 5
  • token_label_and_one_prdicate_out.txt中存放的是句子对应的标签及关系
B-GENE I-GENE I-GENE O O O O B-DIEASE	treat
B-OBJ I-OBJ I-OBJ O O O O B-DIEASE I-DIEASE I-DIEASE I-DIEASE I-DIEASE	treat
B-GENE I-GENE O O O O B-DIEASE I-DIEASE	treat
B-OBJ I-OBJ I-OBJ O O O O B-DIEASE I-DIEASE I-DIEASE I-DIEASE I-DIEASE	treat
B-GENE I-GENE I-GENE I-GENE O O O O B-DIEASE I-DIEASE I-DIEASE I-DIEASE I-DIEASE I-DIEASE I-DIEASE	treat
  • 1
  • 2
  • 3
  • 4
  • 5

有了处理后的数据,接下来可以进行实体及关系标注模型run_sequnce_labeling.py的训练,训练参数如下:

--task_name=SKE_2019
--do_train=true
--do_eval=false
--data_dir=bin/subject_object_labeling/sequence_labeling_data
--vocab_file=pretrained_model/biobert_v1.1_pubmed/vocab.txt
--bert_config_file=pretrained_model/biobert_v1.1_pubmed/bert_config.json
--init_checkpoint=pretrained_model/biobert_v1.1_pubmed/model.ckpt-1000000
--max_seq_length=128
--train_batch_size=32
--learning_rate=2e-5
--num_train_epochs=9.0
--output_dir=./output/sequnce_labeling_model/epochs9/
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

该模型是对Google-Bert模型的数据预处理模块及下游任务进行了微调,进而实现实体关系标注:

  • 数据处理模块
	//定以标注标签
    def get_token_labels(self):
        BIO_token_labels = ["[Padding]", "[category]", "[##WordPiece]", "[CLS]", "[SEP]", "B-GENE", "I-GENE", "B-DIEASE","I-DIEASE", "O",'B-SUB','I-SUB','B-OBJ','I-OBJ']  # id 0 --> [Paddding]
        return BIO_token_labels
    def get_predicate_labels(self):
        return ['treat', 'cause']
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

该模型要同时预测实体和关系,因此定义了get_token_labels、get_predicate_labels两类标签,其中get_token_labels是对实体进行标注,get_predicate_labels是对关系进行标注。

def convert_single_example(ex_index, example, token_label_list, predicate_label_list, max_seq_length,
                           tokenizer):
    """Converts a single `InputExample` into a single `InputFeatures`."""
    if isinstance(example, PaddingInputExample):
        return InputFeatures(
            input_ids=[0] * max_seq_length,
            input_mask=[0] * max_seq_length,
            segment_ids=[0] * max_seq_length,
            token_label_ids=[0] * max_seq_length,
            predicate_label_id = [0],
            is_real_example=False)

    token_label_map = {} //将token_labels标签映射成数值字典
    for (i, label) in enumerate(token_label_list):
        token_label_map[label] = i

    predicate_label_map = {} //将predicate_labels标签映射成数值字典
    for (i, label) in enumerate(predicate_label_list):
        predicate_label_map[label] = i

    text_token = example.text_token.split("\t")[0].split(" ") //获取输入的一句话
    if example.token_label is not None:
        token_label = example.token_label.split("\t")[0].split(" ") //存放分词结果对应的token_label标签
    else:
        token_label = ["O"] * len(text_token)
    assert len(text_token) == len(token_label)

    text_predicate = example.text_token.split("\t")[1] //获取输入一句话对应的predicate_labels标签
    if example.token_label is not None:
        token_predicate = example.token_label.split("\t")[1] //存放输入一句话对应的predicate_labels标签
    else:
        token_predicate = text_predicate
    assert text_predicate == token_predicate

    tokens_b = [text_predicate] * len(text_token) //存放关系标签,重复展开与text_a一样长
    predicate_id = predicate_label_map[text_predicate] //将关系标签映射成数字序列
    _truncate_seq_pair(text_token, tokens_b, max_seq_length - 3)
    tokens = []
    token_label_ids = []
    segment_ids = []
    tokens.append("[CLS]")
    segment_ids.append(0)
    token_label_ids.append(token_label_map["[CLS]"])

    for token, label in zip(text_token, token_label):
        tokens.append(token)
        segment_ids.append(0)
        token_label_ids.append(token_label_map[label])

    tokens.append("[SEP]")
    segment_ids.append(0)
    token_label_ids.append(token_label_map["[SEP]"])

    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    #bert_tokenizer.convert_tokens_to_ids(["[SEP]"]) --->[102]
    bias = 1 //1-100 dict index not used
    for token in tokens_b:
      //将关系标签词向量加入输入语句对应的词向量中,即增加关系信息
      input_ids.append(predicate_id + bias) //add  bias for different from word dict
      segment_ids.append(1)
      token_label_ids.append(token_label_map["[category]"])

    input_ids.append(tokenizer.convert_tokens_to_ids(["[SEP]"])[0]) //102
    segment_ids.append(1)
    token_label_ids.append(token_label_map["[SEP]"])

    while len(input_ids) < max_seq_length:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)
        token_label_ids.append(0)
        tokens.append("[Padding]")

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length
    assert len(token_label_ids) == max_seq_length

    if ex_index < 5:
        tf.logging.info("*** Example ***")
        tf.logging.info("guid: %s" % (example.guid))
        tf.logging.info("tokens: %s" % " ".join(
            [tokenization.printable_text(x) for x in tokens]))
        tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
        tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
        tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
        tf.logging.info("token_label_ids: %s" % " ".join([str(x) for x in token_label_ids]))
        tf.logging.info("predicate_id: %s" % str(predicate_id))

    feature = InputFeatures(
        input_ids=input_ids,
        input_mask=input_mask,
        segment_ids=segment_ids,
        token_label_ids=token_label_ids,
        predicate_label_id=[predicate_id],
        is_real_example=True)
    return feature
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97

值得注意的是:
1、模型的text_b不再是None,存放的是该句话对应的关系标签,其长度与text_a相同,对应的token_label为[category]
2、在input_ids输入词向量序列中加入了关系的词向量

数据处理后我们得到了guid、tokens、input_ids、input_mask、segment_ids、token_label_ids、predicate_id数据,如输入:"‘SA ##MD ##11 is used to treat diabetes treat’“和其对应的标签”‘B-GENE I-GENE I-GENE O O O O B-DIEASE treat’":

guid:'train-0'
tokens:[CLS] SA ##MD ##11 is used to treat diabetes [SEP] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding]
input_ids: 101 13411 18219 14541 1110 1215 1106 7299 17972 102 1 1 1 1 1 1 1 1 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
input_mask:1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
segment_ids: 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
token_label_ids:3 5 6 6 9 9 9 9 7 4 1 1 1 1 1 1 1 1 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
predicate_id:0
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 模型下游任务
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
                 token_label_ids, predicate_label_id, num_token_labels, num_predicate_labels,
                 use_one_hot_embeddings):
    """Creates a classification model."""
    model = modeling.BertModel(
        config=bert_config,
        is_training=is_training,
        input_ids=input_ids,
        input_mask=input_mask,
        token_type_ids=segment_ids,
        use_one_hot_embeddings=use_one_hot_embeddings)
    //关系预测任务
    predicate_output_layer = model.get_pooled_output()
    intent_hidden_size = predicate_output_layer.shape[-1].value
    predicate_output_weights = tf.get_variable(
        "predicate_output_weights", [num_predicate_labels, intent_hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02))

    predicate_output_bias = tf.get_variable(
        "predicate_output_bias", [num_predicate_labels], initializer=tf.zeros_initializer())

    with tf.variable_scope("predicate_loss"):
        if is_training:
            # I.e., 0.1 dropout
            predicate_output_layer = tf.nn.dropout(predicate_output_layer, keep_prob=0.9)

        predicate_logits = tf.matmul(predicate_output_layer, predicate_output_weights, transpose_b=True)
        predicate_logits = tf.nn.bias_add(predicate_logits, predicate_output_bias)
        predicate_probabilities = tf.nn.softmax(predicate_logits, axis=-1)
        predicate_prediction = tf.argmax(predicate_probabilities, axis=-1, output_type=tf.int32)
        predicate_labels = tf.one_hot(predicate_label_id, depth=num_predicate_labels, dtype=tf.float32)
        predicate_per_example_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=predicate_logits, labels=predicate_labels), -1)
        predicate_loss = tf.reduce_mean(predicate_per_example_loss)
     
    //实体标注任务
    token_label_output_layer = model.get_sequence_output()

    token_label_hidden_size = token_label_output_layer.shape[-1].value

    token_label_output_weight = tf.get_variable(
        "token_label_output_weights", [num_token_labels, token_label_hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02)
    )
    token_label_output_bias = tf.get_variable(
        "token_label_output_bias", [num_token_labels], initializer=tf.zeros_initializer()
    )
    with tf.variable_scope("token_label_loss"):
        if is_training:
            token_label_output_layer = tf.nn.dropout(token_label_output_layer, keep_prob=0.9)
        token_label_output_layer = tf.reshape(token_label_output_layer, [-1, token_label_hidden_size])
        token_label_logits = tf.matmul(token_label_output_layer, token_label_output_weight, transpose_b=True)
        token_label_logits = tf.nn.bias_add(token_label_logits, token_label_output_bias)

        token_label_logits = tf.reshape(token_label_logits, [-1, FLAGS.max_seq_length, num_token_labels])
        token_label_log_probs = tf.nn.log_softmax(token_label_logits, axis=-1)
        token_label_one_hot_labels = tf.one_hot(token_label_ids, depth=num_token_labels, dtype=tf.float32)
        token_label_per_example_loss = -tf.reduce_sum(token_label_one_hot_labels * token_label_log_probs, axis=-1)
        token_label_loss = tf.reduce_sum(token_label_per_example_loss)
        token_label_probabilities = tf.nn.softmax(token_label_logits, axis=-1)
        token_label_predictions = tf.argmax(token_label_probabilities, axis=-1)
        # return (token_label_loss, token_label_per_example_loss, token_label_logits, token_label_predict)
    //模型损失值计算
    loss = 0.5 * predicate_loss + token_label_loss
    return (loss,
            predicate_loss, predicate_per_example_loss, predicate_probabilities, predicate_prediction,
            token_label_loss, token_label_per_example_loss, token_label_logits, token_label_predictions)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66

由于该模型需要对实体和关系分别进行标注,因此定义了两个模型输出参数predicate_output_layertoken_label_output_layer,以及两个全连接层predicate_losstoken_label_loss。其中,predicate_output_layer是对关系进行预测,是一个多标签分类问题,因此只需要使用get_pooled_output()获取到[CLS]标签的词向量,而token_label_output_layer则是对句子中的实体进行标注,因此需要使用get_sequence_output()获取整个句子的词向量。

此外,还需要注意的是该模型的损失值loss是两个全连接层的损失值的综合:loss = 0.5 * predicate_loss + token_label_loss

模型训练完之后,会在output文件夹中写入sequnce_labeling_model文件保存训练好的模型。

模型预测

模型预测与训练一样,先由run_predicate_classification模型预测出测试数据集中每一句话的关系,再由run_sequnce_labeling模型标注出预测句子中的实体及关系。

run_predicate_classification.py

我们利用训练好的模型对测试文件进行关系预测,运行参数如下:

--task_name=SKE_2019
--do_predict=true
--data_dir=bin/predicate_classifiction/classification_data
--vocab_file=pretrained_model/biobert_v1.1_pubmed/vocab.txt
--bert_config_file=pretrained_model/biobert_v1.1_pubmed/bert_config.json
--init_checkpoint=output/predicate_classification_model/epochs6/model.ckpt-0
--max_seq_length=128
--output_dir=./output/predicate_infer_out/epochs6/ckpt0
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

值得注意的是再数据预处理模块中,由于测试数据集没有关系标签,因此为了和预测时的输入数据保持一致,需要先添加一个unlabel标签

def _create_example(self, lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    for (i, line) in enumerate(lines):
        guid = "%s-%s" % (set_type, i)
        if set_type == "test":
            text_str = line
            predicate_label_str = 'unlabel'
        else:
            text_str = line[0]
            predicate_label_str = line[1]
        examples.append(
            InputExample(guid=guid, text_a=text_str, text_b=None, label=predicate_label_str))
    return examples
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

数据处理后我们得到了guid、tokens、input_ids、input_mask、segment_ids、label_ids数据,如输入:"‘SA ##MD ##11 is used to treat diabetes’":

guid:test-0
tokens:['[CLS]', 'SA', '##MD', '##11', 'is', 'used', 'to', 'treat', 'diabetes', '[SEP]']
input_ids:[101, 13411, 18219, 14541, 1110, 1215, 1106, 7299, 17972, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...
input_mask:[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...
segment_ids:[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...
label_ids:[0, 1]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

模型及下游任务在训练时已讲过,接下来我们看下模型的关系预测输出:

tf.logging.info("***** Running prediction*****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(predict_examples), num_actual_predict_examples,
                        len(predict_examples) - num_actual_predict_examples)
        tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)
        predict_drop_remainder = True if FLAGS.use_tpu else False
        predict_input_fn = file_based_input_fn_builder(
            input_file=predict_file,
            seq_length=FLAGS.max_seq_length,
            label_length=label_length,
            is_training=False,
            drop_remainder=predict_drop_remainder)
        result = estimator.predict(input_fn=predict_input_fn)
        output_score_value_file = os.path.join(FLAGS.output_dir, "predicate_score_value.txt")
        output_predicate_predict_file = os.path.join(FLAGS.output_dir, "predicate_predict.txt")
        with tf.gfile.GFile(output_score_value_file, "w") as score_value_writer:
            with tf.gfile.GFile(output_predicate_predict_file, "w") as predicate_predict_writer:
                num_written_lines = 0
                tf.logging.info("***** Predict results *****")
                for (i, prediction) in enumerate(result):
                    probabilities = prediction["probabilities"]
                    if i >= num_actual_predict_examples:
                        break
                    output_line_score_value = " ".join(
                        str(class_probability)
                        for class_probability in probabilities) + "\n"
                    predicate_predict = []
                    for idx, class_probability in enumerate(probabilities):
                        if class_probability > 0.5:
                            predicate_predict.append(label_list[idx])
                    output_line_predicate_predict = " ".join(predicate_predict) + "\n"               predicate_predict_writer.write(output_line_predicate_predict)
                    score_value_writer.write(output_line_score_value)
                    num_written_lines += 1
        assert num_written_lines == num_actual_predict_examples
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

在预测程序中我们定义了两个预测文件predicate_score_value.txtpredicate_predict.txt,其中,predicate_score_value.txt文件用来存放模型预测出来测试数据中每一句话对应到不同关系的概率值,predicate_predict.txt文件则用来存放测试数据集中每一句话对应的关系概率值大于0.5的概率。

模型预测完成后,会在output文件夹下新建predicate_infer_out\epochs6\ckpt0文件夹,其中存放了predicate_predict、predicate_score_value、predict三个文件:

  • predicate_predict
treat cause
treat unlabel
treat
treat
treat unlabel
  • 1
  • 2
  • 3
  • 4
  • 5
  • predicate_score_value
0.5944473 0.5069371 0.498385
0.5740756  0.498385 0.5615229
0.5615229 0.47858068 0.47900787
0.5729883 0.49133754 0.47858068
0.6151916 0.5069371 0.4920553
  • 1
  • 2
  • 3
  • 4
  • 5

有了测试数据集中的关系预测结果,接下来就是通过run_sequnce_labeling模型对其实体及关系进行标注,在此之前,我们需要先对run_predicate_classification模型的输出结果进行预处理。

python bin/predicate_classifiction/prepare_data_for_labeling_infer.py
  • 1

数据预处理的目的是将输入的测试数据按照run_predicate_classification模型预测出来的关系标签对应,如果一句话对应了多个标签则将这句话重复多遍和关系一一对应,值得注意的是句子之中预测出来的unlabel将被删除,代表这句话中没有标签。

数据预处理完成后会在bin/subject_object_labeling/sequence_labeling_data文件夹下新建test文件夹,其中包括了text_and_one_predicate、token_in_and_one_predicate、token_in_not_UNK_and_one_predicate文档:
text_and_one_predicate.txt中存放的是测试数据集及预测数来的关系:

SAMD11 is used to treat diabetes, but it may lead to auto-immunes diseases	cause
SAMD11 is used to treat diabetes, but it may lead to auto-immunes diseases	cause
CD105 is used to treat neurodegenerative	treat
CD34 is used to treat cardiovascular	treat
Gata4 is used to treat auto-immunes diseases	treat
FAM41C is used to treat myocardial infarction	treat
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

token_in_and_one_predicate.txt中存放的是测试数据集的分词结果及对应的关系:

SA ##MD ##11 is used to treat diabetes , but it may lead to my ##oc ##ard ##ial in ##far ##ction		treat
SA ##MD ##11 is used to treat diabetes , but it may lead to my ##oc ##ard ##ial in ##far ##ction	cause
CD ##10 ##5 is used to treat ne ##uro ##de ##gene ##rative	treat
CD ##34 is used to treat card ##iovascular	treat
G ##ata ##4 is used to treat auto - immune ##s diseases	treat
FA ##M ##41 ##C is used to treat my ##oc ##ard ##ial in ##far ##ction	treat
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

token_in_not_UNK_and_one_predicate.txt存放的是测试是测试数据集分词结果及关系标签:

SA ##MD ##11 is used to treat diabetes , but it may lead to my ##oc ##ard ##ial in ##far ##ction		treat
SA ##MD ##11 is used to treat diabetes , but it may lead to my ##oc ##ard ##ial in ##far ##ction	cause
CD ##10 ##5 is used to treat ne ##uro ##de ##gene ##rative	treat
CD ##34 is used to treat card ##iovascular	treat
G ##ata ##4 is used to treat auto - immune ##s diseases	treat
FA ##M ##41 ##C is used to treat my ##oc ##ard ##ial in ##far ##ction	treat
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

我们注意到前两句话其实是输入测试数据集中的第一句话重复了两遍,只是因为run_predicate_classification模型预测出这句话中具有两个关系,因此将其重复两遍,目的是让run_sequnce_labeling模型对输入序列进行单一的关系抽取。

run_sequnce_labeling.py

我们利用训练好的模型对预处理后的数据进行实体关系抽取,运行参数如下:

--task_name=SKE_2019
--do_predict=true
--data_dir=bin/subject_object_labeling/sequence_labeling_data
--vocab_file=pretrained_model/biobert_v1.1_pubmed/vocab.txt
--bert_config_file=pretrained_model/biobert_v1.1_pubmed/bert_config.json
--init_checkpoint=output/sequnce_labeling_model/epochs9/model.ckpt-0
--max_seq_length=128
--output_dir=./output/sequnce_infer_out/epochs9/ckpt0
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

由于该模型预测时的测试数据没有训练时的训练数据的token_labels层的实体标签,因此在数据预处理阶段将其全部标"O",和训练时相同,将"text_b"以关系标签展开成"text_a"相同长度的序列,在实际输入序列后同样加入"text_b"的序列构成"input_ids"。

数据处理后我们得到了guid、tokens、input_ids、input_mask、segment_ids、token_label_ids、predicate_id数据,如输入:"‘SA ##MD ##11 is used to treat diabetes treat’“和其对应的标签”‘B-GENE I-GENE I-GENE O O O O B-DIEASE treat’":

guid:test-0
tokens:[CLS] SA ##MD ##11 is used to treat diabetes [SEP] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding] [Padding]
input_ids:101 13411 18219 14541 1110 1215 1106 7299 17972 102 1 1 1 1 1 1 1 1 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
input_mask:1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
segment_ids:0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
token_label_ids:3 9 9 9 9 9 9 9 9 4 1 1 1 1 1 1 1 1 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
predicate_id:0
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

模型及下游任务在训练时已讲,接下来我们看模型的实体关系预测和输出:

tf.logging.info("***** Running prediction*****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(predict_examples), num_actual_predict_examples,
                        len(predict_examples) - num_actual_predict_examples)
        tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_drop_remainder = True if FLAGS.use_tpu else False
        predict_input_fn = file_based_input_fn_builder(
            input_file=predict_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=predict_drop_remainder)

        result = estimator.predict(input_fn=predict_input_fn)
        token_label_output_predict_file = os.path.join(FLAGS.output_dir, "token_label_predictions.txt")
        predicate_output_predict_file = os.path.join(FLAGS.output_dir, "predicate_predict.txt")
        predicate_output_probabilities_file = os.path.join(FLAGS.output_dir, "predicate_probabilities.txt")
        with open(token_label_output_predict_file, "w", encoding='utf-8') as token_label_writer:
            with open(predicate_output_predict_file, "w", encoding='utf-8') as predicate_predict_writer:
                with open(predicate_output_probabilities_file, "w", encoding='utf-8') as predicate_probabilities_writer:
                    num_written_lines = 0
                    tf.logging.info("***** token_label predict and predicate labeling results *****")
                    for (i, prediction) in enumerate(result):
                        token_label_prediction = prediction["token_label_predictions"]
                        predicate_probabilities = prediction["predicate_probabilities"]
                        predicate_prediction = prediction["predicate_prediction"]
                        if i >= num_actual_predict_examples:
                            break
                        token_label_output_line = " ".join(token_label_id2label[id] for id in token_label_prediction) + "\n"
                        token_label_writer.write(token_label_output_line)
                        predicate_predict_line = predicate_label_id2label[predicate_prediction]
                        predicate_predict_writer.write(predicate_predict_line + "\n")
                        predicate_probabilities_line = " ".join(str(sigmoid_logit) for sigmoid_logit in predicate_probabilities) + "\n"
                        predicate_probabilities_writer.write(predicate_probabilities_line)
                        num_written_lines += 1
        assert num_written_lines == num_actual_predict_examples
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36

模型运行后会在output文件夹下新建sequnce_infer_out\epochs9\ckpt0文件,包括predicate_predict、predicate_probabilities、token_label_predictions三个文件,其中predicate_predict.txtpredicate_probabilities.txt是对测试数据集中句子关系的预测文件,token_label_predictions则是对输入句子进行实体标注的文件。

predicate_probabilities.txt文件中存放的是每一个句子映射到全部关系标签上的概率值(总共有3个标签,对应3个概率值):

0.5944473 0.5069371 0.498385
0.5740756  0.498385 0.5615229
0.5615229 0.47858068 0.47900787
0.5729883 0.49133754 0.47858068
0.6151916 0.5069371 0.4920553
  • 1
  • 2
  • 3
  • 4
  • 5

predicate_predict.txt文件中存放的是一个句子对应所有标签概率中最大的概率所对应的标签:

treat
cause
treat
treat
treat
treat
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

token_label_predictions.txt文件中存放的是对输入测试数据集的实体标签:

B-GENE I-GENE I-GENE O O O O B-DIEASE
B-OBJ I-OBJ I-OBJ O O O O B-DIEASE I-DIEASE I-DIEASE I-DIEASE I-DIEASE	
B-GENE I-GENE O O O O B-DIEASE I-DIEASE
B-OBJ I-OBJ I-OBJ O O O O B-DIEASE I-DIEASE I-DIEASE I-DIEASE I-DIEASE	
B-GENE I-GENE I-GENE I-GENE O O O O B-DIEASE I-DIEASE I-DIEASE I-DIEASE I-DIEASE I-DIEASE I-DIEASE
  • 1
  • 2
  • 3
  • 4
  • 5

至此,我们对输入的测试数据集进行了关系及对应的实体标注,获得了每一句话对应的关系及关系对应的实体,接下来通过运行produce_submit_json_file.py将模型输出结果进行组合成最终的输出结果:

python produce_submit_json_file.py
  • 1

程序运行完将在output文件夹下新建final_text_spo_list_result文件夹存放最终结果keep_empty_spo_list_subject_predicate_object_predict_output.txt,该文档中保存了预测数据集对应的实体及关系:

{ "text": "SAMD11 is used to treat diabetes", "spo_list": [{"predicate": "treat", "object_type": "GENE", "subject_type": "DISEASE", "object": "SAMD11", "subject": "diabetes"}]}
{ "text": "CD105 is used to treat neurodegenerative", "spo_list": [{"predicate": "treat", "object_type": "GENE", "subject_type": "DISEASE", "object": "CD105", "subject": "neurodegenerative"}]}
{ "text": "CD34 is used to treat cardiovascular", "spo_list": [{"predicate": "treat", "object_type": "GENE", "subject_type": "DISEASE", "object": "CD34", "subject": "cardiovascular"}]}
{ "text": "Gata4 is used to treat auto-immunes diseases", "spo_list": [{"predicate": "treat", "object_type": "GENE", "subject_type": "DISEASE", "object": "Gata4", "subject": "auto-immunes diseases"}]}
{ "text": "FAM41C is used to treat myocardial infarction", "spo_list": [{"predicate": "treat", "object_type": "GENE", "subject_type": "DISEASE", "object": "FAM41C", "subject": "myocardial infarction"}]}
  • 1
  • 2
  • 3
  • 4
  • 5

总结

整个项目使用了两个模型run_predicate_classification.pyrun_sequnce_labeling.py,都是在Google-Bert模型的基础上进行了微调,其中,第一个模型主要是用来对句子中的关系进行初步预测,将一个句子中含有的多种关系转换成句子与关系一一对应的形式,第二个模型主要是进行实体及关系抽取,因此定义了两个下游任务,分别进行实体和关系的预测,再通过数据处理操作将模型的预测结果组合成需要的形式。

项目使用了BioBERT训练好的模型作为初始模型,在Entity-Relation-Extraction项目上进行微调实现实体关系抽取在医学领域的应用。项目模型下载:Entity-Relation-Extraction(Medical).zip,如有侵权请及时私信删除!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/344558
推荐阅读
相关标签
  

闽ICP备14008679号