赞
踩
import config import framework import argparse import models import os import torch import numpy as np import random os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 在结果中使用随机种子,以实现结果的可复现性 seed = 2179 torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) random.seed(seed) # 使用确定性算法,确保结果的可复现性 torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # 使用Python标准库中的argparse模块,用于解析命令行参数 parser = argparse.ArgumentParser() parser.add_argument('--model_name', type=str, default='OneRel', help='name of the model') parser.add_argument('--lr', type=float, default=1e-5) parser.add_argument('--dropout_prob', type=float, default=0.2) parser.add_argument('--entity_pair_dropout', type=float, default=0.1) parser.add_argument('--multi_gpu', type=bool, default=False) parser.add_argument('--dataset', type=str, default='DUIE') parser.add_argument('--batch_size', type=int, default=4) parser.add_argument('--max_epoch', type=int, default=40) parser.add_argument('--test_epoch', type=int, default=1) parser.add_argument('--train_prefix', type=str, default='train_triples') parser.add_argument('--dev_prefix', type=str, default='dev_triples') parser.add_argument('--test_prefix', type=str, default='test_triples') parser.add_argument('--max_len', type=int, default=100) parser.add_argument('--bert_max_len', type=int, default=200) parser.add_argument('--rel_num', type=int, default=48) parser.add_argument('--period', type=int, default=100) parser.add_argument('--debug', type=bool, default=False) # 解析命令行参数并将结果存储在args变量中 args = parser.parse_args() con = config.Config(args) fw = framework.Framework(con) model = { 'OneRel': models.RelModel } fw.train(model[args.model_name])
Dropout是一种在神经网络中常用的正则化技术。
Dropout率指的是在训练过程中随机丢弃神经网络中的一部分单元(神经元)的比例。
(1)
这里调用了config.Config
class Config(object): def __init__(self, args): self.args = args self.multi_gpu = args.multi_gpu self.learning_rate = args.lr self.batch_size = args.batch_size self.max_epoch = args.max_epoch self.max_len = args.max_len self.rel_num = args.rel_num self.bert_max_len = args.bert_max_len # 预训练的Bert模型(如bert-base-chinese)通常具有768维的隐藏状态输出,这里bert_dim被设置为768是为了与预训练的Bert模型保持一致。 self.bert_dim = 768 self.tag_size = 4 self.dropout_prob = args.dropout_prob self.entity_pair_dropout = args.entity_pair_dropout # dataset self.dataset = args.dataset # path and name self.data_path = './data/' + self.dataset self.checkpoint_dir = './checkpoint/' + self.dataset self.log_dir = './log/' + self.dataset self.result_dir = './result/' + self.dataset self.train_prefix = args.train_prefix self.dev_prefix = args.dev_prefix self.test_prefix = args.test_prefix self.model_save_name = args.model_name + '_DATASET_' + self.dataset + "_LR_" + str(self.learning_rate) + "_BS_" + str(self.batch_size) + "Max_len" + str(self.max_len) + "Bert_ML" + str(self.bert_max_len) + "DP_" + str(self.dropout_prob) + "EDP_" + str(self.entity_pair_dropout) self.log_save_name = 'LOG_' + args.model_name + '_DATASET_' + self.dataset + "_LR_" + str(self.learning_rate) + "_BS_" + str(self.batch_size) + "Max_len" + str(self.max_len) + "Bert_ML" + str(self.bert_max_len) + "DP_" + str(self.dropout_prob) + "EDP_" + str(self.entity_pair_dropout) self.result_save_name = 'RESULT_' + args.model_name + '_DATASET_' + self.dataset + "_LR_" + str(self.learning_rate) + "_BS_" + str(self.batch_size) + "Max_len" + str(self.max_len) + "Bert_ML" + str(self.bert_max_len)+ "DP_" + str(self.dropout_prob) + "EDP_" + str(self.entity_pair_dropout) + ".json" # log setting self.period = args.period self.test_epoch = args.test_epoch # debug 这里debug置为false,不用管啦 self.debug = args.debug if self.debug: self.dev_prefix = self.train_prefix self.test_prefix = self.train_prefix
在实体关系抽取任务中,self.tag_size = 4,通常是因为任务要求对每个实体对进行分类,将其划分到以下四个类别中的一个:
1.实体对不存在关系(No Relation)。
2.实体对存在关系,但具体关系未知(Other Relation):表示两个实体之间存在关系,但具体的关系类型未被提前定义或者需要进一步细化。
3.实体对存在关系,且具体关系为预定义的某一种类型(Relation Type 1):表示两个实体之间存在某一种特定的关系类型。
4.实体对存在关系,且具体关系为预定义的另一种类型(Relation Type 2):表示两个实体之间存在另一种特定的关系类型。
def __init__(self, config):
self.config = config
# 创建了一个交叉熵损失函数并将其赋值给self.loss_function实例变量
self.loss_function = nn.CrossEntropyLoss(reduction='none')
调用framework,对Framework类进行初始化
很清楚的一个讲解
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。