当前位置:   article > 正文

DNN在TensorFlow框架的实现_基于tensorflow的dnn模型结构

基于tensorflow的dnn模型结构

一、使用tensorflow表达一个batch的数据

tensorflow提供placeholder机制用于提供输入数据,其相当于一种占位符如%d这种,这个位置的数据在程序运行时再指定,这样在程序中就不需要生成大量常量来提供输入数据
placeholder定义时这个位置上的数据类型是需要指定的
例子:

x = tf.placeholder(tf.float32,shape(3,2))
  • 1

feed_dict来指定x的取值,它是一个字典

feed_dict = {x: [[7,9],[1,2],[3,4]]}
  • 1

二、设计以及优化神经网络

1、去线性化

2、损失函数

3、反向传播算法

4、避免过拟合的正则化

5、滑动平均模型在未知数据上更健壮

经典损失函数:

交叉熵(cross_entropy)刻画了输出向量和期望向量的概率分布距离

Softmax函数将神经网络输出变成一个概率分布,从而可以通过交叉熵来计算预测概率分布和真实答案的概率分布之间的距离

然后通过反向传播算法来调整神经网络参数的取值使得差距可以被缩小

以下代码定义了一个简单的损失函数

定义损失函数y_代表正确结果,y代表预测值

cross_entropy = -tf.reduce_mean(y_*tf.log(tf.clip_by_value(y,le-10,1.0)))
  • 1

定义反向传播算法来优化神经网络中的参数,学习率定义了每次参数移动的幅度,学习率 #的设置要靠经验

train_step=tf.train.Adamoptimizer(learning_rate).minimize(loss)
  • 1

神经网络训练过程模型结构

batch_size =n

每次选取一小部分数据作为当前训练数据来执行反向传播算法

x= tf.placeholder(tf.float32,shape=batch_size,2),name=’x-input’)

y_= tf.placeholder(tf.float32,shape=batch_size,1),name=’y-input’)’

定义神经网络结构和优化算法

loss=...

train_step=tf.train.Adamoptimizer(learning_rate).minimize(loss)

训练神经网络

with tf.Session() as sess:

初始化参数

...

迭代的更新参数

for I in range(STEPS):

准备batch_size个训练数据。一般将所有训练数据随即打乱之后再选取,可以得到更好的#优化效果

current_X,current_Y=...

sess.run(train_step,feed_dict={x: current_X ,y_= current_Y})
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

过拟合问题

过拟合指的是当一个模型过为复杂之后,它可以很好的“记忆”每一个训练数据中的随机噪音的部分而忽略了要去“学习”训练数据中的通用趋势

为了避免过拟合问题,一般采用正则化,其思想大概为在损失函数中加入刻画模型复杂程度指标,希望通过限制权重的大小,使得模型不能任意拟合数据中的随机噪音

变量的滑动平均模型

滑动平均模型会使神经网络在未知的数据上更加健壮

当初始化某个变量时,指定了trainable = False程序运行时不会计算这个变量的滑动平均值

三、变量管理

通过tf.variable_scope()函数来控制tf.get_variable()函数获取已经创建过的变量,使用不同的命名空间来隔离不同层的变量
可以让每一层中的变量命名只需考虑在当前层的作用,不必担心重命名问题

四、tensorflow模型持久化

通过tf.train.saver()来初始化tensorflow持久化类
保存模型的文件名末尾加上训练轮数
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)

tf.train.get_checkpoint_state()函数会通过checkpoint文件自动找到目录中最新的模型文件名
通过saver.restore()加载模型

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/697769
推荐阅读
相关标签
  

闽ICP备14008679号