当前位置:   article > 正文

使用自己的数据集训练MobileNet、ResNet实现图像分类(TensorFlow)_shufflenet怎么进行分类

shufflenet怎么进行分类

使用自己的数据集训练MobileNet、ResNet实现图像分类(TensorFlow)

之前鄙人写了一篇博客《使用自己的数据集训练GoogLenet InceptionNet V1 V2 V3模型(TensorFlow)https://panjinquan.blog.csdn.net/article/details/81560537,本博客就是此博客的框架基础上,完成对MobileNet的图像分类模型的训练,其相关项目的代码也会统一更新到一个Github中,强烈建议先看这篇博客《使用自己的数据集训练GoogLenet InceptionNet V1 V2 V3模型(TensorFlow)》后,再来看这篇博客。

TensorFlow官网中使用高级API -slim实现了很多常用的模型,如VGG,GoogLenet V1、V2和V3以及MobileNet、resnet模型:可详看这里https://github.com/tensorflow/models/tree/master/research/slim,当然TensorFlow官网也提供了训练这些模型的脚本文件,但灵活性太差了,要想增加log或者其他信息,真的很麻烦。本人花了很多时间,去搭建一个较为通用的模型训练框架《tensorflow_models_nets》,目前几乎可以支持所有模型的训练,由于训练过程是自己构建的,所以你可以在此基础上进行任意的修改,也可以搭建自己的训练模型。

重要说明

(1)项目Github源码:https://github.com/PanJinquan/tensorflow_models_learning麻烦给个“Star”

(2)你需要一台显卡不错的服务器,不然会卡的一比,慢到花都谢了

(2)对于MobileNet、resnet等大型的网络模型,重头开始训练,是很难收敛的。但迁移学习finetune部分我还没有实现,大神要是现实了,分享一下哈。

(3)注意训练mobilenet时,在迭代10000次以前,loss和准确率几乎不会提高。一开始我以为是训练代码写错了,后来寻思了很久,才发现是模型太复杂了,所以收敛慢的一比,大概20000次迭代后,准确率才开始蹭蹭的往上长,迭代十万次后准确率才70%,


目录

使用自己的数据集训练MobileNet图像识别(TensorFlow)

1、项目文件结构说明 

2、MobileNet的网络:

3、图片数据集

4、制作tfrecords数据格式

5、MobileNet模型

6、训练方法实现过程

7、模型预测

8、其他模型训练方法 


1、项目文件结构说明 

tensorflow_models_nets:

|__dataset   #数据文件

    |__record #里面存放record文件

    |__train    #train原始图片

    |__val      #val原始图片

|__models  #保存训练的模型

|__slim        #这个是拷贝自slim模块:https://github.com/tensorflow/models/tree/master/research/slim

|__test_image #存放测试的图片

|__create_labels_files.py #制作trian和val TXT的文件

|__create_tf_record.py #制作tfrecord文件

|__inception_v1_train_val.py #inception V1的训练文件

|__inception_v3_train_val.py # inception V3训练文件

|__mobilenet_train_val.py#mobilenet训练文件

|__resnet_v1_train_val.py#resnet训练文件

|__predict.py # 模型预测文件

2、MobileNet的网络:

关于MobileNet模型,请详看这篇博客《轻量级网络--MobileNet论文解读》https://blog.csdn.net/u011974639/article/details/79199306 ,本博客不会纠结于模型原理和论文,主要分享的是用自己的数据集去训练MobileNet的方法。

3、图片数据集

下面是我下载的数据集,共有五类图片,分别是:flower、guitar、animal、houses和plane,每组数据集大概有800张左右。为了照顾网友,下面的数据集,都已经放在Github项目的文件夹dataset上了,不需要你下载了,记得给个“star”哈

