当前位置:   article > 正文

【TensorFlow】实现Basic RNN_tensorflow2.0中的basicrnn

tensorflow2.0中的basicrnn

该模型是一个Seq2Seq的模型:

输入:(0,1)序列;例如x = (1,1,1,0,0,0,1,1)

标签:输出(0,1)序列右移若干位后的序列;例如将x右移2位后,y=(0,0,1,1,1,0,0,0)

因此该模型的作用是给定一个(0,1)序列,预测其右移若干位后的序列。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
  • 1
  • 2
  • 3

参数

num_epochs = 100
total_series_length = 50000 # 生成数据的总长度
truncated_backprop_length = 15
state_size = 4 # cell state的尺寸
num_classes = 2
echo_step = 3
batch_size = 5
# total_series_length//batch_size==10000:表示每个样本的长度为10000
# num_batches:将10000以每15个为一个单位送入模型需要的次数
num_batches = total_series_length//batch_size//truncated_backprop_length
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

用于生成数据的函数

def generateData():
    '''
    随机生成长度为total_seires_length的0、1序列x;
    将x循环右移echo_step步并将右移的数字清零生成y;
    '''
    # 0和1按照概率0.5,0.5进行随机抽样,抽样次数为total_series_length
    x = np.array(np.random.choice(2,total_series_length,p=[0.5,0.5]))
    y = np.roll(x,echo_step) # 将x循环右移echo_step步
    y[0:echo_step] = 0

    x = x.reshape((batch_size,-1))
    y = y.reshape((batch_size,-1))
    
    return x,y
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

placeholder与Variable

X = tf.placeholder(tf.float32,[batch_size,truncated_backprop_length]) # 输入
Y = tf.placeholder(tf.int64,[batch_size,truncated_backprop_length]) # 输出
init_state = tf.placeholder(tf.float32,[batch_size,state_size]) # 初始状态

# 输入权重与偏置
# 输入的尺寸为1,隐藏状态的尺寸为state_size
# W1与b1负责将输入和隐藏状态合并的向量转换为下一个cell的state,因此W1.shape==(state_size+1,state_size)
W1 = tf.Variable(np.random.rand(state_size+1,state_size),dtype = tf.float32)
b1 = tf.Variable(np.random.rand(1,state_size),dtype = tf.float32)
# 输出权重与偏置
W2 = tf.Variable(np.random.rand(state_size,num_classes),dtype = tf.float32)
b2 = tf.Variable(np.random.rand(1,num_classes),dtype = tf.float32)

# 将X和Y在axis=1上进行拆分
# len(inputs_series)==truncated_backprop_length,inputs_series[i].shape==(batch_size,1)
inputs_series = tf.unstack(X,axis=1)
labels_series = tf.unstack(Y,axis=1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

前向传播

current_state = init_state
state_series = [] # 用于存储每个cell的state

# 计算所有的隐藏状态
for current_input in inputs_series:
    # current_input.shape == (batch_size,1)
    current_input = tf.reshape(current_input,[batch_size,1])
    # input_and_state_concatenated.shape == (batch_size,state_size+1)
    input_and_state_concatenated = tf.concat([current_input,current_state],axis=1)
    
    next_state = tf.tanh(tf.matmul(input_and_state_concatenated,W1)+b1)
    state_series.append(next_state)
    current_state = next_state
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

loss与反向传播

# 将state_series中的所有state,通过W2,b2转换为输出
logits_series = [tf.matmul(state,W2)+b2 for state in state_series]
# 对所有cell的logits应用softmax
predictions_series = [tf.nn.softmax(logit) for logit in logits_series]

losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,labels=labels) for logits,labels in zip(logits_series,labels_series)]
total_loss = tf.reduce_mean(losses)

train_step = tf.train.AdagradOptimizer(0.3).minimize(total_loss)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

训练

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    loss_list = []

    for epoch_idx in range(num_epochs):
        x,y = generateData()
        _current_state = np.zeros((batch_size, state_size))

        print("New data, epoch ", epoch_idx)

        for batch_idx in range(num_batches):
            start_idx = batch_idx * truncated_backprop_length
            end_idx = start_idx + truncated_backprop_length

            batchX = x[:,start_idx:end_idx]
            batchY = y[:,start_idx:end_idx]

            _total_loss, _train_step, _current_state, _predictions_series = sess.run(
                [total_loss, train_step, current_state, predictions_series],
                feed_dict={
                    X:batchX,
                    Y:batchY,
                    init_state:_current_state
                })
            loss_list.append(_total_loss)

            if batch_idx%100 == 0:
                print("Step",batch_idx, "Loss", _total_loss)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
New data, epoch  0
Step 0 Loss 0.6896054
Step 100 Loss 0.53661335
Step 200 Loss 0.48168364
Step 300 Loss 0.55712247
Step 400 Loss 0.42068735
Step 500 Loss 0.5304155
Step 600 Loss 0.38075384
......
New data, epoch  99
Step 0 Loss 0.23018573
Step 100 Loss 4.39578e-05
Step 200 Loss 3.580574e-05
Step 300 Loss 2.5321053e-05
Step 400 Loss 3.820721e-05
Step 500 Loss 3.697888e-05
Step 600 Loss 4.187376e-05
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号