当前位置:   article > 正文

Pytorch版本的Ernie Health+BiLSTM+CRF模型源码详解_ernie-bilstm

ernie-bilstm

Pytorch版本的Ernie Health+BiLSTM+CRF模型源码详解

一、主函数

(1)定义训练函数

# 一个星号*的作用是将tuple或者list中的元素进行unpack,分开传入,作为多个参数;两个星号**的作用是把dict类型的数据作为参数传入。
def train(**kwargs) :
    config = Config() # 调用配置文件
    config.update(**kwargs)
    print('当前的参数配置为:\n',config)
    if config.use_cuda :
        torch.cuda.set_device(config.gpu)
    print('loading corpus ·················· ')
    vocab = load_vocab(config.vocab) # {tag.index} # 加载词汇表
    label_dic = load_vocab(config.label_file) # 加载标签 将所有可能出现的标签在词汇表中找到对应的索引
    id2tag = {label_dic[tag]: tag for tag in label_dic.keys()}
    tagset_size = len(label_dic) # 计算所有可能出现的标签数目
    train_data = read_corpus(config.train_file, max_length=config.maxlength, label_dic=label_dic, vocab=vocab)
    dev_data = read_corpus(config.dev_file, max_length=config.maxlength, label_dic=label_dic, vocab=vocab)
    test_data = read_corpus(config.test_file, max_length=config.maxlength, label_dic=label_dic, vocab=vocab)

    # 构建训练集数据
    train_dataset = build_dataset(train_data)
    # 从数据库中每次抽出batch_size个样本 shuffle:在每个epoch开始的时候,对数据进行重新打乱
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=config.batch_size)
    # 构建验证集数据
    dev_dataset = build_dataset(dev_data)
    dev_loader = DataLoader(dev_dataset, shuffle=True, batch_size=config.batch_size)
    # 构建测试集数据
    test_dataset = build_dataset(test_data)
    # shuffle:在每个epoch开始的时候,对数据进行重新打乱
    test_loader = DataLoader(test_dataset, shuffle=True, batch_size=config.batch_size)

    # 构建模型
    model = Ernie_Lstm_Crf(config.ernie_path, tagset_size, config.ernie_embedding, \
                           config.rnn_hidden, config.rnn_layer, dropout_ratio=config.dropout_ratio, \
                           dropout1=config.dropout1, use_cuda=config.use_cuda)

    if config.load_model:
        assert config.load_path is not None
        model = load_model(model, name=config.load_path)
    # if config.use_cuda() :
    model.cuda()
    optimizer = getattr(optim, config.optim)
    optimizer = optimizer(model.parameters(), lr=config.lr, weight_decay = config.weight_decay)
    model.train() # 设置成训练模型
    step = 0 # 记录训练迭代次数
    eval_loss = float('inf') # 初始化损失为无穷大
    last_improved = 0 # 记录上一次更新的step值
    flag = False
    for epoch in range(config.base_epoch) :
        for i, batch in enumerate(train_loader) :
            step += 1
            model.zero_grad()
            inputs, masks, tags = batch
            inputs, masks, tags = Variable(inputs), Variable(masks), Variable(tags)
            # if config.use_cuda():
            inputs, masks, tags = inputs.cuda(), masks.cuda(), tags.cuda()
            feats = model(inputs, masks)
            loss = model.loss(feats, masks.byte(), tags)
            loss.backward()
            optimizer.step()

            # 5次迭代输出一次结果
            if step % 5 == 0 :
                print('step:{} | epoch:{} | loss:{}'.format(step, epoch, loss.item()))
            # 50次迭代保存一次模型
            if step % 50 == 0 :
                f1, dev_loss = dev(model, dev_loader, config, id2tag, test=False)
                if dev_loss < eval_loss : # 验证集损失小,模型效果好保存
                    eval_loss = dev_loss
                    save_model(model, epoch)
                    last_improved = step # 记录最后一次迭代的次数
                    improve = '*'
                else:
                    improve = ''
                print('eval epoch: {} | f1_score: {} | loss: {} | {}'.format(epoch, f1, dev_loss, improve))
            if step - last_improved > config.require_improvement : # 提前结束训练
                print('No optimization for a long time, auto-stopping...')
                flag = True
                break
        if flag:
            break
    test(model, test_loader, config, id2tag)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79

加载词汇表

vocab = load_vocab(config.vocab) # {tag.index} 加载词汇表
  • 1

