当前位置:   article > 正文

NLP实践:pytorch 实现基于LSTM的预训练模型以及词性分类任务_mytorch加载预训练模型nlp

mytorch加载预训练模型nlp

环境版本配置

1: CUDA版本: Cuda compilation tools, release 11.8, V11.8.89
在cmd中用以下指令查看

nvcc -V
  • 1

**2: cudnn版本:**8700

import torch #用这个查看
print(torch.backends.cudnn.version())
  • 1
  • 2

3: python版本: 3.9
4: Pytorch版本: torch 2.0.0+cu118
5: nltk: 3.8.1
在win+R cmd中用pip安装

pip install nltk 
  • 1

6:GPU 2070s O8G(移动端版本)

环境准备

1 安装nltk库和语料文件:看这篇文章
或者从我网盘上下包含代码和预料库的安装包

链接:https://pan.baidu.com/s/1geJ6bTNrV-WsPoBTMdxgWg?pwd=hf66 
提取码:hf66 
复制这段内容后打开百度网盘手机App,操作更方便哦
  • 1
  • 2
  • 3

2 安装pytorch
这里分为cpu版本和gpu版本;前者运算慢,后者运算快,懒得折腾的直接按安cpu版本,应该也能用.
⬇ ⬇ ⬇ ⬇ ⬇ ⬇ ⬇ ⬇ ⬇ ⬇ ⬇ ⬇ ⬇ ⬇

pip3 install torch torchvision torchaudio
  • 1

要折腾cuda版本文章好多,比如这个文章
3 应该能跑了

代码

有两个py文件,一个是自己写的工具库 还有一个是主代码,放一个文件夹下(尽量避免路径存在中文)

主代码:

import torch,os,pickle,re,keyboard,sys
from tqdm.auto import tqdm
from torch import Tensor, nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from tools import path_get, load_data, read_vocab,Vocab,base_check
from datetime import datetime
from time import sleep

embedding_dim = 128  #模型参数设置
hidden_dim = 256
batch_size = 32
WEIGHT_INIT_RANGE = 0.1

path1,path2,path3,path4,path5,path6=path_get()
#vocab_path(1),target_path(2),train_path(3),test_path(4),nltk_path(5),mode_path(6)
def torchcheck():
    try :
        import torch
    except:
        return "error"

def cudacheck():#torch检测函数
    import torch
    if torch.cuda.is_available():
        device_count = torch.cuda.device_count()
        device=torch.device("cuda:0")
        # 获取Torch的CUDNN版本
        cudnn_version = torch.backends.cudnn.version()
        note=(f"有 {device_count} 个可用的CUDA设备",f"Torch的CUDNN版本为 {cudnn_version}")
    else:
        device=torch.device("cpu")
        note=("没有可用的GPU")
    return device,note
 
class LstmDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return self.data[i]

def collate_fn(examples):#数据预处理
    lengths = torch.tensor([len(ex[0]) for ex in examples])#计算出每个样本的长度
    inputs = [torch.tensor(ex[0]) for ex in examples]
    targets = [torch.tensor(ex[1]) for ex in examples]
    #将每个样本的输入和目标分别转换为torch.tensor,并分别存储在inputs和targets中。
    inputs = pad_sequence(inputs, batch_first=True, padding_value=vocab["<pad>"])#填充操作
    targets = pad_sequence(targets, batch_first=True, padding_value=vocab["<pad>"])#填充操作
    return inputs, lengths, targets, inputs != vocab["<pad>"]

def init_weights(model):
    for param in model.parameters():
        torch.nn.init.uniform_(param, a=-WEIGHT_INIT_RANGE, b=WEIGHT_INIT_RANGE)

class LSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_class):
        super(LSTM, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
         #将输入的单词转换为对应的词嵌入(embedding)表示。(降维)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        #将词嵌入作为输入,并输出隐藏状态和细胞状态。
        #隐藏状态用于表示当前时间步的信息。
        #细胞状态用于存储和传递序列中的长期依赖关系。
        self.output = nn.Linear(hidden_dim, num_class)
        #使用nn.Linear定义一个线性层,将隐藏状态映射到输出类别的维度。(类似全连接)
        init_weights(self)

    def forward(self, inputs, lengths):
        embeddings = self.embeddings(inputs)#词嵌入表示
        x_pack = pack_padded_sequence(embeddings,lengths, batch_first=True, enforce_sorted=False)
        #打包操作
        hidden, (hn, cn) = self.lstm(x_pack)
        #将打包后的序列输入到LSTM层中,并获取隐藏状态和细胞状态。
        hidden, _ = pad_packed_sequence(hidden, batch_first=True)
        #解包,将压缩的序列解压缩为原始形状的序列。
        outputs = self.output(hidden)#将隐藏状态的最后一个时间步作为模型的输出
        log_probs = F.log_softmax(outputs, dim=-1)#通过线性层映射到输出类别的维度。(类似全连接)
        return log_probs

if os.path.exists(path1) and os.path.exists(path2) and os.path.exists(path3) and os.path.exists(path4):#加载文件
    # 从文件中加载
    print("数据准备......")
    with open(path3, 'rb') as f:
        train_data = pickle.load(f)
    with open(path4, 'rb') as f:
        test_data = pickle.load(f)
        # print(test_data[0])
    vocab=read_vocab(path1)
    pos_vocab=read_vocab(path2)
    # train_data, test_data, vocab, pos_vocab = load_data(mode)
    train_dataset = LstmDataset(train_data)
    test_dataset = LstmDataset(test_data)
    # print(test_dataset)
    # print(type(test_dataset))
else:
    ("未发现数据,正在导入")
    load_data("nltk")
    print("数据导入完毕,开始读取.....")
    with open(path3, 'rb') as f:
        train_data = pickle.load(f)
    with open(path4, 'rb') as f:
        test_data = pickle.load(f)
        # print(test_data[0])
    vocab=read_vocab(path1)
    pos_vocab=read_vocab(path4)
    # train_data, test_data, vocab, pos_vocab = load_data(mode)
    train_dataset = LstmDataset(train_data)
    test_dataset = LstmDataset(test_data)

def clear_input_buffer():
    try:
        import msvcrt
        while msvcrt.kbhit():
            msvcrt.getch()
    except ImportError:
        import termios
        termios.tcflush(sys.stdin, termios.TCIOFLUSH)
# num_class = len(pos_vocab)   #模型相关参数配置
# model = LSTM(len(vocab), embedding_dim, hidden_dim, num_class)
# model.to(device)

def mode_training(device): #训练函数并保存
    num_class = len(pos_vocab)   #模型相关参数配置
    model = LSTM(len(vocab), embedding_dim, hidden_dim, num_class)
    model.to(device)
    train_data_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
    # print("type: test_data_loader",type(test_data_loader))
    nll_loss = nn.NLLLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001) #使用Adam优化器
    model.train()
    check=10000#初始化判断参数
    i=1        #初始化训练轮数计数器
    while check>0.5:   #一直训练到总loss小于0.5为止(不建议太小,不收敛)
        # for epoch in range(num_epoch):
        total_loss = 0
        progress= tqdm(train_data_loader, desc=f"Training Epoch {i}")
        for batch in progress:
            inputs, lengths, targets, mask = [x.to(device) for x in batch]
            log_probs = model(inputs, lengths.cpu())#前向传播
            loss = nll_loss(log_probs[mask], targets[mask])#损失函数计算
            optimizer.zero_grad()#梯度归零
            loss.backward()#计算损失函数关于模型参数的梯度。
            optimizer.step()#更新模型参数
            total_loss += loss#总损失
            # print(f"Loss: {total_loss:.2f}")
            progress.set_postfix({"Loss": f"{total_loss:.2f}"})
        i=i+1
            # if total_loss<=2:
            #     print("train have done")
            # break
        print()
        check=float(f"{total_loss:.2f}")
    print("training is over")
    now = datetime.now()
    modname = now.strftime("mod""%m%d%H%M"".pth")#读取现实世界时间命名模型
    path=path6+'\\'+modname
    print("模型已保存到:",path)
    with open(path6+"\path.txt", 'a') as file: #保存path路径
        file.write(path+"\n")
    torch.save(model.state_dict(), path)

def mode_testing(path,device): #测试模型
    num_class = len(pos_vocab)   #模型相关参数配置
    model = LSTM(len(vocab), embedding_dim, hidden_dim, num_class)
    model.to(device)
    test_data_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collate_fn, shuffle=False)
    if os.path.exists(path):
        print("检测到模型")
        model.load_state_dict(torch.load(path),strict=False)
    acc = 0
    total = 0
    for batch in tqdm(test_data_loader, desc=f"Testing"):
        inputs, lengths, targets, mask = [x.to(device) for x in batch]
        with torch.no_grad():
            # print(inputs,lengths)
            output = model(inputs,lengths.cpu())
            # print(type(output.argmax))
            acc += (output.argmax(dim=-1) == targets)[mask].sum().item()
            total += mask.sum().item()
    print(f"Acc: {acc / total:.2f}")

