当前位置:   article > 正文

Bert在fine-tune训练时的技巧:①冻结部分层参数、②weight-decay (L2正则化)、③warmup_proportion、④_bert冻结参数

bert冻结参数

作为一个NLPer,bert应该是会经常用到的一个模型了。但bert可调参数很多,一些技巧也很多,比如加上weight-decay, layer初始化、冻结参数、只优化部分层参数等等,方法太多了,每次都会纠结该怎么样去finetune,才能让bert训练的又快又好呢,有没有可能形成一个又快又好又准的大体方向的准则呢。于是,就基于这个研究、实践了一番,总结了这篇文章。

1.使用误差修正,训练收敛变快,效果变好。

这个方法主要来自于这篇文章Revisiting Few-sample BERT Fine-tuning。文章中提到,原优化器adam它的数学公式中是带bias-correct,而在官方的bert模型中,实现的优化器bertadam是不带bias-correction的。

在代码上, 也就是这个BertAdam的实现,是不带bias-correction。不过这个pytorch_pretrained_bert这个package是抱抱脸2019年的推出的开发代码了,已经废弃了。

  1. from pytorch_pretrained_bert.optimization import BertAdam
  2. optimizer = BertAdam(optimizer_grouped_parameters,
  3. lr=2e-05,
  4. warmup= 0.1 ,
  5. t_total= 2000)

现在的transformers的已经更正过这个问题了,修改的更加灵活了。

  1. import transformers
  2. optimizer = transformers.AdamW(model_parameters, lr=lr, correct_bias=True)

于是,俺砖头在THNews数据上做文本分类任务试验了一下有无correct_bias的情况,影响不大,效果还略微有降,但paper中讨论的是小数据量场景,可能存在些场景适应性问题,大家可以自行尝试。

2.使用权重初始化。

用bert做finetune时,通常会直接使用bert的预训练模型权重,去初始化下游任务中的模型参数,这样做是为了充分利用bert在预训练过程中学习到的语言知识,将其能够迁移到下游任务的学习当中。

以bert-base为例,由12层的transformer block堆叠而成。那到底是直接保留bert中预训练的模型参数,还是保留部分,或是保留哪些层的模型参数对下游任务更友好呢?其实有一些论文讨论过这个这个问题,总结起来就是,底部的层也就是靠近输入的层,学到的是通用语义信息,比如词性、词法等语言学知识,而靠近顶部的层也就是靠近输出的层,会倾向于学习到接近下游任务的知识,拿预训练任务来说,就是masked word prediction、next sentence prediction任务的知识。

所以借此经验,finetune时,可以保留底部的bert权重,对于顶部层的权重(1~6 layers)可以重新进行随机初始化,让这部分参数在你的 任务上进行重新学习。这部分实验,这篇文章Revisiting Few-sample BERT Fine-tuning也帮大家实践了,采取重新初始化部分层参数的方法,在一部分任务上,指标获得了一些明显提升。

于是,砖头也实践了一下文本分类任务,在训练上能明显看到收敛变快,但效果上变化不大,这些实验代码都放在文章末尾的仓库了,大家感兴趣的可以研究交流。

3.weight-decay (L2正则化)

