当前位置:   article > 正文

LSTM--长句预测_lstm-长句预测

lstm-长句预测

封装成类

import tensorflow_pass_warning_wgs
import tensorflow as tf, numpy
import tensorflow.contrib as con

tf.set_random_seed(1)


class model():
    def __init__(self, sentence):
        self.sentence = sentence
        self.sequnec_length = 10  # 序列长度
        self.X_data, self.Y_data = self.get_data(sentence)
        self.hidden_size = 50  # 隐藏神经元数量
        self.num_classes = len(self.idx2char)
        self.batch_size = len(self.X_data)

    def get_data(self, sentence):
        # 词集向量
        self.idx2char = list(set(sentence))

        # 转字典
        char_dict = {w: i for i, w in enumerate(self.idx2char)}

        # 构造数据集
        x_data, y_data = [], []
        for i in range(0, len(sentence) - self.sequnec_length):
            x_str = sentence[i: i + self.sequnec_length]    # 措开取
            y_str = sentence[i + 1: i + self.sequnec_length + 1]

            print(i, x_str, '->', y_str)

            # 词袋模型
            x, y = [char_dict[c] for c in x_str], [char_dict[c] for c in y_str]

            x_data.append(x)
            y_data.append(y)
        return x_data, y_data

    # 训练
    def train(self):
        X, Y = tf.placeholder(tf.int32, [None, self.sequnec_length]), tf.placeholder(tf.int32, [None, self.sequnec_length])
        X_oneHot = tf.one_hot(X, self.num_classes)

        cells = [con.rnn.LSTMCell(num_units=self.hidden_size) for _ in range(2)]    # 深层RNN,多个RNN基础单元
        mul_cells = con.rnn.MultiRNNCell(cells)     # 堆叠RNN基础单元
        outputs, state = tf.nn.dynamic_rnn(mul_cells, X_oneHot, dtype=tf.float32)

        outputs = tf.reshape(outputs, shape=[-1, self.hidden_size])
        logits = con.layers.fully_connected(outputs, self.num_classes, activation_fn=None)      # 全连接
        logits = tf.reshape(logits, shape=[self.batch_size, self.sequnec_length, self.num_classes])

        weight = tf.ones(shape=[self.batch_size, self.sequnec_length])
        cost = tf.reduce_mean(con.seq2seq.sequence_loss(logits=logits, targets=Y, weights=weight))
        optimizer = tf.train.AdamOptimizer(0.1).minimize(cost)

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            for i in range(500):
                c_, lo_, _ = sess.run([cost, logits, optimizer], feed_dict={X: self.X_data, Y: self.Y_data})

                for j, res in enumerate(lo_):
                    index = numpy.argmax(res, 1)
                    if j == 0:
                        ret = ''.join([self.idx2char[c] for c in index])
                    else:
                        ret += self.idx2char[index[-1]]
                print(i, c_, ret)
                if ret == self.sentence[1:]:
                    break


sentence = ("if you want to build a ship, don't drum up people together to "
            "collect wood and don't assign them tasks and work, but rather "
            "teach them to long for the endless immensity of the sea.")
m = model(sentence)
m.train()



  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79

普通方法

from __future__ import print_function

import tensorflow as tf
import numpy as np
from tensorflow.contrib import rnn
import tensorflow.contrib as rnnn
from tensorflow.python.ops.rnn import dynamic_rnn

tf.set_random_seed(777)  # reproducibility

sentence = ("if you want to build a ship, don't drum up people together to "
            "collect wood and don't assign them tasks and work, but rather "
            "teach them to long for the endless immensity of the sea.")

char_set = list(set(sentence))  # print(len(char_set))  #25
char_dic = {w: i for i, w in enumerate(char_set)}
hidden_size = 50  # len(char_set)    #25

sequence_length = 10  # Any arbitrary number
data_dim = len(char_set)  # 25
num_classes = len(char_set)  # 25
learning_rate = 0.1

