当前位置:   article > 正文

AG_NEWS数据集的编码分类和预测_agnews数据集ltsm

agnews数据集ltsm
##nlp入门小白,这个数据集有很多下载资源可以网上找一下下载到本地,结合了好多博主的东西
path = r'自己的地址datasets\AG_NEWS.data\datasets\AG_NEWS'

import torch
from torchtext.datasets import AG_NEWS

# train_iter = AG_NEWS(root=path, split='train')
#####
import pandas as pd
def load_data(csv_file):
    df = pd.read_csv(csv_file, header=None)  # pd默认第一行不读取,所以添加 header
    dataTmep = []

    # 逐行读取,_ 行号,row 内容
    for _, row in df.iterrows():
        label = row[0]
        context = row[1] + row[2]  # 将标题,内容合并
        dataTmep.append((label, context))
    return dataTmep
train_iter=train_dataset = load_data(r"D:\study\dataset\AG News\train.csv")
test_iter=test_dataset=load_data(r"D:\study\dataset\AG News\test.csv")
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import vocab

tokenizer = get_tokenizer('basic_english')
counter = Counter()
for (label, line) in train_iter:
    counter.update(tokenizer(line))
#这我自己添加的,后面预测一直有报错加了就没了
for (label, line) in test_iter:
    counter.update(tokenizer(line))
vocab = vocab(counter, min_freq=1)


text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]

label_pipeline = lambda x: int(x) - 1


from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))

    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)


from torch import nn
class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        output = self.fc(embedded)
        return output


num_class = 4
vocab_size = len(vocab)
emsize = 64
model = TextClassificationModel(vocab_size, emsize, num_class).to(device)


import time
def train(dataloader):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(dataloader):
        optimizer.zero_grad()
        predited_label = model(text, offsets)
        loss = criterion(predited_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predited_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches, accuracy {:8.3f}'.format(epoch, idx, len(dataloader), total_acc / total_count))
            total_acc, total_count = 0, 0
            start_time = time.time()


def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predited_label = model(text, offsets)

            total_acc += (predited_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc / total_count


def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        output = model(text.to(device), torch.tensor([0]).to(device))#将输入张量移动到与模型相同的设备上
        # output = model(text, torch.tensor([0]))
        return output.argmax(1).item() + 1


from torch.utils.data.dataset import random_split

if __name__ == '__main__':

    EPOCHS = 10
    LR = 5
    BATCH_SIZE = 64
    #
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
    total_accu = None
    # train_iter, test_iter = AG_NEWS(root=path)
    train_dataset = load_data(r"D:\study\dataset\AG News\train.csv")
    test_dataset = load_data(r"D:\study\dataset\AG News\test.csv")
    num_train = int(len(train_dataset) * 0.95)
    split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])
    #
    train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)      # shuffle表示随机打乱
    valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    #
    for epoch in range(1, EPOCHS + 1):
        epoch_start_time = time.time()
        train(train_dataloader)
        accu_val = evaluate(valid_dataloader)
        if total_accu is not None and total_accu > accu_val:
            scheduler.step()
        else:
            total_accu = accu_val
        print('-' * 59)
        print('| end of epoch {:3d} | time: {:5.2f}s | '
              'valid accuracy {:8.3f} '.format(epoch, time.time() - epoch_start_time, accu_val))
        print('-' * 59)
    #
    #
    print('Checking the results of test dataset.')

    accu_test = evaluate(test_dataloader)
    print('test accuracy {:8.3f}'.format(accu_test))
    #
    torch.save(model.state_dict(), r'D:\study\dataset\AG News\model_TextClassification.pth')
#以下是预测内容
    # model.load_state_dict(torch.load(r'D:\study\dataset\AG News\model_TextClassification.pth'))
    # #
    # ag_news_label = {1: "World",
    #                  2: "Sports",
    #                  3: "Business",
    #                  4: "Sci/Tec"}
    #
    # # ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was enduring the season’s worst weather conditions on Sunday at The Open on his way to a closing 75 at Royal Portrush, which considering the wind and the rain was a respectable showing. Thursday’s first round at the WGC-FedEx St. Jude Invitational was another story. With temperatures in the mid-80s and hardly any wind, the Spaniard was 13 strokes better in a flawless round. Thanks to his best putting performance on the PGA Tour, Rahm finished with an 8-under 62 for a three-stroke lead, which was even more impressive considering he’d never played the front nine at TPC Southwind."
    # ex_text_str = 'Beijing of Automation, Beijing Institute of Technology'
    # # model = model.to("cpu")
    #
    # print("This is a %s news" % ag_news_label[predict(ex_text_str, text_pipeline)])

#原博 参考博   不完全一样做了一下修改,我运行的时候虽然下载到本地但还是那什么~【PyTorch】7 文本分类TorchText实战——AG_NEWS四类别新闻分类_agnews依据新闻标题写content-CSDN博客Torchtext下的AG_NEWS数据集进行分类(官方文档代码)-CSDN博客

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
  

闽ICP备14008679号