由于在bert官方的代码中对于bias项、LayerNorm.biasLayerNorm.weight项是免于正则化的。因此经常在bert的训练中会采用与bert原训练方式一致的做法,也就是下面这段代码。

  1. param_optimizer = list(multi_classification_model.named_parameters())
  2. no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
  3. optimizer_grouped_parameters = [
  4. {'params': [p for n, p in param_optimizer
  5. if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
  6. {'params': [p for n, p in param_optimizer
  7. if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
  8. ]
  9. optimizer = transformers.AdamW(optimizer_grouped_parameters, lr=config.lr, correct_bias=not config.bertadam)

实践出真知,砖头在文本分类的任务上试验了一下,加与不加,这个差别没什么影响,大家也可以在训练的时候对比试一试,代码代价很小。

4.冻结部分层参数(Frozen parameter)

冻结参数经常在一些大模型的训练中使用,主要是对于一些参数较多的模型,冻结部分参数在不太影响结果精度的情况下,可以减少参数的迭代计算,加快训练速度。在bert中fine-tune中也常用到这种措施,一般会冻结的是bert前几层,因为有研究bert结构的论文表明,bert前面几层冻结是不太影响模型最终结果表现的。这个就有点类似与图像类的深度网络,模型前面层学习的都是一些通用且广泛的知识(比如一些基础的线、点形状类似),这类知识都差不多。这里关于冻结参数主要有这么两种方法。

  1. # 方法1: 设置requires_grad = False
  2. for param in model.parameters():
  3. param.requires_grad = False
  4. # 方法2: torch.no_grad()
  5. class net(nn.Module):
  6. def __init__():
  7. ......
  8. def forward(self.x):
  9. with torch.no_grad(): # no_grad下参数不会迭代
  10. x = self.layer(x)
  11. ......
  12. x = self.fc(x)
  13. return x

  train.py

  1. # code reference: https://github.com/asappresearch/revisit-bert-finetuning
  2. import os
  3. import sys
  4. import time
  5. import argparse
  6. import logging
  7. import numpy as np
  8. from tqdm import tqdm
  9. from sklearn import metrics
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. from torch.utils.data import Dataset, DataLoader
  14. import transformers
  15. from transformers import BertModel, AlbertModel, BertConfig, BertTokenizer
  16. from dataloader import TextDataset, BatchTextCall
  17. from model import MultiClass
  18. from utils import load_config
  19. logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(filename)s:%(lineno)d:%(message)s',
  20. datefmt='%m/%d/%Y %H:%M:%S',
  21. level=logging.INFO)
  22. logger = logging.getLogger(__name__)
  23. def choose_bert_type(path, bert_type="tiny_albert"):
  24. """
  25. choose bert type for chinese, tiny_albert or macbert(bert)
  26. return: tokenizer, model
  27. """
  28. if bert_type == "albert":
  29. model_config = BertConfig.from_pretrained(path)
  30. model = AlbertModel.from_pretrained(path, config=model_config)
  31. elif bert_type == "bert" or bert_type == "roberta":
  32. model_config = BertConfig.from_pretrained(path)
  33. model = BertModel.from_pretrained(path, config=model_config)
  34. else:
  35. model_config, model = None, None
  36. print("ERROR, not choose model!")
  37. return model_config, model
  38. def evaluation(model, test_dataloader, loss_func, label2ind_dict, save_path, valid_or_test="test"):
  39. # model.load_state_dict(torch.load(save_path))
  40. model.eval()
  41. total_loss = 0
  42. predict_all = np.array([], dtype=int)
  43. labels_all = np.array([], dtype=int)
  44. for ind, (token, segment, mask, label) in enumerate(test_dataloader):
  45. token = token.cuda()
  46. segment = segment.cuda()
  47. mask = mask.cuda()
  48. label = label.cuda()
  49. out = model(token, segment, mask)
  50. loss = loss_func(out, label)
  51. total_loss += loss.detach().item()
  52. label = label.data.cpu().numpy()
  53. predic = torch.max(out.data, 1)[1].cpu().numpy()
  54. labels_all = np.append(labels_all, label)
  55. predict_all = np.append(predict_all, predic)
  56. acc = metrics.accuracy_score(labels_all, predict_all)
  57. if valid_or_test == "test":
  58. report = metrics.classification_report(labels_all, predict_all, target_names=label2ind_dict.keys(), digits=4)
  59. confusion = metrics.confusion_matrix(labels_all, predict_all)
  60. return acc, total_loss / len(test_dataloader), report, confusion
  61. return acc, total_loss / len(test_dataloader)
  62. def train(config):
  63. label2ind_dict = {'体育': 0, '娱乐': 1, '家居': 2, '房产': 3, '教育': 4, '时尚': 5, '时政': 6, '游戏': 7, '科技': 8, '财经': 9}
  64. os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu
  65. torch.backends.cudnn.benchmark = True
  66. # load_data(os.path.join(data_dir, "cnews.train.txt"), label_dict)
  67. tokenizer = BertTokenizer.from_pretrained(config.pretrained_path)
  68. train_dataset_call = BatchTextCall(tokenizer, max_len=config.sent_max_len)
  69. train_dataset = TextDataset(os.path.join(config.data_dir, "train.txt"))
  70. train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=10,
  71. collate_fn=train_dataset_call)
  72. valid_dataset = TextDataset(os.path.join(config.data_dir, "dev.txt"))
  73. valid_dataloader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=True, num_workers=10,
  74. collate_fn=train_dataset_call)
  75. test_dataset = TextDataset(os.path.join(config.data_dir, "test.txt"))
  76. test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True, num_workers=10,
  77. collate_fn=train_dataset_call)
  78. model_config, bert_encode_model = choose_bert_type(config.pretrained_path, bert_type=config.bert_type)
  79. multi_classification_model = MultiClass(bert_encode_model, model_config,
  80. num_classes=10, pooling_type=config.pooling_type)
  81. multi_classification_model.cuda()
  82. # multi_classification_model.load_state_dict(torch.load(config.save_path))
  83. if config.weight_decay:
  84. param_optimizer = list(multi_classification_model.named_parameters())
  85. no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
  86. optimizer_grouped_parameters = [
  87. {'params': [p for n, p in param_optimizer
  88. if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
  89. {'params': [p for n, p in param_optimizer
  90. if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
  91. ]
  92. optimizer = transformers.AdamW(optimizer_grouped_parameters, lr=config.lr, correct_bias=not config.bertadam)
  93. else:
  94. optimizer = transformers.AdamW(multi_classification_model.parameters(), lr=config.lr, betas=(0.9, 0.999),
  95. eps=1e-08, weight_decay=0.01, correct_bias=not config.bertadam)
  96. num_train_optimization_steps = len(train_dataloader) * config.epoch
  97. if config.warmup_proportion != 0:
  98. scheduler = transformers.get_linear_schedule_with_warmup(optimizer,
  99. int(num_train_optimization_steps * config.warmup_proportion),
  100. num_train_optimization_steps)
  101. else:
  102. scheduler = transformers.get_linear_schedule_with_warmup(optimizer,
  103. int(num_train_optimization_steps * config.warmup_proportion),
  104. num_train_optimization_steps)
  105. loss_func = F.cross_entropy
  106. # reinit pooler-layer
  107. if config.reinit_pooler:
  108. if config.bert_type in ["bert", "roberta", "albert"]:
  109. logger.info(f"reinit pooler layer of {config.bert_type}")
  110. encoder_temp = getattr(multi_classification_model, config.bert_type)
  111. encoder_temp.pooler.dense.weight.data.normal_(mean=0.0, std=encoder_temp.config.initializer_range)
  112. encoder_temp.pooler.dense.bias.data.zero_()
  113. for p in encoder_temp.pooler.parameters():
  114. p.requires_grad = True
  115. else:
  116. raise NotImplementedError
  117. # reinit encoder layers
  118. if config.reinit_layers > 0:
  119. if config.bert_type in ["bert", "roberta", "albert"]:
  120. # assert config.reinit_pooler
  121. logger.info(f"reinit layers count of {str(config.reinit_layers)}")
  122. encoder_temp = getattr(multi_classification_model, config.bert_type)
  123. for layer in encoder_temp.encoder.layer[-config.reinit_layers:]:
  124. for module in layer.modules():
  125. if isinstance(module, (nn.Linear, nn.Embedding)):
  126. module.weight.data.normal_(mean=0.0, std=encoder_temp.config.initializer_range)
  127. elif isinstance(module, nn.LayerNorm):
  128. module.bias.data.zero_()
  129. module.weight.data.fill_(1.0)
  130. if isinstance(module, nn.Linear) and module.bias is not None:
  131. module.bias.data.zero_()
  132. else:
  133. raise NotImplementedError
  134. if config.freeze_layer_count:
  135. logger.info(f"frozen layers count of {str(config.freeze_layer_count)}")
  136. # We freeze here the embeddings of the model
  137. for param in multi_classification_model.bert.embeddings.parameters():
  138. param.requires_grad = False
  139. if config.freeze_layer_count != -1:
  140. # if freeze_layer_count == -1, we only freeze the embedding layer
  141. # otherwise we freeze the first `freeze_layer_count` encoder layers
  142. for layer in multi_classification_model.bert.encoder.layer[:config.freeze_layer_count]:
  143. for param in layer.parameters():
  144. param.requires_grad = False
  145. loss_total, top_acc = [], 0
  146. for epoch in range(config.epoch):
  147. multi_classification_model.train()
  148. start_time = time.time()
  149. tqdm_bar = tqdm(train_dataloader, desc="Training epoch{epoch}".format(epoch=epoch))
  150. for i, (token, segment, mask, label) in enumerate(tqdm_bar):
  151. token = token.cuda()
  152. segment = segment.cuda()
  153. mask = mask.cuda()
  154. label = label.cuda()
  155. multi_classification_model.zero_grad()
  156. out = multi_classification_model(token, segment, mask)
  157. loss = loss_func(out, label)
  158. loss.backward()
  159. optimizer.step()
  160. scheduler.step()
  161. optimizer.zero_grad()
  162. loss_total.append(loss.detach().item())
  163. logger.info("Epoch: %03d; loss = %.4f cost time %.4f" % (epoch, np.mean(loss_total), time.time() - start_time))
  164. acc, loss, report, confusion = evaluation(multi_classification_model,
  165. test_dataloader, loss_func, label2ind_dict,
  166. config.save_path)
  167. logger.info("Accuracy: %.4f Loss in test %.4f" % (acc, loss))
  168. if top_acc < acc:
  169. top_acc = acc
  170. # torch.save(multi_classification_model.state_dict(), config.save_path)
  171. logger.info(f"{report} \n {confusion}")
  172. time.sleep(1)
  173. if __name__ == "__main__":
  174. parser = argparse.ArgumentParser(description='bert finetune test')
  175. parser.add_argument("--data_dir", type=str, default="../data/THUCNews/news")
  176. parser.add_argument("--save_path", type=str, default="../ckpt/bert_classification")
  177. parser.add_argument("--pretrained_path", type=str, default="/data/Learn_Project/Backup_Data/bert_chinese",
  178. help="pre-train model path")
  179. parser.add_argument("--bert_type", type=str, default="bert", help="bert or albert")
  180. parser.add_argument("--gpu", type=str, default='0')
  181. parser.add_argument("--epoch", type=int, default=20)
  182. parser.add_argument("--lr", type=float, default=0.005)
  183. parser.add_argument("--warmup_proportion", type=float, default=0.1)
  184. parser.add_argument("--pooling_type", type=str, default="first-last-avg")
  185. parser.add_argument("--batch_size", type=int, default=512)
  186. parser.add_argument("--sent_max_len", type=int, default=44)
  187. parser.add_argument("--do_lower_case", type=bool, default=True,
  188. help="Set this flag true if you are using an uncased model.")
  189. parser.add_argument("--bertadam", type=int, default=0, help="If bertadam, then set correct_bias = False")
  190. parser.add_argument("--weight_decay", type=int, default=0, help="If weight_decay, set 1")
  191. parser.add_argument("--reinit_pooler", type=int, default=1, help="reinit pooler layer")
  192. parser.add_argument("--reinit_layers", type=int, default=6, help="reinit pooler layers count")
  193. parser.add_argument("--freeze_layer_count", type=int, default=6, help="freeze layers count")
  194. args = parser.parse_args()
  195. log_filename = f"test_bertadam{args.bertadam}_weight_decay{str(args.weight_decay)}" \
  196. f"_reinit_pooler{str(args.reinit_pooler)}_reinit_layers{str(args.reinit_layers)}" \
  197. f"_frozen_layers{str(args.freeze_layer_count)}_warmup_proportion{str(args.warmup_proportion)}"
  198. logger.addHandler(logging.FileHandler(os.path.join("./log", log_filename), 'w'))
  199. logger.info(args)
  200. train(args)

 dataloader.py

  1. import os
  2. import pandas as pd
  3. import numpy as np
  4. import torch
  5. from torch.utils.data import Dataset, DataLoader
  6. from transformers import BertModel, AlbertModel, BertConfig, BertTokenizer
  7. from transformers import BertForSequenceClassification, AutoModelForMaskedLM
  8. def load_data(path):
  9. train = pd.read_csv(path, header=0, sep='\t', names=["text", "label"])
  10. print(train.shape)
  11. # valid = pd.read_csv(os.path.join(path, "cnews.val.txt"), header=None, sep='\t', names=["label", "text"])
  12. # test = pd.read_csv(os.path.join(path, "cnews.test.txt"), header=None, sep='\t', names=["label", "text"])
  13. texts = train.text.to_list()
  14. labels = train.label.map(int).to_list()
  15. # label_dic = dict(zip(train.label.unique(), range(len(train.label.unique()))))
  16. return texts, labels
  17. class TextDataset(Dataset):
  18. def __init__(self, filepath):
  19. super(TextDataset, self).__init__()
  20. self.train, self.label = load_data(filepath)
  21. def __len__(self):
  22. return len(self.train)
  23. def __getitem__(self, item):
  24. text = self.train[item]
  25. label = self.label[item]
  26. return text, label
  27. class BatchTextCall(object):
  28. """call function for tokenizing and getting batch text
  29. """
  30. def __init__(self, tokenizer, max_len=312):
  31. self.tokenizer = tokenizer
  32. self.max_len = max_len
  33. def text2id(self, batch_text):
  34. return self.tokenizer(batch_text, max_length=self.max_len,
  35. truncation=True, padding='max_length', return_tensors='pt')
  36. def __call__(self, batch):
  37. batch_text = [item[0] for item in batch]
  38. batch_label = [item[1] for item in batch]
  39. source = self.text2id(batch_text)
  40. token = source.get('input_ids').squeeze(1)
  41. mask = source.get('attention_mask').squeeze(1)
  42. segment = source.get('token_type_ids').squeeze(1)
  43. label = torch.tensor(batch_label)
  44. return token, segment, mask, label
  45. if __name__ == "__main__":
  46. data_dir = "/GitProject/Text-Classification/Chinese-Text-Classification/data/THUCNews/news_all"
  47. # pretrained_path = "/data/Learn_Project/Backup_Data/chinese-roberta-wwm-ext"
  48. pretrained_path = "/data/Learn_Project/Backup_Data/RoBERTa_zh_L12_PyTorch"
  49. label_dict = {'体育': 0, '娱乐': 1, '家居': 2, '房产': 3, '教育': 4, '时尚': 5, '时政': 6, '游戏': 7, '科技': 8, '财经': 9}
  50. # tokenizer, model = choose_bert_type(pretrained_path, bert_type="roberta")
  51. tokenizer = BertTokenizer.from_pretrained(pretrained_path)
  52. model_config = BertConfig.from_pretrained(pretrained_path)
  53. model = BertModel.from_pretrained(pretrained_path, config=model_config)
  54. # model = BertForSequenceClassification.from_pretrained(pretrained_path)
  55. # model = AutoModelForMaskedLM.from_pretrained(pretrained_path)
  56. text_dataset = TextDataset(os.path.join(data_dir, "test.txt"))
  57. text_dataset_call = BatchTextCall(tokenizer)
  58. text_dataloader = DataLoader(text_dataset, batch_size=2, shuffle=True, num_workers=2, collate_fn=text_dataset_call)
  59. for i, (token, segment, mask, label) in enumerate(text_dataloader):
  60. print(i, token, segment, mask, label)
  61. out = model(input_ids=token, attention_mask=mask, token_type_ids=segment)
  62. # loss, logits = model(token, mask, segment)[:2]
  63. print(out)
  64. print(out.last_hidden_state.shape)
  65. break

model.py

  1. import torch
  2. from torch import nn
  3. BertLayerNorm = torch.nn.LayerNorm
  4. class MultiClass(nn.Module):
  5. """ text processed by bert model encode and get cls vector for multi classification
  6. """
  7. def __init__(self, bert_encode_model, model_config, num_classes=10, pooling_type='first-last-avg'):
  8. super(MultiClass, self).__init__()
  9. self.bert = bert_encode_model
  10. self.num_classes = num_classes
  11. self.fc = nn.Linear(model_config.hidden_size, num_classes)
  12. self.pooling = pooling_type
  13. self.dropout = nn.Dropout(model_config.hidden_dropout_prob)
  14. self.layer_norm = BertLayerNorm(model_config.hidden_size)
  15. def forward(self, batch_token, batch_segment, batch_attention_mask):
  16. out = self.bert(batch_token,
  17. attention_mask=batch_attention_mask,
  18. token_type_ids=batch_segment,
  19. output_hidden_states=True)
  20. # print(out)
  21. if self.pooling == 'cls':
  22. out = out.last_hidden_state[:, 0, :] # [batch, 768]
  23. elif self.pooling == 'pooler':
  24. out = out.pooler_output # [batch, 768]
  25. elif self.pooling == 'last-avg':
  26. last = out.last_hidden_state.transpose(1, 2) # [batch, 768, seqlen]
  27. out = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768]
  28. elif self.pooling == 'first-last-avg':
  29. first = out.hidden_states[1].transpose(1, 2) # [batch, 768, seqlen]
  30. last = out.hidden_states[-1].transpose(1, 2) # [batch, 768, seqlen]
  31. first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768]
  32. last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768]
  33. avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) # [batch, 2, 768]
  34. out = torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) # [batch, 768]
  35. else:
  36. raise "should define pooling type first!"
  37. out = self.layer_norm(out)
  38. out = self.dropout(out)
  39. out_fc = self.fc(out)
  40. return out_fc
  41. if __name__ == '__main__':
  42. path = "/data/Learn_Project/Backup_Data/bert_chinese"
  43. MultiClassModel = MultiClass
  44. # MultiClassModel = BertForMultiClassification
  45. multi_classification_model = MultiClassModel.from_pretrained(path, num_classes=10)
  46. if hasattr(multi_classification_model, 'bert'):
  47. print("-------------------------------------------------")
  48. else:
  49. print("**********************************************")

