当前位置:   article > 正文

黑马NLP实战 --- 新闻分类模型训练

黑马nlp

关于《黑马程序员》课程中NLP中 训练新闻分类模型

最近在学习NLP的相关知识,找了资料比较全的黑马程序员中讲解NLP的课程,可是其中有一部分实战 新闻主题分类实战项目中,我发现黑马程序员代码有大两的错误,多处代码逻辑错误:

  1. 首先是数据集下载太慢,因为需要翻墙才能下载,所以大部分情况在加载数据集就会出现Timeout异常
  2. 数据集的处理,在课程中并没有提到,加载本地的csv数据集文件出现的格式不对的情况
  3. 其次,generator_banth()这个方法中返回的数据对象元组形式是不对的,新闻数据集的元组是3项(type, title ,content)分别是新闻的类型,新闻的标题和新闻的内内容,但是在课程却只有两项。

!!!需要注意的是 torchtext 的版本是0.4 ,可能是版本更新后,这个模块被移走了,如果不是0.4 可能会出现from torchtext.datasets.text_classification 这句话错误!!!

针对上述问题,我整理了一个完整可以正常运行的完整代码,希望给个小心心或者关注我一下呀~

先放代码

from torchtext.datasets.text_classification import *
from torchtext.datasets.text_classification import _csv_iterator, _create_data_from_iterator
import os
import time
from torch import optim
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import torch.nn as nn
from torchtext.data.utils import ngrams_iterator
from torchtext.data.utils import get_tokenizer




N_GRAMS =2
if not os.path.isdir('./data'):
    os.mkdir('./data')

BATCH_SIZE = 16
device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')

# 定义创建数据集
def _setup_data_set(dataset_tar='./data/ag_news_csv.tar.gz',
                    n_grams=N_GRAMS, vocab=None,
                    include_unk=False):
    extracted_files = extract_archive(dataset_tar)
    train_csv_path = ''
    test_csv_path = ''
    for file_name in extracted_files:
        if file_name.endswith('train.csv'):
            train_csv_path = file_name
        if file_name.endswith('test.csv'):
            test_csv_path=file_name
        
    if vocab is None:
        print("Building Vocab based on %s" % train_csv_path)
        # 创建词典
        vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams=n_grams))
    else:
        if not isinstance(vocab, Vocab):
            raise TypeError("Passed vocabulary is not of type Vocab")
    print('Vocab has %d entries' % len(vocab))
    print('Creating training data')
    train_data, train_labels = _create_data_from_iterator(
        vocab, _csv_iterator(test_csv_path, n_grams, yield_cls=True), include_unk)
    print('Creating testing data')
    test_data, test_labels = _create_data_from_iterator(
        vocab, _csv_iterator(test_csv_path, n_grams, yield_cls=True), include_unk)
    if len(train_labels ^ test_labels) > 0:
        raise ValueError("Training and test labels on't match")
    # 返回数据集实例
    return (TextClassificationDataset(vocab, train_data, train_labels),
            TextClassificationDataset(vocab, test_data, test_labels))


train_data_set, test_data_set = _setup_data_set()

# 定义模型
class TextSentiment(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()
        
    def init_weights(self):
        init_range = 0.5
        self.embedding.weight.data.uniform_(-init_range, init_range)
        self.fc.weight.data.uniform_(-init_range, init_range)
        self.fc.bias.data.zero_()
    
    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)


VOCAB_SIZE = len(train_data_set.get_vocab())
EMBED_DIM = 32
NUM_CLASS = len(train_data_set.get_labels())
# 实列化
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUM_CLASS).to(device)

N_EPOCH = 5
min_valid_loss = float('inf')

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=4.0)
scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

train_len = int(len(train_data_set) * 0.95)
sub_train_, sub_valid_ = random_split(train_data_set, [train_len, len(train_data_set) - train_len])


def generate_batch(batch):
    label = torch.tensor([entry[0] for entry in batch])
    text = [entry[1] for entry in batch]
    offsets =[0] + [len(entry) for entry in text]
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    
    text = torch.cat(text)
    return text, offsets, label

def train_function(sub_train_):
    
    loss_ = 0
    acc_ = 0
    data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)
    for i, (text, offsets, cls) in enumerate(data):
        optimizer.zero_grad()
        text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
        output = model(text, offsets)
        loss_ = criterion(output, cls)
        loss_ += loss_.item()
        loss_.backward()
        optimizer.step()
        acc_ += (output.argmax(1) == cls).sum().item()
    # 调整学习率
    scheduler.step()
    
    return loss_ / len(sub_train_), acc_ / len(sub_train_)

def test(data_):
    loss = 0
    acc = 0
    data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
    for text, offsets, cls in data:
        text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
        with torch.no_grad():
            output = model(text, offsets)
            loss = criterion(output, cls)
            loss += loss.item()
            acc += (output.argmax(1) == cls).sum().item()
        
    return loss / len(data_), acc / len(data_)


for epoch in range(N_EPOCH):
    start_time = time.time()
    train_loss, train_acc = train_function(sub_train_)
    valid_loss, valid_acc = test(sub_valid_)
    
    secs = int(time.time() - start_time)
    mins = secs / 60
    secs = secs % 60
    
    print('Epoch:%d' % (epoch + 1), "| time in %d minutes, %d seconds" % (mins, secs))
    print(f"\tLoss:{train_loss:.4f}(train)\t|\tAcc:{train_acc * 100:.1f}%(train)")
    print(f"\tLoss:{valid_loss:.6f}(valid)\t|\tAcc:{valid_acc * 100:.6f}%(valid)")
    
    
    
