赞
踩
import tensorflow_pass_warning_wgs import tensorflow as tf, numpy import tensorflow.contrib as con tf.set_random_seed(1) class model(): def __init__(self, sentence): self.sentence = sentence self.sequnec_length = 10 # 序列长度 self.X_data, self.Y_data = self.get_data(sentence) self.hidden_size = 50 # 隐藏神经元数量 self.num_classes = len(self.idx2char) self.batch_size = len(self.X_data) def get_data(self, sentence): # 词集向量 self.idx2char = list(set(sentence)) # 转字典 char_dict = {w: i for i, w in enumerate(self.idx2char)} # 构造数据集 x_data, y_data = [], [] for i in range(0, len(sentence) - self.sequnec_length): x_str = sentence[i: i + self.sequnec_length] # 措开取 y_str = sentence[i + 1: i + self.sequnec_length + 1] print(i, x_str, '->', y_str) # 词袋模型 x, y = [char_dict[c] for c in x_str], [char_dict[c] for c in y_str] x_data.append(x) y_data.append(y) return x_data, y_data # 训练 def train(self): X, Y = tf.placeholder(tf.int32, [None, self.sequnec_length]), tf.placeholder(tf.int32, [None, self.sequnec_length]) X_oneHot = tf.one_hot(X, self.num_classes) cells = [con.rnn.LSTMCell(num_units=self.hidden_size) for _ in range(2)] # 深层RNN,多个RNN基础单元 mul_cells = con.rnn.MultiRNNCell(cells) # 堆叠RNN基础单元 outputs, state = tf.nn.dynamic_rnn(mul_cells, X_oneHot, dtype=tf.float32) outputs = tf.reshape(outputs, shape=[-1, self.hidden_size]) logits = con.layers.fully_connected(outputs, self.num_classes, activation_fn=None) # 全连接 logits = tf.reshape(logits, shape=[self.batch_size, self.sequnec_length, self.num_classes]) weight = tf.ones(shape=[self.batch_size, self.sequnec_length]) cost = tf.reduce_mean(con.seq2seq.sequence_loss(logits=logits, targets=Y, weights=weight)) optimizer = tf.train.AdamOptimizer(0.1).minimize(cost) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(500): c_, lo_, _ = sess.run([cost, logits, optimizer], feed_dict={X: self.X_data, Y: self.Y_data}) for j, res in enumerate(lo_): index = numpy.argmax(res, 1) if j == 0: ret = ''.join([self.idx2char[c] for c in index]) else: ret += self.idx2char[index[-1]] print(i, c_, ret) if ret == self.sentence[1:]: break sentence = ("if you want to build a ship, don't drum up people together to " "collect wood and don't assign them tasks and work, but rather " "teach them to long for the endless immensity of the sea.") m = model(sentence) m.train()
from __future__ import print_function import tensorflow as tf import numpy as np from tensorflow.contrib import rnn import tensorflow.contrib as rnnn from tensorflow.python.ops.rnn import dynamic_rnn tf.set_random_seed(777) # reproducibility sentence = ("if you want to build a ship, don't drum up people together to " "collect wood and don't assign them tasks and work, but rather " "teach them to long for the endless immensity of the sea.") char_set = list(set(sentence)) # print(len(char_set)) #25 char_dic = {w: i for i, w in enumerate(char_set)} hidden_size = 50 # len(char_set) #25 sequence_length = 10 # Any arbitrary number data_dim = len(char_set) # 25 num_classes = len(char_set) # 25 learning_rate = 0.1 # 构造数据集 dataX = [] dataY = [] for i in range(0, len(sentence) - sequence_length): x_str = sentence[i:i + sequence_length] y_str = sentence[i + 1: i + sequence_length + 1] print(i, x_str, '->', y_str) x = [char_dic[c] for c in x_str] # x str to index 字符转数字 y = [char_dic[c] for c in y_str] # y str to index dataX.append(x) dataY.append(y) batch_size = len(dataX) # 170 print(batch_size) X = tf.placeholder(tf.int32, [None, sequence_length]) Y = tf.placeholder(tf.int32, [None, sequence_length]) X_one_hot = tf.one_hot(X, num_classes) # 独热编码 #print(X_one_hot) (?, 10, 25) # 建一个有隐藏单元的LSTM,Make a lstm cell with hidden_size (each unit output vector size) def cell(): # cell = rnn.BasicLSTMCell(hidden_size, state_is_tuple=True) # cell = rnn.GRUCell(hidden_size) cell = rnn.LSTMCell(hidden_size, state_is_tuple=True) return cell multi_cells = rnn.MultiRNNCell([cell() for _ in range(2)], state_is_tuple=True) # outputs:展开隐藏层 unfolding size x hidden size, state = hidden size outputs, _states = dynamic_rnn(multi_cells, X_one_hot, dtype=tf.float32) # 全连接层FC layer X_for_fc = tf.reshape(outputs, [-1, hidden_size]) outputs = tf.contrib.layers.fully_connected(X_for_fc, num_classes, activation_fn=None) # outputs = rnnn.layers.fully_connected(X_for_fc,num_classes,activation_fn=None) # print(outputs.shape) #(?,25) # 改变维度准备计算序列损失reshape out for sequence_loss outputs = tf.reshape(outputs, [batch_size, sequence_length, num_classes]) # (170, 10, 25) weights = tf.ones([batch_size, sequence_length]) # 所有的权重都是1 All weights are 1 (equal weights) # 计算损失值 sequence_loss = tf.contrib.seq2seq.sequence_loss(logits=outputs, targets=Y, weights=weights) # sequence_loss = rnnn.seq2seq.sequence_loss(logits=outputs,targets=Y,weights=weights) mean_loss = tf.reduce_mean(sequence_loss) train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(mean_loss) sess = tf.Session() sess.run(tf.global_variables_initializer()) for i in range(500): # (500) _, lossval, results = sess.run( [train_op, mean_loss, outputs], feed_dict={X: dataX, Y: dataY}) # print(results.shape) (170,10,25) # if i == 49: # for j, result in enumerate(results): #j:[0,170) result:(10,25) # index = np.argmax(result, axis=1) # print(i, j, ''.join([char_set[t] for t in index]), l) results = sess.run(outputs, feed_dict={X: dataX}) # (170,10,25) for j, result in enumerate(results): index = np.argmax(result, axis=1) # print('----------------------------------------') # print(index) if j is 0: # 第一个结果10个字符组成一个句子 print all for the first result to make a sentence ret = ''.join([char_set[t] for t in index]) else: # 其它取最后一个字符 ret = ret + char_set[index[-1]] print(i, lossval, ret) if ret == sentence[1:]: break # #输出每个结果的最后一个字符检测效果 Let's print the last char of each result to check it works # results = sess.run(outputs, feed_dict={X: dataX}) #(170,10,25) # for j, result in enumerate(results): # index = np.argmax(result, axis=1) # if j is 0: #第一个结果10个字符组成一个句子 print all for the first result to make a sentence # # print(''.join([char_set[t] for t in index]), end='') # ret =''.join([char_set[t] for t in index]) # else: #其它取最后一个字符 # # print(char_set[index[-1]], end='') # ret = ret + char_set[index[-1]]
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。