# 构造数据集
dataX = []
dataY = []
for i in range(0, len(sentence) - sequence_length):
    x_str = sentence[i:i + sequence_length]
    y_str = sentence[i + 1: i + sequence_length + 1]
    print(i, x_str, '->', y_str)
    x = [char_dic[c] for c in x_str]  # x str to index 字符转数字
    y = [char_dic[c] for c in y_str]  # y str to index
    dataX.append(x)
    dataY.append(y)

batch_size = len(dataX)  # 170
print(batch_size)
X = tf.placeholder(tf.int32, [None, sequence_length])
Y = tf.placeholder(tf.int32, [None, sequence_length])

X_one_hot = tf.one_hot(X, num_classes)  # 独热编码 #print(X_one_hot) (?, 10, 25)


# 建一个有隐藏单元的LSTM,Make a lstm cell with hidden_size (each unit output vector size)
def cell():
    # cell = rnn.BasicLSTMCell(hidden_size, state_is_tuple=True)
    # cell = rnn.GRUCell(hidden_size)
    cell = rnn.LSTMCell(hidden_size, state_is_tuple=True)
    return cell


multi_cells = rnn.MultiRNNCell([cell() for _ in range(2)], state_is_tuple=True)

# outputs:展开隐藏层 unfolding size x hidden size, state = hidden size
outputs, _states = dynamic_rnn(multi_cells, X_one_hot, dtype=tf.float32)
# 全连接层FC layer
X_for_fc = tf.reshape(outputs, [-1, hidden_size])
outputs = tf.contrib.layers.fully_connected(X_for_fc, num_classes, activation_fn=None)
# outputs = rnnn.layers.fully_connected(X_for_fc,num_classes,activation_fn=None)
# print(outputs.shape) #(?,25)
# 改变维度准备计算序列损失reshape out for sequence_loss
outputs = tf.reshape(outputs, [batch_size, sequence_length, num_classes])  # (170, 10, 25)
weights = tf.ones([batch_size, sequence_length])  # 所有的权重都是1 All weights are 1 (equal weights)
# 计算损失值
sequence_loss = tf.contrib.seq2seq.sequence_loss(logits=outputs, targets=Y, weights=weights)
# sequence_loss = rnnn.seq2seq.sequence_loss(logits=outputs,targets=Y,weights=weights)
mean_loss = tf.reduce_mean(sequence_loss)
train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(mean_loss)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(500):  # (500)
    _, lossval, results = sess.run(
        [train_op, mean_loss, outputs], feed_dict={X: dataX, Y: dataY})
    # print(results.shape)  (170,10,25)
    # if i == 49:
    #     for j, result in enumerate(results):  #j:[0,170)   result:(10,25)
    #         index = np.argmax(result, axis=1)
    # print(i, j, ''.join([char_set[t] for t in index]), l)
    results = sess.run(outputs, feed_dict={X: dataX})  # (170,10,25)
    for j, result in enumerate(results):
        index = np.argmax(result, axis=1)
        # print('----------------------------------------')
        # print(index)
        if j is 0:  # 第一个结果10个字符组成一个句子 print all for the first result to make a sentence
            ret = ''.join([char_set[t] for t in index])
        else:  # 其它取最后一个字符
            ret = ret + char_set[index[-1]]
    print(i, lossval, ret)
    if ret == sentence[1:]:
        break

# #输出每个结果的最后一个字符检测效果 Let's print the last char of each result to check it works
# results = sess.run(outputs, feed_dict={X: dataX})  #(170,10,25)
# for j, result in enumerate(results):
#     index = np.argmax(result, axis=1)
#     if j is 0:  #第一个结果10个字符组成一个句子 print all for the first result to make a sentence
#         # print(''.join([char_set[t] for t in index]), end='')
#         ret =''.join([char_set[t] for t in index])
#     else: #其它取最后一个字符
#         # print(char_set[index[-1]], end='')
#         ret = ret + char_set[index[-1]]

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/575483
推荐阅读
相关标签
  

闽ICP备14008679号