animal:http://www.robots.ox.ac.uk/~vgg/data/pets/ 
flower:http://www.robots.ox.ac.uk/~vgg/data/flowers/ 
plane:http://www.robots.ox.ac.uk/~vgg/data/airplanes_side/airplanes_side.tar 
house:http://www.robots.ox.ac.uk/~vgg/data/houses/houses.tar 
guitar:http://www.robots.ox.ac.uk/~vgg/data/guitars/guitars.tar 

    下载图片数据集后,需要划分为train和val数据集,前者用于训练模型的数据,后者主要用于验证模型。这里提供一个create_labels_files.py脚本,可以直接生成训练train和验证val的数据集txt文件。

  1. #-*-coding:utf-8-*-
  2. """
  3. @Project: googlenet_classification
  4. @File : create_labels_files.py
  5. @Author : panjq
  6. @E-mail : pan_jinquan@163.com
  7. @Date : 2018-08-11 10:15:28
  8. """
  9. import os
  10. import os.path
  11. def write_txt(content, filename, mode='w'):
  12. """保存txt数据
  13. :param content:需要保存的数据,type->list
  14. :param filename:文件名
  15. :param mode:读写模式:'w' or 'a'
  16. :return: void
  17. """
  18. with open(filename, mode) as f:
  19. for line in content:
  20. str_line = ""
  21. for col, data in enumerate(line):
  22. if not col == len(line) - 1:
  23. # 以空格作为分隔符
  24. str_line = str_line + str(data) + " "
  25. else:
  26. # 每行最后一个数据用换行符“\n”
  27. str_line = str_line + str(data) + "\n"
  28. f.write(str_line)
  29. def get_files_list(dir):
  30. '''
  31. 实现遍历dir目录下,所有文件(包含子文件夹的文件)
  32. :param dir:指定文件夹目录
  33. :return:包含所有文件的列表->list
  34. '''
  35. # parent:父目录, filenames:该目录下所有文件夹,filenames:该目录下的文件名
  36. files_list = []
  37. for parent, dirnames, filenames in os.walk(dir):
  38. for filename in filenames:
  39. # print("parent is: " + parent)
  40. # print("filename is: " + filename)
  41. # print(os.path.join(parent, filename)) # 输出rootdir路径下所有文件(包含子文件)信息
  42. curr_file=parent.split(os.sep)[-1]
  43. if curr_file=='flower':
  44. labels=0
  45. elif curr_file=='guitar':
  46. labels=1
  47. elif curr_file=='animal':
  48. labels=2
  49. elif curr_file=='houses':
  50. labels=3
  51. elif curr_file=='plane':
  52. labels=4
  53. files_list.append([os.path.join(curr_file, filename),labels])
  54. return files_list
  55. if __name__ == '__main__':
  56. train_dir = 'dataset/train'
  57. train_txt='dataset/train.txt'
  58. train_data = get_files_list(train_dir)
  59. write_txt(train_data,train_txt,mode='w')
  60. val_dir = 'dataset/val'
  61. val_txt='dataset/val.txt'
  62. val_data = get_files_list(val_dir)
  63. write_txt(val_data,val_txt,mode='w')

注意,上面Python代码,已经定义每组图片对应的标签labels:

  1. flower ->labels=0
  2. guitar ->labels=1
  3. animal ->labels=2
  4. houses ->labels=3
  5. plane ->labels=4

4、制作tfrecords数据格式

有了 train.txt和val.txt数据集,我们就可以制作train.tfrecords和val.tfrecords文件了,项目提供一个用于制作tfrecords数据格式的Python文件:create_tf_record.py,鄙人已经把代码放在另一篇博客:Tensorflow生成自己的图片数据集TFrecordshttps://blog.csdn.net/guyuealian/article/details/80857228 ,代码有详细注释了,所以这里不贴出来了.

注意:

(1)create_tf_record.py将train和val数据分别保存为单个record文件,当图片数据很多时候,会导致单个record文件超级巨大的情况,解决方法就是,将数据分成多个record文件保存,读取时,只需要将多个record文件的路径列表交给“tf.train.string_input_producer”即可。

(2)如何将数据保存为多个record文件呢?请参考鄙人的博客:《Tensorflow生成自己的图片数据集TFrecords》https://blog.csdn.net/guyuealian/article/details/80857228

为了方便大家,项目以及适配了“create_tf_record.py”文件,dataset已经包含了训练和测试的图片,请直接运行create_tf_record.py即可生成tfrecords文件。

