赞
踩
最近在学习NLP的相关知识,找了资料比较全的黑马程序员中讲解NLP的课程,可是其中有一部分实战 新闻主题分类实战项目中,我发现黑马程序员代码有大两的错误,多处代码逻辑错误:
!!!需要注意的是 torchtext 的版本是0.4 ,可能是版本更新后,这个模块被移走了,如果不是0.4 可能会出现from torchtext.datasets.text_classification 这句话错误!!!
针对上述问题,我整理了一个完整可以正常运行的完整代码,希望给个小心心或者关注我一下呀~
先放代码
from torchtext.datasets.text_classification import * from torchtext.datasets.text_classification import _csv_iterator, _create_data_from_iterator import os import time from torch import optim import torch from torch.utils.data import DataLoader from torch.utils.data.dataset import random_split import torch.nn as nn from torchtext.data.utils import ngrams_iterator from torchtext.data.utils import get_tokenizer N_GRAMS =2 if not os.path.isdir('./data'): os.mkdir('./data') BATCH_SIZE = 16 device = torch.device('cuda'if torch.cuda.is_available() else 'cpu') # 定义创建数据集 def _setup_data_set(dataset_tar='./data/ag_news_csv.tar.gz', n_grams=N_GRAMS, vocab=None, include_unk=False): extracted_files = extract_archive(dataset_tar) train_csv_path = '' test_csv_path = '' for file_name in extracted_files: if file_name.endswith('train.csv'): train_csv_path = file_name if file_name.endswith('test.csv'): test_csv_path=file_name if vocab is None: print("Building Vocab based on %s" % train_csv_path) # 创建词典 vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams=n_grams)) else: if not isinstance(vocab, Vocab): raise TypeError("Passed vocabulary is not of type Vocab") print('Vocab has %d entries' % len(vocab)) print('Creating training data') train_data, train_labels = _create_data_from_iterator( vocab, _csv_iterator(test_csv_path, n_grams, yield_cls=True), include_unk) print('Creating testing data') test_data, test_labels = _create_data_from_iterator( vocab, _csv_iterator(test_csv_path, n_grams, yield_cls=True), include_unk) if len(train_labels ^ test_labels) > 0: raise ValueError("Training and test labels on't match") # 返回数据集实例 return (TextClassificationDataset(vocab, train_data, train_labels), TextClassificationDataset(vocab, test_data, test_labels)) train_data_set, test_data_set = _setup_data_set() # 定义模型 class TextSentiment(nn.Module): def __init__(self, vocab_size, embed_dim, num_class): super().__init__() self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True) self.fc = nn.Linear(embed_dim, num_class) self.init_weights() def init_weights(self): init_range = 0.5 self.embedding.weight.data.uniform_(-init_range, init_range) self.fc.weight.data.uniform_(-init_range, init_range) self.fc.bias.data.zero_() def forward(self, text, offsets): embedded = self.embedding(text, offsets) return self.fc(embedded) VOCAB_SIZE = len(train_data_set.get_vocab()) EMBED_DIM = 32 NUM_CLASS = len(train_data_set.get_labels()) # 实列化 model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUM_CLASS).to(device) N_EPOCH = 5 min_valid_loss = float('inf') criterion = torch.nn.CrossEntropyLoss().to(device) optimizer = optim.SGD(model.parameters(), lr=4.0) scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9) train_len = int(len(train_data_set) * 0.95) sub_train_, sub_valid_ = random_split(train_data_set, [train_len, len(train_data_set) - train_len]) def generate_batch(batch): label = torch.tensor([entry[0] for entry in batch]) text = [entry[1] for entry in batch] offsets =[0] + [len(entry) for entry in text] offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) text = torch.cat(text) return text, offsets, label def train_function(sub_train_): loss_ = 0 acc_ = 0 data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch) for i, (text, offsets, cls) in enumerate(data): optimizer.zero_grad() text, offsets, cls = text.to(device), offsets.to(device), cls.to(device) output = model(text, offsets) loss_ = criterion(output, cls) loss_ += loss_.item() loss_.backward() optimizer.step() acc_ += (output.argmax(1) == cls).sum().item() # 调整学习率 scheduler.step() return loss_ / len(sub_train_), acc_ / len(sub_train_) def test(data_): loss = 0 acc = 0 data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch) for text, offsets, cls in data: text, offsets, cls = text.to(device), offsets.to(device), cls.to(device) with torch.no_grad(): output = model(text, offsets) loss = criterion(output, cls) loss += loss.item() acc += (output.argmax(1) == cls).sum().item() return loss / len(data_), acc / len(data_) for epoch in range(N_EPOCH): start_time = time.time() train_loss, train_acc = train_function(sub_train_) valid_loss, valid_acc = test(sub_valid_) secs = int(time.time() - start_time) mins = secs / 60 secs = secs % 60 print('Epoch:%d' % (epoch + 1), "| time in %d minutes, %d seconds" % (mins, secs)) print(f"\tLoss:{train_loss:.4f}(train)\t|\tAcc:{train_acc * 100:.1f}%(train)") print(f"\tLoss:{valid_loss:.6f}(valid)\t|\tAcc:{valid_acc * 100:.6f}%(valid)") # 测试模型 ag_news_label = { 1: "World", 2: "Sports", 3: "Business", 4: "Sci/Tec" } def predict(text, model, vocab, ngrams): tokenizer = get_tokenizer("basic_english") with torch.no_grad(): text = torch.tensor([vocab[token] for token in ngrams_iterator(tokenizer(text),ngrams)]) output = model(text, torch.tensor([0])) return output.argmax(1).item() + 1 ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \ enduring the season’s worst weather conditions on Sunday at The \ Open on his way to a closing 75 at Royal Portrush, which \ considering the wind and the rain was a respectable showing. \ Thursday’s first round at the WGC-FedEx St. Jude Invitational \ was another story. With temperatures in the mid-80s and hardly any \ wind, the Spaniard was 13 strokes better in a flawless round. \ Thanks to his best putting performance on the PGA Tour, Rahm \ finished with an 8-under 62 for a three-stroke lead, which \ was even more impressive considering he’d never played the \ front nine at TPC Southwind." vocab = train_data_set.get_vocab() model = model.to("cpu") print("This is a %s news" % ag_news_label[predict(ex_text_str, model, vocab, 2)])
运行效果:
Building Vocab based on ./data/ag_news_csv/train.csv 120000lines [00:06, 17621.91lines/s] Vocab has 1308844 entries Creating training data 7600lines [00:00, 8405.65lines/s] Creating testing data 7600lines [00:00, 9790.11lines/s] Epoch:1 | time in 0 minutes, 0 seconds Loss:0.0003(train) | Acc:40.9%(train) Loss:0.0038(valid) | Acc:47.1%(valid) Epoch:2 | time in 0 minutes, 0 seconds Loss:0.0002(train) | Acc:70.4%(train) Loss:0.0031(valid) | Acc:67.9%(valid) Epoch:3 | time in 0 minutes, 0 seconds Loss:0.0004(train) | Acc:82.4%(train) Loss:0.0126(valid) | Acc:52.6%(valid) Epoch:4 | time in 0 minutes, 0 seconds Loss:0.0001(train) | Acc:88.3%(train) Loss:0.0026(valid) | Acc:60.8%(valid) Epoch:5 | time in 0 minutes, 0 seconds Loss:0.0000(train) | Acc:91.9%(train) Loss:0.0002(valid) | Acc:79.7%(valid) Epoch:6 | time in 0 minutes, 0 seconds Loss:0.0000(train) | Acc:94.8%(train) Loss:0.0001(valid) | Acc:81.8%(valid) Epoch:7 | time in 0 minutes, 0 seconds Loss:0.0000(train) | Acc:96.7%(train) Loss:0.0001(valid) | Acc:83.4%(valid) Epoch:8 | time in 0 minutes, 0 seconds Loss:0.0000(train) | Acc:98.5%(train) Loss:0.0001(valid) | Acc:83.4%(valid) Epoch:9 | time in 0 minutes, 0 seconds Loss:0.0000(train) | Acc:99.3%(train) Loss:0.0001(valid) | Acc:81.1%(valid) Epoch:10 | time in 0 minutes, 0 seconds Loss:0.0000(train) | Acc:99.6%(train) Loss:0.0002(valid) | Acc:82.1%(valid) Epoch:11 | time in 0 minutes, 0 seconds Loss:0.0000(train) | Acc:99.8%(train) Loss:0.0001(valid) | Acc:84.7%(valid) Epoch:12 | time in 0 minutes, 0 seconds Loss:0.0000(train) | Acc:99.9%(train) Loss:0.0001(valid) | Acc:83.2%(valid) Epoch:13 | time in 0 minutes, 0 seconds Loss:0.0000(train) | Acc:100.0%(train) Loss:0.0001(valid) | Acc:83.7%(valid) Epoch:14 | time in 0 minutes, 0 seconds Loss:0.0000(train) | Acc:100.0%(train) Loss:0.0001(valid) | Acc:83.2%(valid) Epoch:15 | time in 0 minutes, 0 seconds Loss:0.0000(train) | Acc:100.0%(train) Loss:0.0001(valid) | Acc:84.2%(valid) Epoch:16 | time in 0 minutes, 0 seconds Loss:0.0000(train) | Acc:100.0%(train) Loss:0.0001(valid) | Acc:85.0%(valid) Epoch:17 | time in 0 minutes, 0 seconds Loss:0.0000(train) | Acc:100.0%(train) Loss:0.0001(valid) | Acc:85.0%(valid) Epoch:18 | time in 0 minutes, 0 seconds Loss:0.0000(train) | Acc:100.0%(train) Loss:0.0000(valid) | Acc:85.5%(valid) Epoch:19 | time in 0 minutes, 0 seconds Loss:0.0000(train) | Acc:100.0%(train) Loss:0.0001(valid) | Acc:84.2%(valid) Epoch:20 | time in 0 minutes, 0 seconds Loss:0.0000(train) | Acc:100.0%(train) Loss:0.0001(valid) | Acc:84.5%(valid) This is a Sports news
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。