{‘[PAD]’: 0,
‘[UNK]’: 1,
‘[CLS]’: 2,
‘[SEP]’: 3,
‘[MASK]’: 4,
‘,’: 5,
‘的’: 6,
‘是’: 7,
‘:’: 8,
‘有’: 9,
‘。’: 10,
‘不’: 11,
‘一’: 12,
‘了’: 13,
‘好’: 14,
‘可’: 15,
‘?’: 16,
‘医’: 17,
‘以’: 18,
‘我’: 19,
‘1’: 20,
‘要’: 21,
……

label_dic = load_vocab(config.label_file) # 加载标签 将所有可能出现的标签在词汇表中找到对应的索引
id2tag = {label_dic[tag]: tag for tag in label_dic.keys()} # 完成键与值互换
  • 1
  • 2

{0: ‘< pad>’,
1: ‘B-PER’,
2: ‘I-PER’,
3: ‘B-ORG’,
4: ‘I-ORG’,
5: ‘B-LOC’,
6: ‘I-LOC’,
7: ‘O’,
8: ‘< start>’,
9: ‘< eos>’}

def read_corpus(path, max_length, label_dic, vocab) :
    """
        :param path:数据文件路径
        :param max_length: 最大长度
        :param label_dic: 标签字典
        :return:
        """
    file = open(path, encoding='utf-8')
    content = file.readlines()
    file.close()
    result = []
    tokens = []
    label = []

    for line in content:
        # 读取一行
        if line != '\n':
            word, tag = line.strip('\n').split()
            tokens.append(word)
            label.append(tag)
        # 获得一句话
        else:
            if len(tokens) > max_length - 2:
                tokens = tokens[0:(max_length - 2)] # 截断超过最大长度的部分
                label = label[0:(max_length - 2)]
            tokens_f = ['[CLS]'] + tokens + ['[SEP]']  # 在token的前后分别拼接'[CLS]'和'[SEP]'
            label_f = ["<start>"] + label + ['<eos>']  # 在label的前后分别拼接'<start>'和'<eos>'

            # if '[UNK]' not in vocab:
            #     print('None')
            #     vocab['[UNK]'] = 0
            '''
            for i in tokens_f :
                if i in vocab:
                    int(vocab[i])
                else:
                    int(vocab['[UNK]']
            '''
            # 如果字符在词汇表中,则将对应在词汇表中的索引拼接在input_ids列表中,如果不在就拼接[UNK]的索引
            # [2,……,3]
            input_ids = [int(vocab[i]) if i in vocab else int(vocab['[UNK]']) for i in tokens_f]
            # [8,……,9]
            label_ids = [label_dic[i] for i in label_f]
            input_mask = [1] * len(input_ids) # 全部拼接1
            # 将input_ids、input_mask长度不足max_length的部分全部标记为0,label_ids长度不足max_length的部分全部标记为<pad>对应的索引
            # print(input_ids)
			# print(label_ids)
			# print(input_mask)
            while len(input_ids) < max_length:
                input_ids.append(0)
                input_mask.append(0)
                label_ids.append(label_dic['<pad>'])
            # 确保三者的长度都为max_length,否则报错
            assert len(input_ids) == max_length
            assert len(input_mask) == max_length
            assert len(label_ids) == max_length
			print(input_ids)
			print(label_ids)
			print(input_mask)
            feature = InputFeatures(input_id=input_ids, input_mask=input_mask, label_id=label_ids)
            result.append(feature)
            tokens = []
            label = []
    return result
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64

以People_Daily_dev.txt文件中的数据为例

dev_data = read_corpus('./dataset/People_daliy/People_Daily_dev.txt', 128, label_dic=label_dic, vocab=vocab)
  • 1

在这里插入图片描述
在这里插入图片描述
构建数据集

def build_dataset(data):
    """
    生成数据集
    """
    input_ids = torch.LongTensor([temp.input_id for temp in data])
    print('input_ids:', input_ids)
    print(input_ids.size())
    input_masks = torch.LongTensor([temp.input_mask for temp in data])
    print('input_masks:', input_masks)
    # doc_size 文本句子的数目
    print(input_masks.size()) # torch.Size([doc_size, 128])
    label_ids = torch.LongTensor([temp.label_id for temp in data])
    print('label_ids:', label_ids)
    print(label_ids.size())
    # 向TensorDataset中传入的一系列张量第一个维度大小一定相同 对给定的 tensor 数据,将他们包装成 dataset
    dataset = TensorDataset(input_ids, input_masks, label_ids)
    return dataset
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
import torch
from torch.utils.data import TensorDataset
dev_dataset = build_dataset(dev_data)
  • 1
  • 2
  • 3

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
定义优化器

optimizer = getattr(optim, config.optim) 
  • 1

等价于optimizer = getattr(optim, config.optim),但前者是动态赋值

# 优化器 采样梯度更新模型的可学习参数,使得损失减小
optimizer = optimizer(model.parameters(), lr=config.lr, weight_decay = config.weight_decay)
  • 1
  • 2

学习率(learning rate)lr 控制更新的步伐
在这里插入图片描述
L2正则化系数:weight_decay 权值衰减
正则化方法是减小方差的策略,常见的过拟合就会导致高方差,因此常用正则化降低方差来解决过拟合。正则化有L1正则化与L2正则化,通常就是损失函数加上正则项。

Obj = Cost + RegularizationTerm
  • 1

在这里插入图片描述
加入L2正则项后,目标函数为:
在这里插入图片描述
经过实验证明:随着训练轮数的增加,有权值衰减的模型的泛化能力越强。
训练模型

    for epoch in range(config.base_epoch) :
        for i, batch in enumerate(train_loader) :
            step += 1
            model.zero_grad()
            inputs, masks, tags = batch
            inputs, masks, tags = Variable(inputs), Variable(masks), Variable(tags)
            # if config.use_cuda():
            inputs, masks, tags = inputs.cuda(), masks.cuda(), tags.cuda()
            feats = model(inputs, masks)
            loss = model.loss(feats, masks.byte(), tags)
            loss.backward()
            optimizer.step()

            # 5次迭代输出一次结果
            if step % 5 == 0 :
                print('step:{} | epoch:{} | loss:{}'.format(step, epoch, loss.item()))
            # 50次迭代保存一次模型
            if step % 50 == 0 :
                f1, dev_loss = dev(model, dev_loader, config, id2tag, test=False)
                if dev_loss < eval_loss : # 验证集损失小,模型效果好保存
                    eval_loss = dev_loss
                    save_model(model, epoch)
                    last_improved = step # 记录最后一次迭代的次数
                    improve = '*'
                else:
                    improve = ''
                print('eval epoch: {} | f1_score: {} | loss: {} | {}'.format(epoch, f1, dev_loss, improve))
            if step - last_improved > config.require_improvement : # 提前结束训练
                print('No optimization for a long time, auto-stopping...')
                flag = True
                break
        if flag:
            break
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

将模型的所有参数梯度置为0:model.zero_grad()
将inputs,masks和tags这些数据,使用Variable类将数据封装成一个autograd变量包装,以便在计算图中跟踪梯度信息。inputs, masks, tags = Variable(inputs), Variable(masks), Variable(tags)

optimizer.step()torch.optim优化器的一个方法,用于更新模型参数。在计算完反向传播的梯度之后,可以使用optimizer.step()来执行一步优化,这会使模型现有的权重按预定义的某种规则(例如参数梯度下降或Adam等优化算法)进行微调。具体来说,步骤通常是:

1.通过使用loss.backward()方法,在相应的计算图上计算每个可训练的参数(权重和偏差)的梯度。
2.通过调用optimizer.step(),计算出给定优化器的下一个“步骤”。这将更新模型参数的值,与损失函数最小化的梯度方向相符。
3.通过循环迭代不同批次的数据,不断调用optimizer.step()方法更新模型参数,可以逐渐提高模型的准确性和泛化能力,以更好地完成特定的任务。

(2)定义验证函数

def dev(model, dev_loader, config, id2tag, test = False) :
    model.eval()
    eval_loss = 0
    true = []
    pred = []
    with torch.no_grad() :
        for i, batch in enumerate(dev_loader):
            if test : # 打印信息
                print('处理测试数据集第' + str(i * config.batch_size) + '至第' + str((i + 1) * config.batch_size) + '条······')
            inputs, masks, tags = batch
            inputs, masks, tags = Variable(inputs), Variable(masks), Variable(tags)
            if config.use_cuda :
                inputs, masks, tags = inputs.cuda(), masks.cuda(), tags.cuda()
            feats = model(inputs, masks)
            # 使用维特比算法解码
            best_path = model.crf.decode(feats,masks.byte())
            loss = model.loss(feats, masks.byte(), tags)
            eval_loss += loss.item()
            pred.extend([t for t in best_path])
            true.extend([[x for  x in t.tolist() if x != 0] for t in tags])
    true = [[id2tag[y] for y in x] for x in true]
    pred = [[id2tag[y] for y in x] for x in pred]
    f1 = f1_score(true, pred) # 计算两者之间的f1值
    if test :
        accuracy = accuracy_score(true, pred)
        precision = precision_score(true, pred)
        recall = recall_score(true, pred)
        report = classification_report(true, pred, 4)
        return accuracy, precision, recall, f1, eval_loss / len(dev_loader), report
    model.train()
    return f1, eval_loss / len(dev_loader)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

model.crf.decode()用于获取条件随机场(CRF)的预测标签序列的方法。将CRF模型的输出(标记分数或概率)作为输入,并返回最可能的标记序列。需要注意的是方法返回的是整数标记序列,而不是实际的标记名称或概率。通常需要使用一些对应关系来将其转换为所需的标记格式。
classification_report() 用于生成分类任务的模型性能报告的函数,输入两个参数:真实标签和预测标签。它会把真实标签和预测标签进行比较,计算出所有用于评估模型性能的指标。在该报告中,会生成一个表格,其中每一行都注意到一个类别,始终列出准确率、精确率、召回率和F1分数等指标。digits=4用于格式化输出浮点值的位数。
(3)定义测试函数

# 定义验证函数
def dev(model, dev_loader, config, id2tag, test = False) :
    model.eval()
    eval_loss = 0
    true = []
    pred = []
    with torch.no_grad() :
        for i, batch in enumerate(dev_loader):
            if test : # 打印信息
                print('处理测试数据集第' + str(i * config.batch_size) + '至第' + str((i + 1) * config.batch_size) + '条······')
            inputs, masks, tags = batch
            inputs, masks, tags = Variable(inputs), Variable(masks), Variable(tags)
            if config.use_cuda :
                inputs, masks, tags = inputs.cuda(), masks.cuda(), tags.cuda()
            feats = model(inputs, masks)
            # 使用维特比算法解码
            best_path = model.crf.decode(feats,masks.byte())
            loss = model.loss(feats, masks.byte(), tags)
            eval_loss += loss.item()
            pred.extend([t for t in best_path])
            true.extend([[x for  x in t.tolist() if x != 0] for t in tags])
    true = [[id2tag[y] for y in x] for x in true]
    pred = [[id2tag[y] for y in x] for x in pred]
    f1 = f1_score(true, pred) # 计算两者之间的f1值
    if test :
        accuracy = accuracy_score(true, pred)
        precision = precision_score(true, pred)
        recall = recall_score(true, pred)
        report = classification_report(true, pred, 4)
        return accuracy, precision, recall, f1, eval_loss / len(dev_loader), report
    model.train()
    return f1, eval_loss / len(dev_loader)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

二、定义模型

class Ernie_Lstm_Crf(nn.Module) :
    '''
    ernie_lstm_crf model
    args:
        ernie_config : 模型配置文件
        tagset_size:目标数据集标签的数量
        embedding_dim:编码维度
        hidden_dim:隐藏层维度
        rnn_layers:rnn层数
        dropout_ratio:dropout
        dropout1
        use_cuda 是否使用GPU
    '''
    def __init__(self, ernie_config, tagset_size, embedding_dim, hidden_dim, rnn_layers, dropout_ratio, dropout1, use_cuda=True):
        super(Ernie_Lstm_Crf, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim

        # 加载Ernie [batch_size, max_len, hidden_size] [50, 128, 768]
        self.model = ErnieModel.from_pretrained(ernie_config).to(device)

        # 拼接LSTM 输入维度embedding_dim 768 输出维度hidden_dim 500 * 2
        self.lstm = nn.LSTM(embedding_dim, hidden_dim,
                            num_layers=rnn_layers, bidirectional=True,
                            dropout=dropout_ratio, batch_first=True)
        print(self.lstm) # [50,128,500*2]
        self.rnn_layers = rnn_layers
        print(self.rnn_layers) # [50,128,500]
        # 经过Dropout维度不变
        self.dropout1 = nn.Dropout(p = dropout1)
        print(self.dropout1) # [50,128,500]
        # batch_first默认False 此时 数据维度为:[seq_len, batch_size, hidden_size] ; 设置为True [batch_size, seq_len, hidden_size]
        self.crf = CRF(num_tags = tagset_size, batch_first = True)
        print(self.crf) # 输出[50,128,500,500] -> [50,128,500*2] 最后两个维度代表CRF模型中的状态转移矩阵
        # 输入维度 hidden_dim * 2 = 1000 , 输出维度 tagset_size 10
        self.liner = nn.Linear(hidden_dim * 2, tagset_size) # 经过线性层降维 1000-> 10
        print(self.liner)
        self.tagset_size = tagset_size
        print(self.tagset_size)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39

(1)Ernie后接BiLSTM
如果将来自ErnieHealth模型的张量(输入)传递给LSTM层,则张量的维度会在每个时间步上更改。具体来说,对于一个形状为[batch_size, max_len, hidden_size]的输入,其中batch_size是批处理大小(50),max_len是序列的最大长度(128),hidden_size是ErnieHealth模型的隐藏表示大小(768),它将进入LSTM层,并沿着时间步长(max_len)进行处理。

在时间步t中,LSTM将接收形状为[batch_size, 1, embedding_dim]的输入,其中embedding_dim等于ErnieHealth模型的隐藏层大小(768)。然后,LSTM会将该输入转换为隐藏状态和单元状态,这些状态都是形状为[num_layers*num_directions, batch_size, hidden_dim]的张量(hidden_dim=500)。

最后的输出由LSTM的所有时间步上的最后一层前向和后向隐藏状态连接而成,形状为[batch_size, max_len, num_directions * hidden_dim],其中num_directions是双向LSTM的方向数(2),因此输出的最终维度是hidden_dim乘以num_directions(2*500=1000)。
(2)Ernie接BiLSTM后再接RNN
如果一个LSTM模型的输出张量维度为 [50,128,1000] ,我们将其输入到一个RNN层中,那么输出的张量维度可以计算如下:

首先考虑RNN层的计算过程,它的输入形状通常为 [batch_size, sequence_length, input_size]。 RNN在每个时间步接收输入序列中的一个项,并在每个时间步返回当前时间步的输出以及隐藏状态,因此其输出形状也应该包含一个时间步。对于一个单向的RNN层,其输出的形状为 [batch_size, sequence_length, num_units]

根据上述描述,可以将LSTM的输出张量视为RNN层的输入,其中 batch size 为 50, sequence length 为 128, input size 为 1000。因此,得到的RNN层的输出张量维度为:[50, 128, num_units]

值得注意的是,输出张量的最后一个维度num_units是由RNN中 hidden_size=500(自定义)(或者说是num_units)这个超参数定义的,它指定了隐藏层的大小。
(3)Ernie接BiLSTM再接RNN后再接CRF
经过CRF处理后, 维度为[50,128,500]的张量将变成形状为[50,128,500,500] 的张量,其中,最后两个维度代表CRF模型中的状态转移矩阵。
(4)Ernie接BiLSTM再接RNN后再接CRF最后接线性层Liner
线性层降维,输入[50,128,500*2],输出[50,128,tagset_set],其中tagset_set=10.

三、配置文件

# coding=utf-8

class Config(object) :
    def __init__(self):
        # 数据集路径
        self.label_file = './dataset/tag/PeopleDaliy_tag.txt'
        self.train_file = './dataset/People_daliy/People_Daily_train.txt'
        self.dev_file = './dataset/People_daliy/People_Daily_dev.txt'
        self.test_file = './dataset/People_daliy/People_Daily_test.txt'
        self.vocab = './premodel/vocab.txt'
        self.ernie_path = './premodel/'
        self.maxlength = 128
        self.use_cuda = True
        self.gpu = 0
        self.batch_size = 50
        self.rnn_hidden = 500
        self.ernie_embedding = 768
        # dropout随机失活,在训练过程的前向传播中,让每个神经元以一定概率处于不激活的状态,以达到减少过拟合的效果
        self.dropout1 = 0.5
        self.dropout_ratio = 0.5
        self.rnn_layer = 1
        self.lr = 5e-5
        self.lr_decay = 0.00001
        self.weight_decay = 0.00005
        # 训练完成的模型保存路径
        self.checkpoint = 'result/'
        # 定义优化器
        self.optim = 'Adam'
        self.load_model = True
        self.load_path = 'PeopleDaily-9718'
        self.base_epoch = 10
        self.require_improvement = 1000 # 若1000次迭代损失并没有优化,则提前结束训练

    def update(self, **kwargs):
        for k ,v in kwargs.items() :
            setattr(self, k ,v)
    def __str__(self):
        return '\n'.join(['%s:%s' % item for item in self.__dict__.items()])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38

运行结果

在这里插入图片描述
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/355736?site
推荐阅读
相关标签
  

闽ICP备14008679号