def Testify(sentence,device):
    device,note=cudacheck()
    # mode_training(device)
    #获取最近一次训练得到的模型
    openpath=path6+"\\"+"path.txt"
    with open(openpath, "r",encoding="utf-8") as file:
        # 读取文件的每一行
        lines = file.readlines()
    # 去除每行末尾的换行符
    lines = [line.strip() for line in lines]
    # 打印不同的行
    test_mod_path = lines[-1]
    num_class = len(pos_vocab)   #模型相关参数配置
    model = LSTM(len(vocab), embedding_dim, hidden_dim, num_class)
    model.to(device)

    # sentence = "I love you"  #输入句子
    se=sentence.lower()#大写转小写
    sen = re.sub(r",", " , ", se)#逗号前后加空格
    print(sen)
    tokens = sen.split()
    print("分词结果",tokens)
    self=read_vocab(path1)
    input = Vocab.convert_tokens_to_ids(self,tokens)#文字转索引
    # print(input)se
    legth=[len(input)]
    tensor_input = torch.tensor([input]).to(device)
    tensor_lengths=torch.tensor(legth).to(device)

    # inputs=torch.tensor([[1327,    0,  368]]).to(device)#可行测试数据
    # lengths=torch.tensor([3]).to(device)
    # print("-------------",type(lengths))

    model.load_state_dict(torch.load(test_mod_path),strict=False)
    with torch.no_grad():
        # output = model(inputs,lengths.cpu())#可行数据测试代码
        output = model(tensor_input,tensor_lengths.cpu())
        # print("out",output)
        target=output.argmax(dim=-1).tolist()
        # print(target)
        for list in target:
            tag=Vocab.convert_ids_to_tokens(pos_vocab,list)
    print("词性对照结果为:",tag)

def printtt():#可视化界面函数
    print("=========================================")
    print("=----按 ESC 退出-----------------------==")
    print("=--------------------------------------==")
    print("=----按 T 尝试训练模型------------------==")
    print("=--------------------------------------==")
    print("=----按 P 检查torch环境----------------==")
    print("=--------------------------------------==")
    print("=----按 c 测试保存的模型文件-------------==")
    print("=--------------------------------------==")
    print("=----按D检测基础数据包------------------==")
    print("=--------------------------------------==")
    print("=---按A输入句子调用模型完成词性分类------==")
    print("=--------------------------------------==")
    print("=========================================")
    print("=--------------------------------BYHHF-==")
    sleep(1)
def print_block(i): #一个输出空格的函数,输入为要加几行空格
    for i in range(i):
        print( )

def main():
    printtt()
    # print_block(2)
    while True:
        device,note=cudacheck()
        if keyboard.is_pressed('esc'):
            print('Welcome back')
            break
        elif keyboard.is_pressed('t'):
            print("收到,开始训练")
            mode_training(device)
            printtt()
            print_block(3)
        elif keyboard.is_pressed('c'):
            print("收到,开始检测")
            openpath=path6+"\\"+"path.txt"
            with open(openpath, "r",encoding="utf-8") as file:
                lines = file.readlines()
            lines = [line.strip() for line in lines]
            test_mod_path = lines[-1]
            mode_testing(test_mod_path,device)
            print_block(2)
            printtt()
            print_block(3)
        elif keyboard.is_pressed('d'):
            base_check()
            printtt()
            print_block(3)
        elif keyboard.is_pressed('p'):
                print(note)
                printtt()
                print_block(3)
        elif keyboard.is_pressed('a'):
            while True:
                clear_input_buffer()
                print("请输入句子, 或者输入break退出识别模式")
                print( )
                sentence = input()
                if sentence==('break'):
                    break
                else:
                    Testify(sentence,device)
                    print( )
            printtt()
            print_block(3)

if __name__ == '__main__':
        main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299

工具库代码:

#自己写(改)的工具的库
from collections import defaultdict
import pickle
from nltk import data
from nltk.corpus import treebank
import os