对于InceptionNet V1:设置resize_height和resize_width = 224 
对于InceptionNet V3:设置resize_height和resize_width = 299 
其他模型,请根据输入需要设置resize_height和resize_width的大小

  1. if __name__ == '__main__':
  2. # 参数设置
  3. resize_height = 224 # 指定存储图片高度
  4. resize_width = 224 # 指定存储图片宽度
  5. shuffle=True
  6. log=5
  7. # 产生train.record文件
  8. image_dir='dataset/train'
  9. train_labels = 'dataset/train.txt' # 图片路径
  10. train_record_output = 'dataset/record/train{}.tfrecords'.format(resize_height)
  11. create_records(image_dir,train_labels, train_record_output, resize_height, resize_width,shuffle,log)
  12. train_nums=get_example_nums(train_record_output)
  13. print("save train example nums={}".format(train_nums))
  14. # 产生val.record文件
  15. image_dir='dataset/val'
  16. val_labels = 'dataset/val.txt' # 图片路径
  17. val_record_output = 'dataset/record/val{}.tfrecords'.format(resize_height)
  18. create_records(image_dir,val_labels, val_record_output, resize_height, resize_width,shuffle,log)
  19. val_nums=get_example_nums(val_record_output)
  20. print("save val example nums={}".format(val_nums))
  21. # 测试显示函数
  22. # disp_records(train_record_output,resize_height, resize_width)
  23. batch_test(train_record_output,resize_height, resize_width)

  create_tf_record.py提供几个重要的函数:

  1. create_records():用于制作records数据的函数,
  2. read_records():用于读取records数据的函数,
  3. get_batch_images():用于生成批训练数据的函数
  4. get_example_nums:统计tf_records图像的个数(example个数)
  5. disp_records(): 解析record文件,并显示图片,主要用于验证生成record文件是否成功

5、MobileNet模型

  官网TensorFlow已经提供了使用TF-slim实现的MobileNet模型。

1、官网模型地址:https://github.com/tensorflow/models/tree/master/research/slim/nets

2、slim/nets下的模型都是用TF-slim实现的网络结构,关系TF-slim的用法,可参考:

tensorflow中slim模块api介绍》:https://blog.csdn.net/guvcolie/article/details/77686555

6、训练方法实现过程

