当前位置:   article > 正文

pytorch 基于sqs2sqs的中文聊天机器人_"enc_batch = batch[\"input_batch\"] src_mask = enc

"enc_batch = batch[\"input_batch\"] src_mask = enc_batch.data.eq(config.pad_idx"

由于数据量小,以及我目前无法处理引入“unk”值导致准确率较高的情况,所以还需要进行优化,目前先用这个代码,等我优化好后重新上传,主要分为三步,第一数据的预处理,第二模型的构建,第三测试集处理

第一步分为:

  • 构建数据,需要构建enc_input,dec_output, dec_input
  • 结巴分类以及去掉停用词
  • 给enc_input输入值添加一个结束状态,给dec输入状态添加一个开始状态,dec输出状态结束状态
  • 将文字转化为数字
  • 将数据转化为pytorch专用数据类型,方便批量化处理

第二步简单分为:

  • embeding进行词向量话
  • 两层runn进行模型处理
  • 最后一个全链接层

第三步数据预处理和第一步一样,但是要注意dec_input应该为空

有时间我会详细说下seq2seq模型,目前我先研究如何提高准确率下,如果有大佬,希望多指点下,嘻嘻

import pandas as pd
import jieba
from torch.utils import data
import torch
import numpy as np
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torch.optim as optim
import warnings
warnings.filterwarnings('ignore')


def cal_clear_word(test):
    stoplist = [' ', '\n', ',']

    def function(a):
        word_list = [w for w in jieba.cut(a) if w not in list(stoplist)]
        return word_list

    test['quest'] = test.apply(lambda x: function(x['quest']), axis=1)
    test['anwer'] = test.apply(lambda x: function(x['anwer']), axis=1)
    return test

def cal_update_date(test, sequence_length):
    def prepare_sequence(seq):
        idxs = [w for w in seq]
        if len(idxs) >= sequence_length:
            idxs = idxs[:sequence_length]
        else:
            pad_num = sequence_length - len(idxs)
            for i in range(pad_num):
                idxs.append('UNK')
        return idxs

    test['quest'] = test.apply(lambda x: prepare_sequence(x['quest']), axis=1)
    test['anwer'] = test.apply(lambda x: prepare_sequence(x['anwer']), axis=1)
    return test


def cal_add_status(test):
    test['enc_input'] = test['quest']
    test['dec_input'] = test['anwer']
    test['dec_output'] = test['anwer']
    test=test[['enc_input','dec_input','dec_output']]
    for i, j, h in test.values:
        i.append('E')
        j.insert(0, "S")
        h.append('E')
    return test



def cal_word_to_ix(test):
    word_to_ix = {}  # 单词的索引字典
    for enc_input, dec_input,dec_output in test.values:
        for word in enc_input:
            if word not in word_to_ix:
                word_to_ix[word] = len(word_to_ix)

    for enc_input, dec_input, dec_output in test.values:
        for word in dec_input:
            if word not in word_to_ix:
                word_to_ix[word] = len(word_to_ix)

    for enc_input, dec_input, dec_output in test.values:
        for word in dec_output:
            if word not in word_to_ix:
                word_to_ix[word] = len(word_to_ix)

    def prepare_sequence(seq, to_ix):
        idxs = [to_ix[w] for w in seq]
        return idxs
    test['enc_input'] = test.apply(lambda x: prepare_sequence(x['enc_input'], word_to_ix), axis=1)
    test['dec_input'] = test.apply(lambda x: prepare_sequence(x['dec_input'], word_to_ix), axis=1)
    test['dec_output'] = test.apply(lambda x: prepare_sequence(x['dec_output'], word_to_ix), axis=1)
    return test, len(word_to_ix), word_to_ix

class TestDataset(data.Dataset):#继承Dataset
    def __init__(self,test):
        self.enc_input=test['enc_input']
        self.dec_input=test['dec_input']
        self.dec_ouput=test['dec_output']

    def __getitem__(self, index):
        #把numpy转换为Tensor
        enc_input=torch.from_numpy(np.array(self.enc_input[index]))
        dec_input=torch.from_numpy(np.array(self.dec_input[index]))
        dec_ouput=torch.from_numpy(np.array(self.dec_ouput[index]))

        return enc_input,dec_input,dec_ouput

    def __len__(self):
        return len(self.enc_input)


class Seq2Seq(nn.Module):
    def __init__(self,n_class,n_hidden):
        super(Seq2Seq, self).__init__()
        self.W = nn.Embedding(vocab_size, embedding_size)
        # self.W = nn.Embedding(vocab_size, embedding_size)
        # self.W = nn.Embedding(vocab_size, embedding_size)
        self.encoder = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)  # encoder
        self.decoder = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)  # decoder
        self.fc = nn.Linear(n_hidden, n_class)

    def forward(self, enc_input, enc_hidden, dec_input):
        # enc_input(=input_batch): [batch_size, n_step+1, n_class]
        # dec_inpu(=output_batch): [batch_size, n_step+1, n_class]

        enc_input = self.W(enc_input)  # [batch_size, sequence_length, embedding_size]
        dec_input = self.W(dec_input)  # [batch_size, sequence_length, embedding_size]

        enc_input = enc_input.transpose(0, 1)  # enc_input: [n_step+1, batch_size, n_class]
        dec_input = dec_input.transpose(0, 1)  # dec_input: [n_step+1, batch_size, n_class]

        # h_t : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
        _, h_t = self.encoder(enc_input, enc_hidden)
        # outputs : [n_step+1, batch_size, num_directions(=1) * n_hidden(=128)]
        outputs, _ = self.decoder(dec_input, h_t)

        model = self.fc(outputs)  # model : [n_step+1, batch_size, n_class]
        return model

