当前位置:   article > 正文

基于Tensorflow的Resnet程序实现(CIFAR10准确率为91.5%)_tensorflow resnet18 cifar10测试集准确率90%

tensorflow resnet18 cifar10测试集准确率90%

       在上一篇博文中我重写了Tensorflow中的CNN的实现,对于CIFAR10的测试集的准确率为85%左右。在这个实现中,用到了2个卷积层和2个全连接层。具体的模型架构如下:

       为了进一步提高准确率,我们可以采用一些更先进的模型架构,其中一种很出名的架构就是RESNET,残差网络。这是Kaiming大神在2015年的论文"Deep Residual Learning for Image Recognition"中提到的一种网络架构,其思想是观察到一般的神经网络结构随着层数的加深,训练的误差反而会增大,因此引入了残差这个概念,把上一层的输出直接和下一层的输出相加,如下图所示。这样理论上随着网络层数的加深,引入这个结构并不会使得误差比浅层的网络更大,因为随着参数的优化,如果浅层网络已经逼近了最优值,那么之后的网络层相当于一个恒等式,即每一层的输入和输出相等,因此更深的层数不会额外增加训练误差。

       在2016年,Kaiming大神发布了另一篇论文“Identity Mappings in Deep Residual Networks”,在这个论文中对Resnet的网络结构作了进一步的改进。改进前和改进后的resnet网络结构如下图所示,按照论文的解释,改进后的结构可以在前向和后向更好的传递残差,因此能取得更好的优化效果:

       在Tensorflow的官方模型中,已经带了一个Resnet的实现,用这个模型训练,在110层的深度下,可以达到CIFAR10测试集92%左右的准确率。但是,这个代码实在是写的比较难读,做了很多辅助功能的封装,每次看代码都是要跳来跳去的看,实在是很不方便。为此我也再次改写了这个代码,按照Kaiming论文介绍的方式来进行模型的构建,在110层的网络层数下,可以达到91%左右的准确率,和官方模型的很接近。

       具体的代码分为两部分,我把构建Resnet模型的代码单独封装在一个文件中。如以下的代码,这个代码里面的_resnet_block_v1和_resnet_block_v2分别对应了上图的两种不同的resnet结构:

  1. import tensorflow as tf
  2. def _resnet_block_v1(inputs, filters, stride, projection, stage, blockname, TRAINING):
  3. # defining name basis
  4. conv_name_base = 'res' + str(stage) + blockname + '_branch'
  5. bn_name_base = 'bn' + str(stage) + blockname + '_branch'
  6. with tf.name_scope("conv_block_stage" + str(stage)):
  7. if projection:
  8. shortcut = tf.layers.conv2d(inputs, filters, (1,1),
  9. strides=(stride, stride),
  10. name=conv_name_base + '1',
  11. kernel_initializer=tf.contrib.layers.variance_scaling_initializer(),
  12. reuse=tf.AUTO_REUSE, padding='same',
  13. data_format='channels_first')
  14. shortcut = tf.layers.batch_normalization(shortcut, axis=1, name=bn_name_base + '1',
  15. training=TRAINING, reuse=tf.AUTO_REUSE)
  16. else:
  17. shortcut = inputs
  18. outputs = tf.layers.conv2d(inputs, filters,
  19. kernel_size=(3, 3),
  20. strides=(stride, stride),
  21. kernel_initializer=tf.contrib.layers.variance_scaling_initializer(),
  22. name=conv_name_base+'2a', reuse=tf.AUTO_REUSE, padding='same',
  23. data_format='channels_first')
  24. outputs = tf.layers.batch_normalization(outputs, axis=1, name=bn_name_base+'2a',
  25. training=TRAINING, reuse=tf.AUTO_REUSE)
  26. outputs = tf.nn.relu(outputs)
  27. outputs = tf.layers.conv2d(outputs, filters,
  28. kernel_size=(3, 3),
  29. strides=(1, 1),
  30. kernel_initializer=tf.contrib.layers.variance_scaling_initializer(),
  31. name=conv_name_base+'2b', reuse=tf.AUTO_REUSE, padding='same',
  32. data_format='channels_first')
  33. outputs = tf.layers.batch_normalization(outputs, axis=1, name=bn_name_base+'2b',
  34. training=TRAINING, reuse=tf.AUTO_REUSE)
  35. outputs = tf.add(shortcut, outputs)
  36. outputs = tf.nn.relu(outputs)
  37. return outputs
  38. def _resnet_block_v2(inputs, filters, stride, projection, stage, blockname, TRAINING):
  39. # defining name basis
  40. conv_name_base = 'res' + str(stage) + blockname + '_branch'
  41. bn_name_base = 'bn' + str(stage) + blockname + '_branch'
  42. with tf.name_scope("conv_block_stage" + str(stage)):
  43. shortcut = inputs
  44. outputs = tf.layers.batch_normalization(inputs, axis=1, name=bn_name_base+'2a',
  45. training=TRAINING, reuse=tf.AUTO_REUSE)
  46. outputs = tf.nn.relu(outputs)
  47. if projection:
  48. shortcut = tf.layers.conv2d(outputs, filters, (1,1),
  49. strides=(stride, stride),
  50. name=conv_name_base + '1',
  51. kernel_initializer=tf.contrib.layers.variance_scaling_initializer(),
  52. reuse=tf.AUTO_REUSE, padding='same',
  53. data_format='channels_first')
  54. shortcut = tf.layers.batch_normalization(shortcut, axis=1, name=bn_name_base + '1',
  55. training=TRAINING, reuse=tf.AUTO_REUSE)
  56. outputs = tf.layers.conv2d(outputs, filters,
  57. kernel_size=(3, 3),
  58. strides=(stride, stride),
  59. kernel_initializer=tf.contrib.layers.variance_scaling_initializer(),
  60. name=conv_name_base+'2a', reuse=tf.AUTO_REUSE, padding='same',
  61. data_format='channels_first')
  62. outputs = tf.layers.batch_normalization(outputs, axis=1, name=bn_name_base+'2b',
  63. training=TRAINING, reuse=tf.AUTO_REUSE)
  64. outputs = tf.nn.relu(outputs)
  65. outputs = tf.layers.conv2d(outputs, filters,
  66. kernel_size=(3, 3),
  67. strides=(1, 1),
  68. kernel_initializer=tf.contrib.layers.variance_scaling_initializer(),
  69. name=conv_name_base+'2b', reuse=tf.AUTO_REUSE, padding='same',
  70. data_format='channels_first')
  71. outputs = tf.add(shortcut, outputs)
  72. return outputs
  73. def inference(images, training, filters, n, ver):
  74. """Construct the resnet model
  75. Args:
  76. images: [batch*channel*height*width]
  77. training: boolean
  78. filters: integer, the filters of the first resnet stage, the next stage will have filters*2
  79. n: integer, how many resnet blocks in each stage, the total layers number is 6n+2
  80. ver: integer, can be 1 or 2, for resnet v1 or v2
  81. Returns:
  82. Tensor, model inference output
  83. """
  84. #Layer1 is a 3*3 conv layer, input channels are 3, output channels are 16
  85. inputs = tf.layers.conv2d(images, filters=16, kernel_size=(3, 3), strides=(1, 1),
  86. name='conv1', reuse=tf.AUTO_REUSE, padding='same', data_format='channels_first')
  87. #no need to batch normal and activate for version 2 resnet.
  88. if ver==1:
  89. inputs = tf.layers.batch_normalization(inputs, axis=1, name='bn_conv1',
  90. training=training, reuse=tf.AUTO_REUSE)
  91. inputs = tf.nn.relu(inputs)
  92. for stage in range(3):
  93. stage_filter = filters*(2**stage)
  94. for i in range(n):
  95. stride = 1
  96. projection = False
  97. if i==0 and stage>0:
  98. stride = 2
  99. projection = True
  100. if ver==1:
  101. inputs = _resnet_block_v1(inputs, stage_filter, stride, projection,
  102. stage, blockname=str(i), TRAINING=training)
  103. else:
  104. inputs = _resnet_block_v2(inputs, stage_filter, stride, projection,
  105. stage, blockname=str(i), TRAINING=training)
  106. #only need for version 2 resnet.
  107. if ver==2:
  108. inputs = tf.layers.batch_normalization(inputs, axis=1, name='pre_activation_final_norm',
  109. training=training, reuse=tf.AUTO_REUSE)
  110. inputs = tf.nn.relu(inputs)
  111. axes = [2, 3]
  112. inputs = tf.reduce_mean(inputs, axes, keep_dims=True)
  113. inputs = tf.identity(inputs, 'final_reduce_mean')
  114. inputs = tf.reshape(inputs, [-1, filters*(2**2)])
  115. inputs = tf.layers.dense(inputs=inputs, units=10, name='dense1', reuse=tf.AUTO_REUSE)
  116. return inputs

       另外一部分的代码就是和Cifar10的处理相关的,其中Cifar10的50000张图片中选取45000张作为训练集,另外5000张作为验证集,测试的10000张图片都作为测试集。在98层的网络深度下,测试集的准确度可以达到92%左右.

  1. import tensorflow as tf
  2. import numpy as np
  3. import os
  4. import resnet_model
  5. #Construct the filenames that include the train cifar10 images
  6. folderPath = 'cifar-10-batches-bin/'
  7. filenames = [os.path.join(folderPath, 'data_batch_%d.bin' % i) for i in xrange(1,6)]
  8. #Define the parameters of the cifar10 image
  9. imageWidth = 32
  10. imageHeight = 32
  11. imageDepth = 3
  12. label_bytes = 1
  13. #Define the train and test batch size
  14. batch_size = 100
  15. test_batch_size = 100
  16. valid_batch_size = 100
  17. #Calulate the per image bytes and record bytes
  18. image_bytes = imageWidth * imageHeight * imageDepth
  19. record_bytes = label_bytes + image_bytes
  20. #Construct the dataset to read the train images
  21. dataset = tf.data.FixedLengthRecordDataset(filenames, record_bytes)
  22. dataset = dataset.shuffle(50000)
  23. #Get the first 45000 records as train dataset records
  24. train_dataset = dataset.take(45000)
  25. train_dataset = train_dataset.batch(batch_size)
  26. train_dataset = train_dataset.repeat(300)
  27. iterator = train_dataset.make_initializable_iterator()
  28. #Get the remain 5000 records as valid dataset records
  29. valid_dataset = dataset.skip(45000)
  30. valid_dataset = valid_dataset.batch(valid_batch_size)
  31. validiterator = valid_dataset.make_initializable_iterator()
  32. #Construct the dataset to read the test images
  33. testfilename = os.path.join(folderPath, 'test_batch.bin')
  34. testdataset = tf.data.FixedLengthRecordDataset(testfilename, record_bytes)
  35. testdataset = testdataset.batch(test_batch_size)
  36. testiterator = testdataset.make_initializable_iterator()
  37. #Decode the train records from the iterator
  38. record = iterator.get_next()
  39. record_decoded_bytes = tf.decode_raw(record, tf.uint8)
  40. #Get the labels from the records
  41. record_labels = tf.slice(record_decoded_bytes, [0, 0], [batch_size, 1])
  42. record_labels = tf.cast(record_labels, tf.int32)
  43. #Get the images from the records
  44. record_images = tf.slice(record_decoded_bytes, [0, 1], [batch_size, image_bytes])
  45. record_images = tf.reshape(record_images, [batch_size, imageDepth, imageHeight, imageWidth])
  46. record_images = tf.transpose(record_images, [0, 2, 3, 1])
  47. record_images = tf.cast(record_images, tf.float32)
  48. #Decode the records from the valid iterator
  49. validrecord = validiterator.get_next()
  50. validrecord_decoded_bytes = tf.decode_raw(validrecord, tf.uint8)
  51. #Get the labels from the records
  52. validrecord_labels = tf.slice(validrecord_decoded_bytes, [0, 0], [valid_batch_size, 1])
  53. validrecord_labels = tf.cast(validrecord_labels, tf.int32)
  54. validrecord_labels = tf.reshape(validrecord_labels, [-1])
  55. #Get the images from the records
  56. validrecord_images = tf.slice(validrecord_decoded_bytes, [0, 1], [valid_batch_size, image_bytes])
  57. validrecord_images = tf.cast(validrecord_images, tf.float32)
  58. validrecord_images = tf.reshape(validrecord_images,
  59. [valid_batch_size, imageDepth, imageHeight, imageWidth])
  60. validrecord_images = tf.transpose(validrecord_images, [0, 2, 3, 1])
  61. #Decode the test records from the iterator
  62. testrecord = testiterator.get_next()
  63. testrecord_decoded_bytes = tf.decode_raw(testrecord, tf.uint8)
  64. #Get the labels from the records
  65. testrecord_labels = tf.slice(testrecord_decoded_bytes, [0, 0], [test_batch_size, 1])
  66. testrecord_labels = tf.cast(testrecord_labels, tf.int32)
  67. testrecord_labels = tf.reshape(testrecord_labels, [-1])
  68. #Get the images from the records
  69. testrecord_images = tf.slice(testrecord_decoded_bytes, [0, 1], [test_batch_size, image_bytes])
  70. testrecord_images = tf.cast(testrecord_images, tf.float32)
  71. testrecord_images = tf.reshape(testrecord_images,
  72. [test_batch_size, imageDepth, imageHeight, imageWidth])
  73. testrecord_images = tf.transpose(testrecord_images, [0, 2, 3, 1])
  74. #Random crop the images after pad each side with 4 pixels
  75. distorted_images = tf.image.resize_image_with_crop_or_pad(record_images,
  76. imageHeight+8, imageWidth+8)
  77. distorted_images = tf.random_crop(distorted_images, size = [batch_size, imageHeight, imageHeight, 3])
  78. #Unstack the images as the follow up operation are on single train image
  79. distorted_images = tf.unstack(distorted_images)
  80. for i in xrange(len(distorted_images)):
  81. distorted_images[i] = tf.image.random_flip_left_right(distorted_images[i])
  82. distorted_images[i] = tf.image.random_brightness(distorted_images[i], max_delta=63)
  83. distorted_images[i] = tf.image.random_contrast(distorted_images[i], lower=0.2, upper=1.8)
  84. distorted_images[i] = tf.image.per_image_standardization(distorted_images[i])
  85. #Stack the images
  86. distorted_images = tf.stack(distorted_images)
  87. #transpose to set the channel first
  88. distorted_images = tf.transpose(distorted_images, perm=[0, 3, 1, 2])
  89. #Unstack the images as the follow up operation are on single image
  90. validrecord_images = tf.unstack(validrecord_images)
  91. for i in xrange(len(validrecord_images)):
  92. validrecord_images[i] = tf.image.per_image_standardization(validrecord_images[i])
  93. #Stack the images
  94. validrecord_images = tf.stack(validrecord_images)
  95. #transpose to set the channel first
  96. validrecord_images = tf.transpose(validrecord_images, perm=[0, 3, 1, 2])
  97. #Unstack the images as the follow up operation are on single image
  98. testrecord_images = tf.unstack(testrecord_images)
  99. for i in xrange(len(testrecord_images)):
  100. testrecord_images[i] = tf.image.per_image_standardization(testrecord_images[i])
  101. #Stack the images
  102. testrecord_images = tf.stack(testrecord_images)
  103. #transpose to set the channel first
  104. testrecord_images = tf.transpose(testrecord_images, perm=[0, 3, 1, 2])
  105. global_step = tf.Variable(0, trainable=False)
  106. boundaries = [10000, 15000, 20000, 25000]
  107. values = [0.1, 0.05, 0.01, 0.005, 0.001]
  108. learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)
  109. weight_decay = 2e-4
  110. filters = 16 #the first resnet block filter number
  111. n = 5 #the basic resnet block number, total network layers are 6n+2
  112. ver = 2 #the resnet block version
  113. #Get the inference logits by the model
  114. result = resnet_model.inference(distorted_images, True, filters, n, ver)
  115. #Calculate the cross entropy loss
  116. cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=record_labels, logits=result)
  117. cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  118. #Add the l2 weights to the loss
  119. #Add weight decay to the loss.
  120. l2_loss = weight_decay * tf.add_n(
  121. # loss is computed using fp32 for numerical stability.
  122. [tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables()])
  123. tf.summary.scalar('l2_loss', l2_loss)
  124. loss = cross_entropy_mean + l2_loss
  125. #Define the optimizer
  126. optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
  127. #Relate to the batch normalization
  128. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  129. with tf.control_dependencies(update_ops):
  130. opt_op = optimizer.minimize(loss, global_step)
  131. valid_accuracy = tf.placeholder(tf.float32)
  132. test_accuracy = tf.placeholder(tf.float32)
  133. tf.summary.scalar("valid_accuracy", valid_accuracy)
  134. tf.summary.scalar("test_accuracy", test_accuracy)
  135. tf.summary.scalar("learning_rate", learning_rate)
  136. validresult = tf.argmax(resnet_model.inference(validrecord_images, False, filters, n, ver), axis=1)
  137. testresult = tf.argmax(resnet_model.inference(testrecord_images, False, filters, n, ver), axis=1)
  138. #Create the session and run the graph
  139. sess = tf.Session()
  140. sess.run(tf.global_variables_initializer())
  141. sess.run(iterator.initializer)
  142. #Merge all the summary and write
  143. summary_op = tf.summary.merge_all()
  144. train_filewriter = tf.summary.FileWriter('train/', sess.graph)
  145. step = 0
  146. while(True):
  147. try:
  148. lossValue, lr, _ = sess.run([loss, learning_rate, opt_op])
  149. if step % 100 == 0:
  150. print "step %i: Learning_rate: %f Loss: %f" %(step, lr, lossValue)
  151. if step % 1000 == 0:
  152. saver.save(sess, 'model/my-model', global_step=step)
  153. truepredictNum = 0
  154. sess.run([testiterator.initializer, validiterator.initializer])
  155. accuracy1 = 0.0
  156. accuracy2 = 0.0
  157. while(True):
  158. try:
  159. predictValue, testValue = sess.run([validresult, validrecord_labels])
  160. truepredictNum += np.sum(predictValue==testValue)
  161. except tf.errors.OutOfRangeError:
  162. print "valid correct num: %i" %(truepredictNum)
  163. accuracy1 = truepredictNum / 5000.0
  164. break
  165. truepredictNum = 0
  166. while(True):
  167. try:
  168. predictValue, testValue = sess.run([testresult, testrecord_labels])
  169. truepredictNum += np.sum(predictValue==testValue)
  170. except tf.errors.OutOfRangeError:
  171. print "test correct num: %i" %(truepredictNum)
  172. accuracy2 = truepredictNum / 10000.0
  173. break
  174. summary = sess.run(summary_op, feed_dict={valid_accuracy: accuracy1, test_accuracy: accuracy2})
  175. train_filewriter.add_summary(summary, step)
  176. step += 1
  177. except tf.errors.OutOfRangeError:
  178. break

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/295659
推荐阅读
相关标签
  

闽ICP备14008679号