utils.py

  1. import yaml
  2. class AttrDict(dict):
  3. """Attr dict: make value private
  4. """
  5. def __init__(self, d):
  6. self.dict = d
  7. def __getattr__(self, attr):
  8. value = self.dict[attr]
  9. if isinstance(value, dict):
  10. return AttrDict(value)
  11. else:
  12. return value
  13. def __str__(self):
  14. return str(self.dict)
  15. def load_config(config_file):
  16. """Load config file"""
  17. with open(config_file) as f:
  18. if hasattr(yaml, 'FullLoader'):
  19. config = yaml.load(f, Loader=yaml.FullLoader)
  20. else:
  21. config = yaml.load(f)
  22. print(config)
  23. return AttrDict(config)
  24. if __name__ == "__main__":
  25. import argparse
  26. parser = argparse.ArgumentParser(description='text classification')
  27. parser.add_argument("-c", "--config", type=str, default="./config.yaml")
  28. args = parser.parse_args()
  29. config = load_config(args.config)

5.warmup & lr_decay

warm_up是在bert训练中是一个经常用到的小技巧了,就是模型迭代前期用较大的lr进行warmup,后期随着迭代,用较小的lr。有一篇文章On Layer Normalization in the Transformer Architecture对此进行了些分析,总结一下就是作者发现Transformer在训练的初始阶段,输出层附近的期望梯度非常大,warmup可以避免前向FC层的不稳定的剧烈改变,所以没有warm-up的话模型优化过程就会非常不稳定。特别是深网络,batch_size较大的时候,这个影响会比较明显。

  1. num_train_optimization_steps = len(train_dataloader) * config.epoch
  2. optimizer = transformers.AdamW(optimizer_grouped_parameters, lr=config.lr)
  3. scheduler = transformers.get_linear_schedule_with_warmup(optimizer,
  4. int(num_train_optimization_steps *0.1),
  5. num_train_optimization_steps)