训练文件源码已经给了较为详细的注释,不明白请在评论区留言吧

  1. #coding=utf-8
  2. import tensorflow as tf
  3. import numpy as np
  4. import pdb
  5. import os
  6. from datetime import datetime
  7. import slim.nets.mobilenet_v1 as mobilenet_v1
  8. from create_tf_record import *
  9. import tensorflow.contrib.slim as slim
  10. '''
  11. 参考资料:https://www.cnblogs.com/adong7639/p/7942384.html
  12. '''
  13. labels_nums = 5 # 类别个数
  14. batch_size = 16 #
  15. resize_height = 224 # mobilenet_v1.default_image_size 指定存储图片高度
  16. resize_width = 224 # mobilenet_v1.default_image_size 指定存储图片宽度
  17. depths = 3
  18. data_shape = [batch_size, resize_height, resize_width, depths]
  19. # 定义input_images为图片数据
  20. input_images = tf.placeholder(dtype=tf.float32, shape=[None, resize_height, resize_width, depths], name='input')
  21. # 定义input_labels为labels数据
  22. # input_labels = tf.placeholder(dtype=tf.int32, shape=[None], name='label')
  23. input_labels = tf.placeholder(dtype=tf.int32, shape=[None, labels_nums], name='label')
  24. # 定义dropout的概率
  25. keep_prob = tf.placeholder(tf.float32,name='keep_prob')
  26. is_training = tf.placeholder(tf.bool, name='is_training')
  27. def net_evaluation(sess,loss,accuracy,val_images_batch,val_labels_batch,val_nums):
  28. val_max_steps = int(val_nums / batch_size)
  29. val_losses = []
  30. val_accs = []
  31. for _ in range(val_max_steps):
  32. val_x, val_y = sess.run([val_images_batch, val_labels_batch])
  33. # print('labels:',val_y)
  34. # val_loss = sess.run(loss, feed_dict={x: val_x, y: val_y, keep_prob: 1.0})
  35. # val_acc = sess.run(accuracy,feed_dict={x: val_x, y: val_y, keep_prob: 1.0})
  36. val_loss,val_acc = sess.run([loss,accuracy], feed_dict={input_images: val_x, input_labels: val_y, keep_prob:1.0, is_training: False})
  37. val_losses.append(val_loss)
  38. val_accs.append(val_acc)
  39. mean_loss = np.array(val_losses, dtype=np.float32).mean()
  40. mean_acc = np.array(val_accs, dtype=np.float32).mean()
  41. return mean_loss, mean_acc
  42. def step_train(train_op,loss,accuracy,
  43. train_images_batch,train_labels_batch,train_nums,train_log_step,
  44. val_images_batch,val_labels_batch,val_nums,val_log_step,
  45. snapshot_prefix,snapshot):
  46. '''
  47. 循环迭代训练过程
  48. :param train_op: 训练op
  49. :param loss: loss函数
  50. :param accuracy: 准确率函数
  51. :param train_images_batch: 训练images数据
  52. :param train_labels_batch: 训练labels数据
  53. :param train_nums: 总训练数据
  54. :param train_log_step: 训练log显示间隔
  55. :param val_images_batch: 验证images数据
  56. :param val_labels_batch: 验证labels数据
  57. :param val_nums: 总验证数据
  58. :param val_log_step: 验证log显示间隔
  59. :param snapshot_prefix: 模型保存的路径
  60. :param snapshot: 模型保存间隔
  61. :return: None
  62. '''
  63. saver = tf.train.Saver(max_to_keep=5)
  64. max_acc = 0.0
  65. with tf.Session() as sess:
  66. sess.run(tf.global_variables_initializer())
  67. sess.run(tf.local_variables_initializer())
  68. coord = tf.train.Coordinator()
  69. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  70. for i in range(max_steps + 1):
  71. batch_input_images, batch_input_labels = sess.run([train_images_batch, train_labels_batch])
  72. _, train_loss = sess.run([train_op, loss], feed_dict={input_images: batch_input_images,
  73. input_labels: batch_input_labels,
  74. keep_prob: 0.8, is_training: True})
  75. # train测试(这里仅测试训练集的一个batch)
  76. if i % train_log_step == 0:
  77. train_acc = sess.run(accuracy, feed_dict={input_images: batch_input_images,
  78. input_labels: batch_input_labels,
  79. keep_prob: 1.0, is_training: False})
  80. print("%s: Step [%d] train Loss : %f, training accuracy : %g" % (
  81. datetime.now(), i, train_loss, train_acc))
  82. # val测试(测试全部val数据)
  83. if i % val_log_step == 0:
  84. mean_loss, mean_acc = net_evaluation(sess, loss, accuracy, val_images_batch, val_labels_batch, val_nums)
  85. print("%s: Step [%d] val Loss : %f, val accuracy : %g" % (datetime.now(), i, mean_loss, mean_acc))
  86. # 模型保存:每迭代snapshot次或者最后一次保存模型
  87. if (i % snapshot == 0 and i > 0) or i == max_steps:
  88. print('-----save:{}-{}'.format(snapshot_prefix, i))
  89. saver.save(sess, snapshot_prefix, global_step=i)
  90. # 保存val准确率最高的模型
  91. if mean_acc > max_acc and mean_acc > 0.7:
  92. max_acc = mean_acc
  93. path = os.path.dirname(snapshot_prefix)
  94. best_models = os.path.join(path, 'best_models_{}_{:.4f}.ckpt'.format(i, max_acc))
  95. print('------save:{}'.format(best_models))
  96. saver.save(sess, best_models)
  97. coord.request_stop()
  98. coord.join(threads)
  99. def train(train_record_file,
  100. train_log_step,
  101. train_param,
  102. val_record_file,
  103. val_log_step,
  104. labels_nums,
  105. data_shape,
  106. snapshot,
  107. snapshot_prefix):
  108. '''
  109. :param train_record_file: 训练的tfrecord文件
  110. :param train_log_step: 显示训练过程log信息间隔
  111. :param train_param: train参数
  112. :param val_record_file: 验证的tfrecord文件
  113. :param val_log_step: 显示验证过程log信息间隔
  114. :param val_param: val参数
  115. :param labels_nums: labels数
  116. :param data_shape: 输入数据shape
  117. :param snapshot: 保存模型间隔
  118. :param snapshot_prefix: 保存模型文件的前缀名
  119. :return:
  120. '''
  121. [base_lr,max_steps]=train_param
  122. [batch_size,resize_height,resize_width,depths]=data_shape
  123. # 获得训练和测试的样本数
  124. train_nums=get_example_nums(train_record_file)
  125. val_nums=get_example_nums(val_record_file)
  126. print('train nums:%d,val nums:%d'%(train_nums,val_nums))
  127. # 从record中读取图片和labels数据
  128. # train数据,训练数据一般要求打乱顺序shuffle=True
  129. train_images, train_labels = read_records(train_record_file, resize_height, resize_width, type='normalization')
  130. train_images_batch, train_labels_batch = get_batch_images(train_images, train_labels,
  131. batch_size=batch_size, labels_nums=labels_nums,
  132. one_hot=True, shuffle=True)
  133. # val数据,验证数据可以不需要打乱数据
  134. val_images, val_labels = read_records(val_record_file, resize_height, resize_width, type='normalization')
  135. val_images_batch, val_labels_batch = get_batch_images(val_images, val_labels,
  136. batch_size=batch_size, labels_nums=labels_nums,
  137. one_hot=True, shuffle=False)
  138. # Define the model:
  139. with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope()):
  140. out, end_points = mobilenet_v1.mobilenet_v1(inputs=input_images, num_classes=labels_nums,
  141. dropout_keep_prob=keep_prob, is_training=is_training,
  142. global_pool=True)
  143. # Specify the loss function: tf.losses定义的loss函数都会自动添加到loss函数,不需要add_loss()了
  144. tf.losses.softmax_cross_entropy(onehot_labels=input_labels, logits=out) # 添加交叉熵损失loss=1.6
  145. # slim.losses.add_loss(my_loss)
  146. loss = tf.losses.get_total_loss(add_regularization_losses=True) # 添加正则化损失loss=2.2
  147. # Specify the optimization scheme:
  148. # 在定义训练的时候, 注意到我们使用了`batch_norm`层时,需要更新每一层的`average`和`variance`参数,
  149. # 更新的过程不包含在正常的训练过程中, 需要我们去手动像下面这样更新
  150. # 通过`tf.get_collection`获得所有需要更新的`op`
  151. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  152. # 使用`tensorflow`的控制流, 先执行更新算子, 再执行训练
  153. with tf.control_dependencies(update_ops):
  154. print("update_ops:{}".format(update_ops))
  155. # create_train_op that ensures that when we evaluate it to get the loss,
  156. # the update_ops are done and the gradient updates are computed.
  157. # train_op = tf.train.MomentumOptimizer(learning_rate=base_lr, momentum=0.9).minimize(loss)
  158. train_op = tf.train.AdadeltaOptimizer(learning_rate=base_lr).minimize(loss)
  159. accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(out, 1), tf.argmax(input_labels, 1)), tf.float32))
  160. # 循环迭代过程
  161. step_train(train_op=train_op, loss=loss, accuracy=accuracy,
  162. train_images_batch=train_images_batch,
  163. train_labels_batch=train_labels_batch,
  164. train_nums=train_nums,
  165. train_log_step=train_log_step,
  166. val_images_batch=val_images_batch,
  167. val_labels_batch=val_labels_batch,
  168. val_nums=val_nums,
  169. val_log_step=val_log_step,
  170. snapshot_prefix=snapshot_prefix,
  171. snapshot=snapshot)
  172. if __name__ == '__main__':
  173. train_record_file='dataset/record/train224.tfrecords'
  174. val_record_file='dataset/record/val224.tfrecords'
  175. train_log_step=100
  176. base_lr = 0.001 # 学习率
  177. # 重头开始训练的话,mobilenet收敛慢的一比,大概20000次迭代后,准确率开始蹭蹭的往上长,迭代十万次后准确率才70%
  178. max_steps = 100000 # 迭代次数
  179. train_param=[base_lr,max_steps]
  180. val_log_step=500
  181. snapshot=2000#保存文件间隔
  182. snapshot_prefix='models/model.ckpt'
  183. train(train_record_file=train_record_file,
  184. train_log_step=train_log_step,
  185. train_param=train_param,
  186. val_record_file=val_record_file,
  187. val_log_step=val_log_step,
  188. labels_nums=labels_nums,
  189. data_shape=data_shape,
  190. snapshot=snapshot,
  191. snapshot_prefix=snapshot_prefix)

