当前位置:   article > 正文

LSTM--短句预测_lstm预测句子

lstm预测句子

封装成类

import tensorflow_pass_warning_wgs
import tensorflow as tf, numpy
import tensorflow.contrib as con


class model():
    def __init__(self, sentences):
        self.x_data, self.y_data = self.get_data(sentences)
        self.hidden_size = len(self.idx2char) * 2   # 隐藏层单元个数26 cell中神经元的个数,不一定要和序列数相等
        self.batch_size = 1     # 批大小
        self.sequence_length = len(sentences) - 1   # 序列长度
        self.num_classes = len(self.idx2char)              # 最终输出大小
        self.sentences = sentences

    def get_data(self, sentences):
        # 构建词集向量
        self.idx2char = list(set(sentences))
        self.idx2char.sort()
        return self.char2collection(self.idx2char, sentences)


    # 词袋模型
    def char2collection(self, idx2char, sentences):
        # 词集向量转为字典
        idx_dict = {c: i for i, c in enumerate(idx2char)}
        # print(idx_dict)

        # 构建词袋模型
        sample_list = [idx_dict[s] for s in sentences]
        # print(sample_list)

        # 构建x, y   输入:去尾  输出:掐头
        x_data = sample_list[:-1]
        y_data = sample_list[1:]
        # print(self.collection2str(x_data))
        # print(self.collection2str(y_data))
        return [x_data], [y_data]

    # 词袋模型转为字符串
    def collection2str(self, sample):
        return ''.join([self.idx2char[c] for c in numpy.squeeze(sample)])

    # 开启会话
    def sess_start(self):
        self.sess = tf.Session()  # 会话
        self.sess.run(tf.global_variables_initializer())

    # 关闭会话
    def sess_close(self):
        self.sess.close()

    # 训练
    def train(self, teststr):
        # 占位符
        X, Y = tf.placeholder(tf.int32, shape=[None, self.sequence_length]), tf.placeholder(tf.int32, shape=[None, self.sequence_length])
        X_oneHot = tf.one_hot(X, self.num_classes)      # 变为独热编码: 1 -> 0 1 0 0 0 0 0 0 0 0 0 0 0

        # 构建模型
        cell = con.rnn.LSTMCell(num_units=self.hidden_size)
        initial_state = cell.zero_state(self.batch_size, tf.float32)
        outputs, state = tf.nn.dynamic_rnn(cell, X_oneHot, initial_state=initial_state, dtype=tf.float32)

        # 全连接
        outputs = tf.reshape(outputs, [-1, self.hidden_size])
        logits = con.layers.fully_connected(outputs, self.num_classes, activation_fn=None)
        logits = tf.reshape(logits, [self.batch_size, self.sequence_length, self.num_classes])

        # 代价 优化
        weight = tf.ones([self.batch_size, self.sequence_length])
        cost = tf.reduce_mean(con.seq2seq.sequence_loss(logits=logits, targets=Y, weights=weight))
        optimizer = tf.train.AdamOptimizer(0.1).minimize(cost)

        # 预测
        prediction = tf.argmax(logits, 2)

        self.sess_start()

        for i in range(100):
            cost_, pre, _ = self.sess.run([cost, prediction, optimizer], feed_dict={X: self.x_data, Y: self.y_data})

            preStr = self.collection2str(pre)
            print(f'e: {i + 1}, cost: {cost_}, prediction: {preStr}')

            if self.sentences[1:] == preStr:
                break

        # 测试
        xtest, ytest = self.char2collection(self.idx2char, teststr)
        p = self.sess.run(prediction, feed_dict={X: xtest})
        print('p: ', self.collection2str(p), ', label: ', teststr)



sample = " if you want you like"
sample_tets = " if you want you like"
m = model(sample)
m.train(sample_tets)
m.sess_close()


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100

普通方法

# 循环神经网络,短句训练
import tensorflow as tf
from tensorflow import contrib
import numpy as np
from tensorflow.python.ops.rnn import dynamic_rnn
import tensorflow.contrib as rnn
# import os
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow_pass_warning_wgs

tf.set_random_seed(777)

# 构造数据集
sample = " if you want you like"
# sample = " take your time come on"
idx2char = list(set(sample))  # 去重放列表里,set无序13
idx2char.sort()
print(idx2char)  # ['e', 'w', 'u', 'k', 'i', 't', ' ', 'o', 'a', 'n', 'f', 'y', 'l']

