当前位置:   article > 正文

Bert-Chinese-Text-Classification-Pytorch项目测试函数编写

bert-chinese-text-classification-pytorch

 问题:

在这个git项目给了训练函数但是未提供测试函数,然后我根据自己的需求改了一下输入的模式

解决方案:

excle表格的形式输入大批量数据(可以根据自己的需要进行更改输入方式)

文本内容是由标题字段和正文字段组成的

输出是直接print,如果想要存入excle表格可以转成dataframe格式然后自己存到需要的位置

如果有需要可以自行更改

  1. # coding: UTF-8
  2. import torch
  3. import pandas as pd
  4. import tqdm
  5. import time
  6. from utils import build_iterator, get_time_dif
  7. from models.bert import Model
  8. from pytorch_pretrained import BertTokenizer
  9. # 配置类
  10. class Config(object):
  11. """配置参数"""
  12. def __init__(self, dataset, all_class):
  13. self.model_name = 'top30bert'
  14. self.data_path = dataset + '/sample_data.xlsx' # 预测集
  15. self.class_list = all_class # 类别名单
  16. self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果
  17. self.device = torch.device('cuda') # if torch.cuda.is_available() else 'cpu'
  18. self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练
  19. self.num_classes = len(self.class_list) # 类别数
  20. self.num_epochs = 3 # epoch数
  21. self.batch_size = 50 # mini-batch大小 原始128
  22. self.pad_size = 32 # 每句话处理成的长度(短填长切)
  23. self.learning_rate = 5e-5 # 学习率
  24. self.bert_path = './bert_pretrain'
  25. self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
  26. self.hidden_size = 768
  27. # 加载数据函数
  28. PAD, CLS = '[PAD]', '[CLS]' # padding符号, bert中综合信息符号
  29. def load_dataset(path, pad_size=32):
  30. contents = []
  31. print('data_path:', path)
  32. data = pd.read_excel(path)
  33. for i, row in tqdm.tqdm(data.iterrows()):
  34. # label = class2num[row['FOUR_TYPE_NAME']]
  35. content = str(row['CONTENT_TEXT']) + str(row['TITLE'])
  36. # content, label = lin.split('\t') # 从tab分开出内容和标签
  37. token = config.tokenizer.tokenize(content)
  38. token = [CLS] + token
  39. seq_len = len(token)
  40. mask = []
  41. token_ids = config.tokenizer.convert_tokens_to_ids(token)
  42. if pad_size:
  43. if len(token) < pad_size:
  44. mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
  45. token_ids += ([0] * (pad_size - len(token)))
  46. else:
  47. mask = [1] * pad_size
  48. token_ids = token_ids[:pad_size]
  49. seq_len = pad_size
  50. contents.append((token_ids, 0, seq_len, mask)) # contents.append((token_ids, int(label), seq_len, mask))
  51. return contents
  52. # 加载配置
  53. classfile = pd.read_excel('data/class.xlsx')
  54. all_class = [value for i, value in enumerate(classfile.iloc[:,0].tolist())] # 所有类别名称列表
  55. dataset = 'data'
  56. config = Config(dataset, all_class)
  57. # 加载数据,预处理
  58. print("Loading data...")
  59. start_time = time.time()
  60. data = load_dataset(config.data_path, config.pad_size) # list
  61. # print("data",data)
  62. print("data_len", len(data))
  63. data_iter = build_iterator(data, config) # utils.DatasetIterater
  64. # print("data_iter",type(data_iter))
  65. time_dif = get_time_dif(start_time)
  66. print("Loading data Time usage:", time_dif)
  67. # 创建模型
  68. model = Model(config).to(config.device)
  69. model.load_state_dict(torch.load(config.save_path))
  70. model.eval()
  71. # 开始预测
  72. predictions = []
  73. with torch.no_grad():
  74. for texts, _ in data_iter:
  75. # print('texts:', texts) # text为tensor
  76. outputs = model(texts)
  77. # print(outputs.size())
  78. predic = torch.max(outputs, 1)[1].cpu().numpy()
  79. predicted_classes = [all_class[idx] for idx in predic]
  80. predictions.extend(predicted_classes) # 将预测结果添加到predictions列表中
  81. print(len(predictions))
  82. for i in predictions:
  83. print(i)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/834798
推荐阅读
相关标签
  

闽ICP备14008679号