7、模型预测

模型预测,项目只提供一个predict.py,实质上,你只需要稍微改改,就可以预测其他模型

  1. #coding=utf-8
  2. import tensorflow as tf
  3. import numpy as np
  4. import pdb
  5. import cv2
  6. import os
  7. import glob
  8. import slim.nets.inception_v3 as inception_v3
  9. from create_tf_record import *
  10. import tensorflow.contrib.slim as slim
  11. def predict(models_path,image_dir,labels_filename,labels_nums, data_format):
  12. [batch_size, resize_height, resize_width, depths] = data_format
  13. labels = np.loadtxt(labels_filename, str, delimiter='\t')
  14. input_images = tf.placeholder(dtype=tf.float32, shape=[None, resize_height, resize_width, depths], name='input')
  15. #其他模型预测请修改这里
  16. with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
  17. out, end_points = inception_v3.inception_v3(inputs=input_images, num_classes=labels_nums, dropout_keep_prob=1.0, is_training=False)
  18. # 将输出结果进行softmax分布,再求最大概率所属类别
  19. score = tf.nn.softmax(out,name='pre')
  20. class_id = tf.argmax(score, 1)
  21. sess = tf.InteractiveSession()
  22. sess.run(tf.global_variables_initializer())
  23. saver = tf.train.Saver()
  24. saver.restore(sess, models_path)
  25. images_list=glob.glob(os.path.join(image_dir,'*.jpg'))
  26. for image_path in images_list:
  27. im=read_image(image_path,resize_height,resize_width,normalization=True)
  28. im=im[np.newaxis,:]
  29. #pred = sess.run(f_cls, feed_dict={x:im, keep_prob:1.0})
  30. pre_score,pre_label = sess.run([score,class_id], feed_dict={input_images:im})
  31. max_score=pre_score[0,pre_label]
  32. print("{} is: pre labels:{},name:{} score: {}".format(image_path,pre_label,labels[pre_label], max_score))
  33. sess.close()
  34. if __name__ == '__main__':
  35. class_nums=5
  36. image_dir='test_image'
  37. labels_filename='dataset/label.txt'
  38. models_path='models/model.ckpt-10000'
  39. batch_size = 1 #
  40. resize_height = 299 # 指定存储图片高度
  41. resize_width = 299 # 指定存储图片宽度
  42. depths=3
  43. data_format=[batch_size,resize_height,resize_width,depths]
  44. predict(models_path,image_dir, labels_filename, class_nums, data_format)