# 测试模型
ag_news_label = {
    1: "World",
    2: "Sports",
    3: "Business",
    4: "Sci/Tec"
}


def predict(text, model, vocab, ngrams):
    tokenizer = get_tokenizer("basic_english")
    with torch.no_grad():
        text = torch.tensor([vocab[token] for token in ngrams_iterator(tokenizer(text),ngrams)])
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item() + 1

ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \
        enduring the season’s worst weather conditions on Sunday at The \
        Open on his way to a closing 75 at Royal Portrush, which \
        considering the wind and the rain was a respectable showing. \
        Thursday’s first round at the WGC-FedEx St. Jude Invitational \
        was another story. With temperatures in the mid-80s and hardly any \
        wind, the Spaniard was 13 strokes better in a flawless round. \
        Thanks to his best putting performance on the PGA Tour, Rahm \
        finished with an 8-under 62 for a three-stroke lead, which \
        was even more impressive considering he’d never played the \
        front nine at TPC Southwind."

vocab = train_data_set.get_vocab()
model = model.to("cpu")

print("This is a %s news" % ag_news_label[predict(ex_text_str, model, vocab, 2)])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184

运行效果:

Building Vocab based on ./data/ag_news_csv/train.csv
120000lines [00:06, 17621.91lines/s]
Vocab has 1308844 entries
Creating training data
7600lines [00:00, 8405.65lines/s]
Creating testing data
7600lines [00:00, 9790.11lines/s]
Epoch:1 | time in 0 minutes, 0 seconds
        Loss:0.0003(train)      |       Acc:40.9%(train)
        Loss:0.0038(valid)      |       Acc:47.1%(valid)
Epoch:2 | time in 0 minutes, 0 seconds
        Loss:0.0002(train)      |       Acc:70.4%(train)
        Loss:0.0031(valid)      |       Acc:67.9%(valid)
Epoch:3 | time in 0 minutes, 0 seconds
        Loss:0.0004(train)      |       Acc:82.4%(train)
        Loss:0.0126(valid)      |       Acc:52.6%(valid)
Epoch:4 | time in 0 minutes, 0 seconds
        Loss:0.0001(train)      |       Acc:88.3%(train)
        Loss:0.0026(valid)      |       Acc:60.8%(valid)
Epoch:5 | time in 0 minutes, 0 seconds
        Loss:0.0000(train)      |       Acc:91.9%(train)
        Loss:0.0002(valid)      |       Acc:79.7%(valid)
Epoch:6 | time in 0 minutes, 0 seconds
        Loss:0.0000(train)      |       Acc:94.8%(train)
        Loss:0.0001(valid)      |       Acc:81.8%(valid)
Epoch:7 | time in 0 minutes, 0 seconds
        Loss:0.0000(train)      |       Acc:96.7%(train)
        Loss:0.0001(valid)      |       Acc:83.4%(valid)
Epoch:8 | time in 0 minutes, 0 seconds
        Loss:0.0000(train)      |       Acc:98.5%(train)
        Loss:0.0001(valid)      |       Acc:83.4%(valid)
Epoch:9 | time in 0 minutes, 0 seconds
        Loss:0.0000(train)      |       Acc:99.3%(train)
        Loss:0.0001(valid)      |       Acc:81.1%(valid)
Epoch:10 | time in 0 minutes, 0 seconds
        Loss:0.0000(train)      |       Acc:99.6%(train)
        Loss:0.0002(valid)      |       Acc:82.1%(valid)
Epoch:11 | time in 0 minutes, 0 seconds
        Loss:0.0000(train)      |       Acc:99.8%(train)
        Loss:0.0001(valid)      |       Acc:84.7%(valid)
Epoch:12 | time in 0 minutes, 0 seconds
        Loss:0.0000(train)      |       Acc:99.9%(train)
        Loss:0.0001(valid)      |       Acc:83.2%(valid)
Epoch:13 | time in 0 minutes, 0 seconds
        Loss:0.0000(train)      |       Acc:100.0%(train)
        Loss:0.0001(valid)      |       Acc:83.7%(valid)
Epoch:14 | time in 0 minutes, 0 seconds
        Loss:0.0000(train)      |       Acc:100.0%(train)
        Loss:0.0001(valid)      |       Acc:83.2%(valid)
Epoch:15 | time in 0 minutes, 0 seconds
        Loss:0.0000(train)      |       Acc:100.0%(train)
        Loss:0.0001(valid)      |       Acc:84.2%(valid)
Epoch:16 | time in 0 minutes, 0 seconds
        Loss:0.0000(train)      |       Acc:100.0%(train)
        Loss:0.0001(valid)      |       Acc:85.0%(valid)
Epoch:17 | time in 0 minutes, 0 seconds
        Loss:0.0000(train)      |       Acc:100.0%(train)
        Loss:0.0001(valid)      |       Acc:85.0%(valid)
Epoch:18 | time in 0 minutes, 0 seconds
        Loss:0.0000(train)      |       Acc:100.0%(train)
        Loss:0.0000(valid)      |       Acc:85.5%(valid)
Epoch:19 | time in 0 minutes, 0 seconds
        Loss:0.0000(train)      |       Acc:100.0%(train)
        Loss:0.0001(valid)      |       Acc:84.2%(valid)
Epoch:20 | time in 0 minutes, 0 seconds
        Loss:0.0000(train)      |       Acc:100.0%(train)
        Loss:0.0001(valid)      |       Acc:84.5%(valid)
This is a Sports news
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/344599
推荐阅读
相关标签
  

闽ICP备14008679号