当前位置:   article > 正文

【实体关系抽取】之——OneRel代码学习笔记(一)_onerel 模型代码实战

onerel 模型代码实战

train.py

import config
import framework
import argparse
import models
import os
import torch
import numpy as np
import random

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# 在结果中使用随机种子,以实现结果的可复现性
seed = 2179
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# 使用确定性算法,确保结果的可复现性
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# 使用Python标准库中的argparse模块,用于解析命令行参数
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='OneRel', help='name of the model')
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--dropout_prob', type=float, default=0.2)
parser.add_argument('--entity_pair_dropout', type=float, default=0.1)
parser.add_argument('--multi_gpu', type=bool, default=False)
parser.add_argument('--dataset', type=str, default='DUIE')
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--max_epoch', type=int, default=40)
parser.add_argument('--test_epoch', type=int, default=1)
parser.add_argument('--train_prefix', type=str, default='train_triples')
parser.add_argument('--dev_prefix', type=str, default='dev_triples')
parser.add_argument('--test_prefix', type=str, default='test_triples')
parser.add_argument('--max_len', type=int, default=100)
parser.add_argument('--bert_max_len', type=int, default=200)
parser.add_argument('--rel_num', type=int, default=48)
parser.add_argument('--period', type=int, default=100)
parser.add_argument('--debug', type=bool, default=False)
# 解析命令行参数并将结果存储在args变量中
args = parser.parse_args()

con = config.Config(args)

fw = framework.Framework(con)

model = {
    'OneRel': models.RelModel
}

fw.train(model[args.model_name])

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53

补充知识:

Dropout率:

Dropout是一种在神经网络中常用的正则化技术。
Dropout率指的是在训练过程中随机丢弃神经网络中的一部分单元(神经元)的比例。
(1)

  • 在神经网络中,每个神经元都会以一定的权重连接到下一层的神经元或输出层。Dropout通过在训练过程中以一定的概率丢弃一些神经元,即将它们的输出置为零,来减少模型中神经元之间的依赖关系。
  • 通过随机丢弃神经元,网络训练过程中的不同迭代步骤中会使用不同的神经元子集,这相当于训练了很多不同的神经网络,并将它们的预测结果进行平均或集成,从而提高模型的泛化能力和鲁棒性。
    (2)
    Dropout可以减少过拟合现象,使得模型更加健壮和可泛化。通过丢弃神经元,模型不仅可以学习到每个独立的神经元特征,还可以学习到神经元之间的组合特征,使得模型对于输入数据的变化和噪声具有更好的适应性。
    (3)
    Dropout率是一个控制丢弃概率的参数,通常取值范围在0到1之间。例如,一个Dropout率为0.2表示在训练过程中每次迭代中有20%的神经元会被随机丢弃。
    一般来说,较小的Dropout率可以有效地减少过拟合,而较大的Dropout率可能会导致模型欠拟合。
    (4)
    在测试和预测阶段,Dropout是关闭的,所有神经元都被保留,并且每个神经元的输出会乘以(1 - Dropout率),以保持模型的期望输出一致。这可以保证模型在测试和预测时具有较好的效果。

con = config.Config(args)

这里调用了config.Config

class Config(object):
    def __init__(self, args):
        self.args = args

        self.multi_gpu = args.multi_gpu
        self.learning_rate = args.lr
        self.batch_size = args.batch_size
        self.max_epoch = args.max_epoch
        self.max_len = args.max_len
        self.rel_num = args.rel_num
        self.bert_max_len = args.bert_max_len
        # 预训练的Bert模型(如bert-base-chinese)通常具有768维的隐藏状态输出,这里bert_dim被设置为768是为了与预训练的Bert模型保持一致。
        self.bert_dim = 768
        self.tag_size = 4
        self.dropout_prob = args.dropout_prob
        self.entity_pair_dropout = args.entity_pair_dropout

        # dataset
        self.dataset = args.dataset

        # path and name
        self.data_path = './data/' + self.dataset
        self.checkpoint_dir = './checkpoint/' + self.dataset
        self.log_dir = './log/' + self.dataset
        self.result_dir = './result/' + self.dataset
        self.train_prefix = args.train_prefix
        self.dev_prefix = args.dev_prefix
        self.test_prefix = args.test_prefix


        self.model_save_name = args.model_name + '_DATASET_' + self.dataset + "_LR_" + str(self.learning_rate) + "_BS_" + str(self.batch_size) + "Max_len" + str(self.max_len) + "Bert_ML" + str(self.bert_max_len) + "DP_" + str(self.dropout_prob) + "EDP_" + str(self.entity_pair_dropout)
        self.log_save_name = 'LOG_' + args.model_name + '_DATASET_' + self.dataset + "_LR_" + str(self.learning_rate) + "_BS_" + str(self.batch_size) + "Max_len" + str(self.max_len) + "Bert_ML" + str(self.bert_max_len) + "DP_" + str(self.dropout_prob) + "EDP_" + str(self.entity_pair_dropout)
        self.result_save_name = 'RESULT_' + args.model_name + '_DATASET_' + self.dataset + "_LR_" + str(self.learning_rate) + "_BS_" + str(self.batch_size) + "Max_len" + str(self.max_len) + "Bert_ML" + str(self.bert_max_len)+ "DP_" + str(self.dropout_prob) + "EDP_" + str(self.entity_pair_dropout) + ".json"
        

        # log setting
        self.period = args.period
        self.test_epoch = args.test_epoch

        # debug 这里debug置为false,不用管啦
        self.debug = args.debug
        if self.debug:
            self.dev_prefix = self.train_prefix
            self.test_prefix = self.train_prefix

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45

补充知识

在实体关系抽取任务中,self.tag_size = 4,通常是因为任务要求对每个实体对进行分类,将其划分到以下四个类别中的一个:
1.实体对不存在关系(No Relation)。
2.实体对存在关系,但具体关系未知(Other Relation):表示两个实体之间存在关系,但具体的关系类型未被提前定义或者需要进一步细化。
3.实体对存在关系,且具体关系为预定义的某一种类型(Relation Type 1):表示两个实体之间存在某一种特定的关系类型。
4.实体对存在关系,且具体关系为预定义的另一种类型(Relation Type 2):表示两个实体之间存在另一种特定的关系类型。

fw = framework.Framework(con)

    def __init__(self, config):
        self.config = config
        # 创建了一个交叉熵损失函数并将其赋值给self.loss_function实例变量
        self.loss_function = nn.CrossEntropyLoss(reduction='none')
  • 1
  • 2
  • 3
  • 4

调用framework,对Framework类进行初始化

补充知识

nn.CrossEntropyLoss

很清楚的一个讲解

推荐阅读
相关标签