赞
踩
##nlp入门小白,这个数据集有很多下载资源可以网上找一下下载到本地,结合了好多博主的东西 path = r'自己的地址datasets\AG_NEWS.data\datasets\AG_NEWS' import torch from torchtext.datasets import AG_NEWS # train_iter = AG_NEWS(root=path, split='train') ##### import pandas as pd def load_data(csv_file): df = pd.read_csv(csv_file, header=None) # pd默认第一行不读取,所以添加 header dataTmep = [] # 逐行读取,_ 行号,row 内容 for _, row in df.iterrows(): label = row[0] context = row[1] + row[2] # 将标题,内容合并 dataTmep.append((label, context)) return dataTmep train_iter=train_dataset = load_data(r"D:\study\dataset\AG News\train.csv") test_iter=test_dataset=load_data(r"D:\study\dataset\AG News\test.csv") from torchtext.data.utils import get_tokenizer from collections import Counter from torchtext.vocab import vocab tokenizer = get_tokenizer('basic_english') counter = Counter() for (label, line) in train_iter: counter.update(tokenizer(line)) #这我自己添加的,后面预测一直有报错加了就没了 for (label, line) in test_iter: counter.update(tokenizer(line)) vocab = vocab(counter, min_freq=1) text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)] label_pipeline = lambda x: int(x) - 1 from torch.utils.data import DataLoader device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def collate_batch(batch): label_list, text_list, offsets = [], [], [0] for (_label, _text) in batch: label_list.append(label_pipeline(_label)) processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64) text_list.append(processed_text) offsets.append(processed_text.size(0)) label_list = torch.tensor(label_list, dtype=torch.int64) offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) text_list = torch.cat(text_list) return label_list.to(device), text_list.to(device), offsets.to(device) from torch import nn class TextClassificationModel(nn.Module): def __init__(self, vocab_size, embed_dim, num_class): super(TextClassificationModel, self).__init__() self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True) self.fc = nn.Linear(embed_dim, num_class) self.init_weights() def init_weights(self): initrange = 0.5 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc.weight.data.uniform_(-initrange, initrange) self.fc.bias.data.zero_() def forward(self, text, offsets): embedded = self.embedding(text, offsets) output = self.fc(embedded) return output num_class = 4 vocab_size = len(vocab) emsize = 64 model = TextClassificationModel(vocab_size, emsize, num_class).to(device) import time def train(dataloader): model.train() total_acc, total_count = 0, 0 log_interval = 500 start_time = time.time() for idx, (label, text, offsets) in enumerate(dataloader): optimizer.zero_grad() predited_label = model(text, offsets) loss = criterion(predited_label, label) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) optimizer.step() total_acc += (predited_label.argmax(1) == label).sum().item() total_count += label.size(0) if idx % log_interval == 0 and idx > 0: elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches, accuracy {:8.3f}'.format(epoch, idx, len(dataloader), total_acc / total_count)) total_acc, total_count = 0, 0 start_time = time.time() def evaluate(dataloader): model.eval() total_acc, total_count = 0, 0 with torch.no_grad(): for idx, (label, text, offsets) in enumerate(dataloader): predited_label = model(text, offsets) total_acc += (predited_label.argmax(1) == label).sum().item() total_count += label.size(0) return total_acc / total_count def predict(text, text_pipeline): with torch.no_grad(): text = torch.tensor(text_pipeline(text)) output = model(text.to(device), torch.tensor([0]).to(device))#将输入张量移动到与模型相同的设备上 # output = model(text, torch.tensor([0])) return output.argmax(1).item() + 1 from torch.utils.data.dataset import random_split if __name__ == '__main__': EPOCHS = 10 LR = 5 BATCH_SIZE = 64 # criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=LR) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1) total_accu = None # train_iter, test_iter = AG_NEWS(root=path) train_dataset = load_data(r"D:\study\dataset\AG News\train.csv") test_dataset = load_data(r"D:\study\dataset\AG News\test.csv") num_train = int(len(train_dataset) * 0.95) split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train]) # train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch) # shuffle表示随机打乱 valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch) test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch) # for epoch in range(1, EPOCHS + 1): epoch_start_time = time.time() train(train_dataloader) accu_val = evaluate(valid_dataloader) if total_accu is not None and total_accu > accu_val: scheduler.step() else: total_accu = accu_val print('-' * 59) print('| end of epoch {:3d} | time: {:5.2f}s | ' 'valid accuracy {:8.3f} '.format(epoch, time.time() - epoch_start_time, accu_val)) print('-' * 59) # # print('Checking the results of test dataset.') accu_test = evaluate(test_dataloader) print('test accuracy {:8.3f}'.format(accu_test)) # torch.save(model.state_dict(), r'D:\study\dataset\AG News\model_TextClassification.pth') #以下是预测内容 # model.load_state_dict(torch.load(r'D:\study\dataset\AG News\model_TextClassification.pth')) # # # ag_news_label = {1: "World", # 2: "Sports", # 3: "Business", # 4: "Sci/Tec"} # # # ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was enduring the season’s worst weather conditions on Sunday at The Open on his way to a closing 75 at Royal Portrush, which considering the wind and the rain was a respectable showing. Thursday’s first round at the WGC-FedEx St. Jude Invitational was another story. With temperatures in the mid-80s and hardly any wind, the Spaniard was 13 strokes better in a flawless round. Thanks to his best putting performance on the PGA Tour, Rahm finished with an 8-under 62 for a three-stroke lead, which was even more impressive considering he’d never played the front nine at TPC Southwind." # ex_text_str = 'Beijing of Automation, Beijing Institute of Technology' # # model = model.to("cpu") # # print("This is a %s news" % ag_news_label[predict(ex_text_str, text_pipeline)])
#原博 参考博 不完全一样做了一下修改,我运行的时候虽然下载到本地但还是那什么~【PyTorch】7 文本分类TorchText实战——AG_NEWS四类别新闻分类_agnews依据新闻标题写content-CSDN博客Torchtext下的AG_NEWS数据集进行分类(官方文档代码)-CSDN博客
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。