当前位置:   article > 正文

AlBERT模型微调_albert微调 知乎

albert微调 知乎
  1. import torch
  2. import torch.nn as nn
  3. import os
  4. from transformers import AlbertModel, BertTokenizer, AlbertConfig
  5. class Config(object):
  6. def __init__(self, dataset):
  7. self.model_name = "albert"
  8. self.data_path = "./albert/data/data/"
  9. self.train_path = self.data_path + "train.txt" # 训练集
  10. self.dev_path = self.data_path + "dev.txt" # 验证集
  11. self.test_path = self.data_path + "test.txt" # 测试集
  12. self.class_list = [
  13. x.strip() for x in open(self.data_path + "class.txt").readlines()
  14. ] # 类别名单
  15. self.save_path = "./albert/src/saved_dic"
  16. if not os.path.exists(self.save_path):
  17. os.mkdir(self.save_path)
  18. self.save_path += "/" + self.model_name + ".pt" # 模型训练结果
  19. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 设备
  20. # self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练
  21. self.num_classes = len(self.class_list) # 类别数
  22. self.num_epochs = 5 # epoch数
  23. self.batch_size = 256 # mini-batch大小
  24. self.pad_size = 32 # 每句话处理成的长度(短填长切)
  25. self.learning_rate = 5e-5 # 学习率
  26. self.bert_path = "/home/ec2-user/toutiao/albert/data/albert_chinese_base/"
  27. self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
  28. self.bert_config = AlbertConfig.from_pretrained(self.bert_path + '/config.json')
  29. self.hidden_size = 768
  30. class Model(nn.Module):
  31. def __init__(self, config):
  32. super(Model, self).__init__()
  33. self.albert = AlbertModel.from_pretrained(config.bert_path,config=config.bert_config)
  34. for name, param in self.albert.named_parameters():
  35. param.requires_grad = True
  36. print(name)
  37. self.fc = nn.Linear(config.hidden_size, config.num_classes)
  38. def forward(self, x):
  39. context = x[0] # 输入的句子
  40. mask = x[2] # 对padding部分进行mask,和句子一个size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0]
  41. _, pooled = self.albert(context, attention_mask=mask)
  42. out = self.fc(pooled)
  43. return out

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/573095
推荐阅读
相关标签
  

闽ICP备14008679号