char2idx = {c: i for i, c in enumerate(idx2char)}  # 转为字典  把字母作为键 它的索引作为值
print(char2idx)  # {'e': 0, 'w': 1, 'u': 2, 'k': 3, 'i': 4, 't': 5, ' ': 6, 'o': 7, 'a': 8, 'n': 9, 'f': 10, 'y': 11, 'l': 12}

sample_idx = [char2idx[c] for c in sample]  # 在字典里取出对应值
print(sample_idx)  # [6, 4, 10, 6, 11, 7, 2, 6, 1, 8, 9, 5, 6, 11, 7, 2, 6, 12, 4, 3, 0]

x_data = [sample_idx[:-1]]  # 输入去掉最后一个
y_data = [sample_idx[1:]]  # 输出去掉第一个

# 设置构建RNN所需要的参数
dic_size = len(char2idx)  # 字典长度13
rnn_hidden_size = len(char2idx) * 2  # 隐藏层单元个数26 cell中神经元的个数,不一定要和序列数相等
batch_size = 1  # 批大小
sequence_length = len(sample) - 1  # 序列长度(时间步数)20
num_classes = len(char2idx)  # 最终输出大小13(RNN或softmax等)

# 定义占位符并且进行独热编码的转化
X = tf.placeholder(tf.int32, [None, sequence_length])  # X data(?, 20)
Y = tf.placeholder(tf.int32, [None, sequence_length])  # Y label(?, 20)
X_one_hot = tf.one_hot(X, num_classes)  # 变为独热编码: 1 -> 0 1 0 0 0 0 0 0 0 0 0 0 0
# print(X_one_hot.shape)  #(?, 20, 13)

# 构建RNN
# cell = tf.contrib.rnn.BasicLSTMCell(num_units=rnn_hidden_size, state_is_tuple=True) #必须写它规定了状态信息的格式
# cell = tf.contrib.rnn.LSTMCell(num_units=rnn_hidden_size, state_is_tuple=True) #必须写它规定了状态信息的格式
cell = tf.contrib.rnn.LSTMCell(num_units=rnn_hidden_size, state_is_tuple=True)
# cell = tf.contrib.rnn.GRUCell(num_units=rnn_hidden_size)
initial_state = cell.zero_state(batch_size, tf.float32)  # RNN的初始化状态
outputs, _states = dynamic_rnn(cell, X_one_hot, initial_state=initial_state, dtype=tf.float32)
# print(outputs.shape) # (1,20,26)
outputs = tf.reshape(outputs, [-1, rnn_hidden_size])  # 全连接前需要变为二维数据 (20,26)
# 加一层全连接,相当于加一层深度,使预测更准确
outputs = contrib.layers.fully_connected(inputs=outputs, num_outputs=num_classes, activation_fn=None)
# outputs = rnn.layers.fully_connected(inputs=outputs,num_outputs=num_classes,activation_fn=None)
# print(outputs.shape) #经过一层全连接 变为(1*20,13) [batch_size*sequence_length,num_classes]

outputs = tf.reshape(outputs, [batch_size, sequence_length, num_classes])  # 变为3维(1,20,13)
weights = tf.ones([batch_size, sequence_length])  # weight为t时与t+1时之间的权重
# 计算序列损失
# sequence_loss =rnn.seq2seq.sequence_loss(logits=outputs,targets=Y,weights=weights)
sequence_loss = tf.contrib.seq2seq.sequence_loss(logits=outputs, targets=Y, weights=weights)  # 三维数据
loss = tf.reduce_mean(sequence_loss)

train = tf.train.AdamOptimizer(learning_rate=0.1).minimize(loss)

# 预测值
prediction = tf.argmax(outputs, axis=2)  # 最后的outputs是三维的 所以axis=2
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(100):
        sl, l, _ = sess.run([sequence_loss, loss, train], feed_dict={X: x_data, Y: y_data})
        result = sess.run(prediction, feed_dict={X: x_data})
        # 用字典输出字符 print char using dic
        result_str = [idx2char[c] for c in np.squeeze(result)]
        print(i, sl, "损失:", l, "预测:[", ''.join(result_str), ']')
        if sample[1:] == ''.join(result_str):
            break

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/575485
推荐阅读
相关标签
  

闽ICP备14008679号