class Vocab:  #词汇表类
    def __init__(self, tokens=None):
        self.idx_to_token = list()
        self.token_to_idx = dict()
        if tokens is not None:
            if "<unk>" not in tokens:
                tokens = tokens + ["<unk>"]
            for token in tokens:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1
            self.unk = self.token_to_idx['<unk>']

    def build(cls, text, min_freq=1, reserved_tokens=None):
        token_freqs = defaultdict(int)
        for sentence in text:
            for token in sentence:
                token_freqs[token] += 1
        uniq_tokens = ["<unk>"] + (reserved_tokens if reserved_tokens else [])
        uniq_tokens += [token for token, freq in token_freqs.items() \
                        if freq >= min_freq and token != "<unk>"]
        return cls(uniq_tokens)

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, token):
        return self.token_to_idx.get(token, self.unk)

    def convert_tokens_to_ids(self, tokens):
        return [self[token] for token in tokens]

    def convert_ids_to_tokens(self, indices):
        return [self.idx_to_token[index] for index in indices]

def load_data(mode):   #检查导入数据
    vocab_path,target_path,train_path,test_path,nltk_path,mode_path=path_get()
    if mode == "nltk":
        data.path.append(nltk_path)
        # sents, postags = zip(*(zip(*sent) for sent in treebank.tagged_sents()))
        sents, postags = zip(*(zip(*[(word.lower(), pos) for word, pos in sent]) for sent in treebank.tagged_sents()))
    if os.path.exists(vocab_path) and os.path.exists(target_path):
        print("读取到vocab文件")
        with open(vocab_path, 'r') as reader:
            vocab_tokens = reader.read().splitlines()
        with open(target_path, 'r') as reader:
            tag_vocab_tokens = reader.read().splitlines()
        vocab = Vocab(vocab_tokens)
        tag_vocab = Vocab(tag_vocab_tokens)
    else:
            print("未发现文件,正在生成")
            vocab = Vocab.build(Vocab, sents, reserved_tokens=["<pad>"])
            tag_vocab = Vocab.build(Vocab, postags)
            with open(vocab_path, 'w') as writer:
                writer.write("\n".join(vocab.idx_to_token))
            with open(target_path, 'w') as writer:
                writer.write("\n".join(tag_vocab.idx_to_token))
    train_data = [(vocab.convert_tokens_to_ids(sentence), 
                   tag_vocab.convert_tokens_to_ids(tags)) for sentence, 
                   tags in zip(sents[:3000], postags[:3000])]
    test_data = [(vocab.convert_tokens_to_ids(sentence), 
                  tag_vocab.convert_tokens_to_ids(tags)) for sentence, 
                  tags in zip(sents[3000:], postags[3000:])]
    with open(train_path, 'wb') as f:
        pickle.dump(train_data, f)
# 将test_data保存为文件
    with open(test_path, 'wb') as f:
        pickle.dump(test_data, f)
    return train_data, test_data, vocab, tag_vocab

def read_vocab(path): #读取词汇表程序
    with open(path, 'r') as f:
        tokens = f.read().split('\n')
    return Vocab(tokens)

def word_tokenize(text):  #分词函数,输入句子输出分词
    data.path.append(r"S:\downloads\nltk\nltk_data")
    from nltk.tokenize import word_tokenize
    words = word_tokenize(text)
    return words

def path_get():     # 获取文件夹的路径
    current_folder = os.path.dirname(os.path.abspath(__file__))# 当前路径(代码文件夹)
    parent_folder = os.path.dirname(current_folder)#返回上一级(主文件夹)
    locate=os.path.join(parent_folder, "vocab")#访问数据文件夹
    # print("1",locate)
    nltklocate=os.path.join(parent_folder, "nltk")#访问nltk文件夹
    # print("2",nltklocate)
    modlocate=os.path.join(parent_folder, "mode")#访问数据文件夹

    vocab_path = os.path.join(locate, "vocab.txt")
    target_path = os.path.join(locate, "vocab_target.txt")
    train_path = os.path.join(locate, "train_data.pkl")
    test_path = os.path.join(locate, "test_data.pkl")
    nltk_path=os.path.join(nltklocate, "nltk_data")
    # mode_path=os.path.join(modlocate, "mode.pth")

    return vocab_path,target_path,train_path,test_path,nltk_path,modlocate

def base_check():
    vocab_path,target_path,train_path,test_path,nltk_path,modlocate=path_get()
    if os.path.exists(vocab_path):
        print("数据索引存在")
    if os.path.exists(target_path):
        print("目标词性索引存在")
    if os.path.exists(train_path):
        print("预设训练集存在")
    if os.path.exists(test_path):
        print("预设测试集存在")
  
# def main():   #测试tool用的
# if __name__ == '__main__':
#     main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120

玄学⬇

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/584219
推荐阅读
相关标签