关于learning_rate衰减, 原来有写过一篇关于自适应优化器Adam还需加learning-rate decay吗?解析,在这里通过文章与实验检验,结论就是发现加了还是会有些许的提升,具体的可以看看这篇噢。

最后,看看结论

基于以上不同策略参数的实验设置,组合下来总计做了64组实验(1/0 represented used or not),其中总体结果f1介于0.9281~0.9405。总体结果来看,不同的fine-tune设置下来,对于结果的影响不是很大,最多只相差了一个多点。对于工程应用上来讲,影响不大,但大家打比赛或刷榜的时候,资源充足时可以试试。不同的策略下,收敛速度还是有相差比较大的,其中有进行一些frozen参数的,迭代计算确实速度快了许多。

最后由于64组结果太长,就不全部贴过来了。以下只贴出了其中最好或最差的前三组结果。完整的实验结果及代码,大家感兴趣的可以看这里 github.com/Chinese-Text-Classification

indexbertadamweight_decayreinit_poolerreinit_layersfrozen_layerswarmup_proportionresult
37100600.00.9287
45101600.00.9284
55110660.00.9281
35100060.00.9398
50110000.10.9405
58111000.10.9396

Bert在fine-tune时训练的5种技巧 - 知乎

bert模型的微调,如何固定住BERT预训练模型参数,只训练下游任务的模型参数? - 知乎

GitHub - shuxinyin/Chinese-Text-Classification: Chinese-Text-Classification Project including bert-classification, textCNN and so on.

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/101311?site
推荐阅读
相关标签
  

闽ICP备14008679号