# 构建数据,需要构建enc_input,dec_output, dec_input
data_dict={'quest':['好好写博客','我想去大厂','今天打王者嘛','明天要加班'],
      'anwer':['加油噢','肯定可以的','打呀,放假为啥不打','五一加屁班']}
train_df = pd.DataFrame(data_dict)
# 结巴分类以及去掉停用词
return_df = cal_clear_word(train_df)
n_step = max([max(len(i), len(j)) for i, j in return_df.values])
return_df = cal_update_date(return_df, n_step)
# 给enc_input输入值添加一个结束状态,给dec输入状态添加一个开始状态,dec输出状态结束状态
return_df = cal_add_status(return_df)
# 将文字转化为数字
return_df, vocab_size,letter2idx = cal_word_to_ix(return_df)
# 将数据转化为pytorch专用数据类型,方便批量化处理
result_df =TestDataset(return_df)
batch_size = 2
test_loader = data.DataLoader(result_df,batch_size,shuffle=False)
# 调用模型
n_class = vocab_size
embedding_size = n_class
n_hidden = 128
# n_class,n_hidden
model = Seq2Seq(n_class,n_hidden).to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(1000):
    for enc_input_batch, dec_input_batch, dec_output_batch in test_loader:
        # make hidden shape [num_layers * num_directions, batch_size, n_hidden]
        h_0 = torch.zeros(1, batch_size, n_hidden).to(device)

        (enc_input_batch, dec_intput_batch, dec_output_batch) = (
        enc_input_batch.to(device).long(), dec_input_batch.to(device).long(), dec_output_batch.to(device).long())
        # enc_input_batch : [batch_size, n_step+1, n_class]
        # dec_intput_batch : [batch_size, n_step+1, n_class]
        # dec_output_batch : [batch_size, n_step+1], not one-hot
        pred = model(enc_input_batch, h_0, dec_intput_batch)
        # pred : [n_step+1, batch_size, n_class]
        pred = pred.transpose(0, 1)  # [batch_size, n_step+1(=6), n_class]
        loss = 0
        for i in range(len(dec_output_batch)):
            # pred[i] : [n_step+1, n_class]
            # dec_output_batch[i] : [n_step+1]
            loss += criterion(pred[i], dec_output_batch[i])
        if (epoch + 1) % 500 == 0:
            print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def make_data(word, n_step, to_ix):
    stoplist = [' ', '\n', ',']
    ord_list = [w for w in jieba.cut(word) if w not in list(stoplist)]
    idxs = [w for w in ord_list]
    if len(idxs) >= n_step:
        idxs = idxs[:n_step]
    else:
        pad_num = n_step - len(idxs)
        for i in range(pad_num):
            idxs.append('UNK')
    enc_input = []
    for i in idxs:
        enc_input.append(i)
    enc_input.append('E')
    enc_input = [to_ix[n] for n in enc_input]  
    dec_input = []
    for i in range(n_step):
        dec_input.append('UNK')
    dec_input.insert(0, "S")
    dec_input = [to_ix[n] for n in dec_input]  
    enc_input = torch.Tensor(enc_input)
    dec_input = torch.Tensor(dec_input)
    enc_input = torch.unsqueeze(enc_input, 0)
    dec_input = torch.unsqueeze(dec_input, 0)
    # enc_input.view(1,n_step+1)
    # dec_input.view(1, n_step+1)
    # dec_ouput = torch.from_numpy(np.array(enc_input))
    # dec_ouput = torch.from_numpy(np.array(enc_input))

    return enc_input, dec_input

# Test
letter = {value:key for key, value in letter2idx.items()}
def translate(word):
    enc_input, dec_input = make_data(word, n_step, letter2idx)
    enc_input, dec_input = enc_input.to(device).long(), dec_input.to(device).long()
    # make hidden shape [num_layers * num_directions, batch_size, n_hidden]
    hidden = torch.zeros(1, 1, n_hidden).to(device)
    output = model(enc_input, hidden, dec_input)
    # output : [n_step+1, batch_size, n_class]
    predict = output.data.max(2, keepdim=True)[1]  # select n_class dimension
    predict = predict.view(n_step+1)
    predict = predict.numpy()
    decoded = [letter[i] for i in predict]
    translated = ''.join(decoded)
    translated = translated.replace('UNK', ' ')
    translated = translated.replace('S', ' ')
    return translated


print('test')
print('今天打王者嘛 ->', translate('今天打王者嘛'))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
Epoch: 0500 cost = 0.001520
Epoch: 0500 cost = 0.001702
Epoch: 1000 cost = 0.000421
Epoch: 1000 cost = 0.000471
test
今天打王者嘛 ->  肯定     
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/484014
推荐阅读
相关标签
  

闽ICP备14008679号