8、其他模型训练方法 

    上面的程序是训练MobileNet的完整过程,实质上,稍微改改就可以支持训练 inception V1,V2和resnet 啦,改动方法也很简单,以 MobileNe训练代码改为resnet_v1模型为例:

(1)import 改为:

  1. # 将
  2. import slim.nets.mobilenet_v1 as mobilenet_v1
  3. # 改为
  4. import slim.nets.resnet_v1 as resnet_v1

(2)record数据

 制作record数据时,需要根据模型输入设置:

resize_height = 224  # 指定存储图片高度
resize_width = 224  # 指定存储图片宽度

(3)定义模型和默认参数修改:

  1. # 将
  2. # Define the model:
  3. with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope()):
  4. out, end_points = mobilenet_v1.mobilenet_v1(inputs=input_images, num_classes=labels_nums,
  5. dropout_keep_prob=keep_prob, is_training=is_training,
  6. global_pool=True)
  7. # 改为
  8. # Define the model:
  9. with slim.arg_scope(resnet_v1.resnet_arg_scope()):
  10. out, end_points = resnet_v1.resnet_v1_101(inputs=input_images, num_classes=labels_nums, is_training=is_training,global_pool=True)

(4)修改优化方案

对于大型的网络模型,重头开始训练,是很难收敛的。训练mobilenet时,在迭代10000次以前,loss和准确率几乎不会提高。一开始我以为是训练代码写错了,后来寻思了很久,才发现是模型太复杂了,所以收敛慢的一比,大概20000次迭代后,准确率才开始蹭蹭的往上长,迭代十万次后准确率才70%,若训练过程发现不收敛,请尝试修改:

1、等!!!!至少你要迭代50000次,才能说你的模型不收敛!

2、增大或减小学习率参数:base_lr(个人经验:模型越深越复杂时,学习率越小)

3、改变优化方案:如使用MomentumOptimizer或者AdadeltaOptimizer等优化方法

4、是否有设置默认的模型参数:如slim.arg_scope(inception_v1.inception_v1_arg_scope())

……最后,就可以Train了!是的,就是那么简单~

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/92736
推荐阅读
相关标签
  

闽ICP备14008679号