当前位置:   article > 正文

【Tensorflow slim】slim learning包_slim.learning.train

slim.learning.train

TF-Slim在learning.py中的训练模型提供了一套简单但功能强大的工具。 这些功能包括一个训练函数可以反复测量损失,计算梯度并将模型保存到磁盘,以及用于操纵梯度的几个便利函数。 例如,一旦我们指定了模型,损失函数和优化方案,我们可以调用slim.learning.create_train_op和slim.learning.train来执行优化:

  1. g = tf.Graph()
  2. # Create the model and specify the losses...
  3. ...
  4. total_loss = slim.losses.get_total_loss()
  5. optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  6. # create_train_op ensures that each time we ask for the loss, the update_ops
  7. # are run and the gradients being computed are applied too.
  8. train_op = slim.learning.create_train_op(total_loss, optimizer)
  9. logdir = ... # Where checkpoints are stored.
  10. slim.learning.train(
  11. train_op,
  12. logdir,
  13. number_of_steps=1000,
  14. save_summaries_secs=300,
  15. save_interval_secs=600):

在这个例子中,slim.learning.train与train_op一起提供,用于(a)计算损失和(b)应用梯度步骤。 logdir指定检查点和事件文件的存储目录。 我们可以限制采取任何数字的梯度步数。 在这种情况下,我们要求采取1000个步骤。 最后,save_summaries_secs = 300表示我们将每隔5分钟计算摘要,save_interval_secs = 600表示我们将每10分钟保存一次模型检查点。

Working Example: Training the VGG16 Model

  1. import tensorflow as tf
  2. slim = tf.contrib.slim
  3. vgg = tf.contrib.slim.nets.vgg
  4. ...
  5. train_log_dir = ...
  6. if not tf.gfile.Exists(train_log_dir):
  7. tf.gfile.MakeDirs(train_log_dir)
  8. with tf.Graph().as_default():
  9. # Set up the data loading:
  10. images, labels = ...
  11. # Define the model:
  12. predictions = vgg.vgg_16(images, is_training=True)
  13. # Specify the loss function:
  14. slim.losses.softmax_cross_entropy(predictions, labels)
  15. total_loss = slim.losses.get_total_loss()
  16. tf.summary.scalar('losses/total_loss', total_loss)
  17. # Specify the optimization scheme:
  18. optimizer = tf.train.GradientDescentOptimizer(learning_rate=.001)
  19. # create_train_op that ensures that when we evaluate it to get the loss,
  20. # the update_ops are done and the gradient updates are computed.
  21. train_tensor = slim.learning.create_train_op(total_loss, optimizer)
  22. # Actually runs training.
  23. slim.learning.train(train_tensor, train_log_dir)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/在线问答5/article/detail/880897
推荐阅读
相关标签
  

闽ICP备14008679号