赞
踩
在这个git项目给了训练函数但是未提供测试函数,然后我根据自己的需求改了一下输入的模式
以excle表格的形式输入大批量数据(可以根据自己的需要进行更改输入方式)
文本内容是由标题字段和正文字段组成的
输出是直接print,如果想要存入excle表格可以转成dataframe格式然后自己存到需要的位置
如果有需要可以自行更改
- # coding: UTF-8
- import torch
- import pandas as pd
- import tqdm
- import time
- from utils import build_iterator, get_time_dif
- from models.bert import Model
- from pytorch_pretrained import BertTokenizer
-
-
-
- # 配置类
- class Config(object):
-
- """配置参数"""
- def __init__(self, dataset, all_class):
- self.model_name = 'top30bert'
- self.data_path = dataset + '/sample_data.xlsx' # 预测集
- self.class_list = all_class # 类别名单
- self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果
- self.device = torch.device('cuda') # if torch.cuda.is_available() else 'cpu'
- self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练
- self.num_classes = len(self.class_list) # 类别数
- self.num_epochs = 3 # epoch数
- self.batch_size = 50 # mini-batch大小 原始128
- self.pad_size = 32 # 每句话处理成的长度(短填长切)
- self.learning_rate = 5e-5 # 学习率
- self.bert_path = './bert_pretrain'
- self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
- self.hidden_size = 768
-
- # 加载数据函数
- PAD, CLS = '[PAD]', '[CLS]' # padding符号, bert中综合信息符号
-
-
- def load_dataset(path, pad_size=32):
- contents = []
- print('data_path:', path)
- data = pd.read_excel(path)
- for i, row in tqdm.tqdm(data.iterrows()):
- # label = class2num[row['FOUR_TYPE_NAME']]
- content = str(row['CONTENT_TEXT']) + str(row['TITLE'])
-
- # content, label = lin.split('\t') # 从tab分开出内容和标签
- token = config.tokenizer.tokenize(content)
- token = [CLS] + token
- seq_len = len(token)
- mask = []
- token_ids = config.tokenizer.convert_tokens_to_ids(token)
-
- if pad_size:
- if len(token) < pad_size:
- mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
- token_ids += ([0] * (pad_size - len(token)))
- else:
- mask = [1] * pad_size
- token_ids = token_ids[:pad_size]
- seq_len = pad_size
- contents.append((token_ids, 0, seq_len, mask)) # contents.append((token_ids, int(label), seq_len, mask))
- return contents
-
- # 加载配置
- classfile = pd.read_excel('data/class.xlsx')
- all_class = [value for i, value in enumerate(classfile.iloc[:,0].tolist())] # 所有类别名称列表
- dataset = 'data'
- config = Config(dataset, all_class)
-
-
- # 加载数据,预处理
- print("Loading data...")
- start_time = time.time()
- data = load_dataset(config.data_path, config.pad_size) # list
- # print("data",data)
- print("data_len", len(data))
- data_iter = build_iterator(data, config) # utils.DatasetIterater
- # print("data_iter",type(data_iter))
- time_dif = get_time_dif(start_time)
- print("Loading data Time usage:", time_dif)
-
- # 创建模型
- model = Model(config).to(config.device)
- model.load_state_dict(torch.load(config.save_path))
- model.eval()
-
- # 开始预测
- predictions = []
- with torch.no_grad():
- for texts, _ in data_iter:
- # print('texts:', texts) # text为tensor
- outputs = model(texts)
- # print(outputs.size())
- predic = torch.max(outputs, 1)[1].cpu().numpy()
- predicted_classes = [all_class[idx] for idx in predic]
- predictions.extend(predicted_classes) # 将预测结果添加到predictions列表中
-
- print(len(predictions))
- for i in predictions:
- print(i)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。