当前位置:   article > 正文

用训练好的Alexnet模型测试迁移学习_基于迁移学习的原理,加载预训练模型alexnet的参数

基于迁移学习的原理,加载预训练模型alexnet的参数

最近用Tensorflow实现了一个Alexnet的模型,并在Imagenet的数据集上跑了一下训练,测试结果是能达到Top5接近71%的准确度。我想测试一下这个训练好的模型是否可以用于其他的图像分类的任务中,因此我选取了Tensorflow的迁移学习教程里面提到的Flowers分类任务的数据来做一个测试。

首先是下载Flowers的数据,具体可以参见Tensorflow里面的介绍。下载后的数据解压之后有5个文件夹,分别带有5种不同的花的图像。我编写了一个程序来把图像转换为TFRECORD格式的数据,方便后续的处理。代码如下:

  1. import os
  2. import cv2
  3. import tensorflow as tf
  4. import numpy as np
  5. def make_example(image, label):
  6. return tf.train.Example(features=tf.train.Features(feature={
  7. 'image' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
  8. 'label' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[label]))
  9. }))
  10. flower_classes = {"daisy":0, "dandelion":1, "roses":2, "sunflowers":3, "tulips":4}
  11. for flower_class in flower_classes.keys():
  12. writer = tf.python_io.TFRecordWriter(flower_class+".tfrecord")
  13. folder_path = flower_class + "/"
  14. files = os.listdir(folder_path)
  15. label = np.array([flower_classes[flower_class]])
  16. for jpgfile in files:
  17. img = cv2.imread(folder_path+"/"+jpgfile, cv2.IMREAD_COLOR)
  18. img_jpg = cv2.imencode('.jpg', img)[1].tobytes()
  19. ex = make_example(img_jpg, label.tobytes())
  20. writer.write(ex.SerializeToString())
  21. writer.close()

程序执行完后会生成5个TFRECORD文件,每个文件对应一种花的图像数据。

之后我们可以利用之前训练好的Alexnet来进行迁移学习了。在上一篇博客中我已介绍了我的Alexnet的模型,我将增加一个新的全连接层来取代原有模型中的最后一层的全连接层,保留之前几层的训练好的参数,只用新的图像数据来训练新加的全连接层。为此,我需要把新的图像数据用原有的Alexnet模型计算后,把倒数第2层的计算结果输出,然后在用新加的全连接层进行计算。原有的Alexnet模型的这几层的参数都要设置为不可训练,只训练新加的层即可。以下是原有的Alexnet模型的代码,注意倒数第2层及之上的层的参数需要改为Trainable=False,模型代码如下:

  1. def inference(images, dropout_rate=1.0, wd=None):
  2. with tf.variable_scope('conv1', reuse=tf.AUTO_REUSE):
  3. kernel = tf.get_variable(initializer=tf.truncated_normal([11,11,3,96], dtype=tf.float32, stddev=1e-1), trainable=False, name='weights')
  4. conv = tf.nn.conv2d(images, kernel, [1,4,4,1], padding='SAME')
  5. biases = tf.get_variable(initializer=tf.constant(0.1, shape=[96], dtype=tf.float32), trainable=False, name='biases')
  6. bias = tf.nn.bias_add(conv, biases)
  7. conv1 = tf.nn.relu(bias, name='conv1')
  8. #lrn1 = tf.nn.lrn(conv1, 4, bias=1.0, alpha=0.001/9, beta=0.75, name='lrn1')
  9. pool1 = tf.nn.max_pool(conv1, ksize=[1,3,3,1], strides=[1,2,2,1], padding='VALID', name='pool1')
  10. with tf.variable_scope('conv2', reuse=tf.AUTO_REUSE):
  11. kernel = tf.get_variable(initializer=tf.truncated_normal([5,5,96,256], dtype=tf.float32, stddev=1e-1), trainable=False, name='weights')
  12. conv = tf.nn.conv2d(pool1, kernel, [1,1,1,1], padding='SAME')
  13. biases = tf.get_variable(initializer=tf.constant(0.1, shape=[256], dtype=tf.float32), trainable=False, name='biases')
  14. bias = tf.nn.bias_add(conv, biases)
  15. conv2 = tf.nn.relu(bias, name='conv2')
  16. #lrn2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001/9, beta=0.75, name='lrn2')
  17. pool2 = tf.nn.max_pool(conv2, ksize=[1,3,3,1], strides=[1,2,2,1], padding='VALID', name='pool2')
  18. with tf.variable_scope('conv3', reuse=tf.AUTO_REUSE):
  19. kernel = tf.get_variable(initializer=tf.truncated_normal([3,3,256,384], dtype=tf.float32, stddev=1e-1), trainable=False, name='weights')
  20. conv = tf.nn.conv2d(pool2, kernel, [1,1,1,1], padding='SAME')
  21. biases = tf.get_variable(initializer=tf.constant(0.1, shape=[384], dtype=tf.float32), trainable=False, name='biases')
  22. bias = tf.nn.bias_add(conv, biases)
  23. conv3 = tf.nn.relu(bias, name='conv3')
  24. with tf.variable_scope('conv4', reuse=tf.AUTO_REUSE):
  25. kernel = tf.get_variable(initializer=tf.truncated_normal([3,3,384,384], dtype=tf.float32, stddev=1e-1), trainable=False, name='weights')
  26. conv = tf.nn.conv2d(conv3, kernel, [1,1,1,1], padding='SAME')
  27. biases = tf.get_variable(initializer=tf.constant(0.1, shape=[384], dtype=tf.float32), trainable=False, name='biases')
  28. bias = tf.nn.bias_add(conv, biases)
  29. conv4 = tf.nn.relu(bias, name='conv4')
  30. with tf.variable_scope('conv5', reuse=tf.AUTO_REUSE):
  31. kernel = tf.get_variable(initializer=tf.truncated_normal([3,3,384,256], dtype=tf.float32, stddev=1e-1), trainable=False, name='weights')
  32. conv = tf.nn.conv2d(conv4, kernel, [1,1,1,1], padding='SAME')
  33. biases = tf.get_variable(initializer=tf.constant(0.1, shape=[256], dtype=tf.float32), trainable=False, name='biases')
  34. bias = tf.nn.bias_add(conv, biases)
  35. conv5 = tf.nn.relu(bias, name='conv5')
  36. pool5 = tf.nn.max_pool(conv5, ksize=[1,3,3,1], strides=[1,2,2,1], padding='VALID', name='pool5')
  37. flatten = tf.layers.flatten(inputs=pool5, name='flatten')
  38. with tf.variable_scope('local1', reuse=tf.AUTO_REUSE):
  39. weights = tf.get_variable(initializer=tf.truncated_normal([6*6*256,4096], dtype=tf.float32, stddev=1/4096.0), trainable=False, name='weights')
  40. if wd is not None:
  41. weights_loss = tf.multiply(tf.nn.l2_loss(weights), wd, name='weight_loss')
  42. tf.add_to_collection('losses', weights_loss)
  43. biases = tf.get_variable(initializer=tf.constant(1.0, shape=[4096], dtype=tf.float32), trainable=False, name='biases')
  44. local1 = tf.nn.relu(tf.nn.xw_plus_b(flatten, weights, biases), name='local1')
  45. local1 = tf.nn.dropout(local1, dropout_rate)
  46. with tf.variable_scope('local2', reuse=tf.AUTO_REUSE):
  47. weights = tf.get_variable(initializer=tf.truncated_normal([4096,4096], dtype=tf.float32, stddev=1/4096.0), trainable=False, name='weights')
  48. if wd is not None:
  49. weights_loss = tf.multiply(tf.nn.l2_loss(weights), wd, name='weight_loss')
  50. tf.add_to_collection('losses', weights_loss)
  51. biases = tf.get_variable(initializer=tf.constant(1.0, shape=[4096], dtype=tf.float32), trainable=False, name='biases')
  52. local2 = tf.nn.relu(tf.nn.xw_plus_b(local1, weights, biases), name='local2')
  53. local2 = tf.nn.dropout(local2, dropout_rate)
  54. with tf.variable_scope('local3', reuse=tf.AUTO_REUSE):
  55. weights = tf.get_variable(initializer=tf.truncated_normal([4096,1000], dtype=tf.float32, stddev=1e-3), trainable=True, name='weights')
  56. biases = tf.get_variable(initializer=tf.constant(1.0, shape=[1000], dtype=tf.float32), trainable=True, name='biases')
  57. local3 = tf.nn.xw_plus_b(local2, weights, biases, name='local3')
  58. return local3, local2

然后我们就可以编写代码来读取之前训练的参数来进行新的训练了,代码如下:

  1. import tensorflow as tf
  2. import alexnet_model
  3. imageWidth = 224
  4. imageHeight = 224
  5. imageDepth = 3
  6. batch_size = 10
  7. resize_min = 256
  8. #解析TFRecord文件的格式
  9. def _parse_function(example_proto):
  10. features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
  11. "label": tf.FixedLenFeature((), tf.string, default_value="")}
  12. parsed_features = tf.parse_single_example(example_proto, features)
  13. image_decoded = tf.image.decode_jpeg(parsed_features["image"], channels=3)
  14. shape = tf.shape(image_decoded)
  15. height, width = shape[0], shape[1]
  16. resized_height, resized_width = tf.cond(height<width,
  17. lambda: (resize_min, tf.cast(tf.multiply(tf.cast(width, tf.float64),tf.divide(resize_min,height)), tf.int32)),
  18. lambda: (tf.cast(tf.multiply(tf.cast(height, tf.float64),tf.divide(resize_min,width)), tf.int32), resize_min))
  19. image_resized = tf.image.resize_images(image_decoded, [resized_height, resized_width])
  20. image_resized = tf.cast(image_resized, tf.uint8)
  21. image_resized = tf.image.convert_image_dtype(image_resized, tf.float32)
  22. # calculate how many to be center crop
  23. shape = tf.shape(image_resized)
  24. height, width = shape[0], shape[1]
  25. amount_to_be_cropped_h = (height - imageHeight)
  26. crop_top = amount_to_be_cropped_h // 2
  27. amount_to_be_cropped_w = (width - imageWidth)
  28. crop_left = amount_to_be_cropped_w // 2
  29. image_valid = tf.slice(image_resized, [crop_top, crop_left, 0], [imageHeight, imageWidth, -1])
  30. return image_valid, parsed_features["label"]
  31. #文件列表,分为3个Dataset, Train, Validation, Test. 其中Train大概占80%,其他两个Dataset大概各占10%的数据量
  32. filenames = ["flower_photos/daisy.tfrecord", "flower_photos/dandelion.tfrecord", "flower_photos/roses.tfrecord", "flower_photos/sunflowers.tfrecord", "flower_photos/tulips.tfrecord"]
  33. dataset = tf.data.TFRecordDataset(filenames)
  34. dataset = dataset.map(_parse_function).shuffle(3670)
  35. #Get the first 2920 records for training dataset
  36. train_dataset = dataset.take(2920)
  37. train_dataset = train_dataset.batch(batch_size)
  38. train_dataset = train_dataset.repeat(50)
  39. train_iterator = train_dataset.make_initializable_iterator()
  40. images, labels = train_iterator.get_next()
  41. images_batch = tf.reshape(images, [batch_size, imageHeight, imageWidth, imageDepth])
  42. labels_raw = tf.decode_raw(labels, tf.int64)
  43. labels_batch = tf.reshape(labels_raw, [batch_size])
  44. #Get the other 370 records for validation dataset
  45. valid_dataset = dataset.skip(2920).take(370)
  46. valid_dataset = valid_dataset.batch(batch_size)
  47. valid_iterator = valid_dataset.make_initializable_iterator()
  48. valid_images, valid_labels = valid_iterator.get_next()
  49. valid_images_batch = tf.reshape(valid_images, [batch_size, imageHeight, imageWidth, imageDepth])
  50. valid_labels_raw = tf.decode_raw(valid_labels, tf.int64)
  51. valid_labels_batch = tf.reshape(valid_labels_raw, [batch_size])
  52. #Get the remaining 380 records for test dataset
  53. test_dataset = dataset.skip(3290)
  54. test_dataset = test_dataset.batch(batch_size)
  55. test_iterator = test_dataset.make_initializable_iterator()
  56. test_images, test_labels = test_iterator.get_next()
  57. #test_images_raw = tf.decode_raw(test_images, tf.uint8)
  58. #test_images_batch = tf.image.convert_image_dtype(test_images_raw, tf.float32)
  59. test_images_batch = tf.reshape(test_images, [batch_size, imageHeight, imageWidth, imageDepth])
  60. test_labels_raw = tf.decode_raw(test_labels, tf.int64)
  61. test_labels_batch = tf.reshape(test_labels_raw, [batch_size])
  62. #设置输入图像和标签的Placeholder
  63. input_images = tf.placeholder(tf.float32, (batch_size, imageHeight, imageWidth, imageDepth))
  64. input_labels = tf.placeholder(tf.int64, (batch_size))
  65. #Bottleneck Input是把原始图像经过Alexnet_model模型处理后输出的特征值,维度为4096
  66. _, bottleneck_input = alexnet_model.inference(input_images, dropout_rate=1.0, wd=None)
  67. #增加一个全连接层,把输入的Bottleneck Input输出为5个维度的Softmax
  68. output_logits = tf.layers.dense(bottleneck_input, units=5, activation=None,
  69. kernel_initializer=tf.initializers.truncated_normal,
  70. name="output_layer", reuse=tf.AUTO_REUSE)
  71. #定义Loss
  72. loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=input_labels, logits=output_logits)
  73. loss_mean = tf.reduce_mean(loss)
  74. #对输出结果进行Softmax
  75. output_result = tf.argmax(tf.nn.softmax(output_logits), 1)
  76. #计算准确度
  77. accuracy_batch = tf.reduce_mean(tf.cast(tf.equal(input_labels, output_result), tf.float32))
  78. #定义训练参数,前30个EPOCH的学习率为0.1, 30-40个EPOCH的学习率为0.05...
  79. global_step = tf.Variable(0, trainable=False)
  80. epoch_steps = int(2920/batch_size)
  81. boundaries = [epoch_steps*30, epoch_steps*40, epoch_steps*50]
  82. values = [0.01, 0.005, 0.001, 0.0005]
  83. learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)
  84. optimizer = tf.train.GradientDescentOptimizer(learning_rate, global_step=global_step)
  85. opt_op = optimizer.minimize(loss_mean)
  86. #加载之前训练的Alexnet的参数
  87. saver=tf.train.Saver(tf.all_variables()[:-3])
  88. #进行训练,每100个Step输出LOSS值。每一个EPOCH训练完成后分别计算Valiation和Test两个数据集的准确度
  89. with tf.Session() as sess:
  90. sess.run(tf.global_variables_initializer())
  91. saver.restore(sess, "model/model.ckpt-450000")
  92. #saver_alexnet.restore(sess, alexnet_graph_weight)
  93. sess.run(global_step.initializer)
  94. sess.run(train_iterator.initializer)
  95. step = 0
  96. total_loss = 0.0
  97. epoch = 0
  98. while(True):
  99. try:
  100. step += 1
  101. images_i, labels_i = sess.run([images_batch, labels_batch])
  102. loss_a, lr, _ = sess.run([loss_mean, learning_rate, opt_op], feed_dict={input_images: images_i, input_labels: labels_i})
  103. total_loss += loss_a
  104. if step%100==0:
  105. print("step %i Learning_rate: %f Loss: %f" %(step, lr, total_loss/100))
  106. total_loss = 0.0
  107. if step%epoch_steps==0:
  108. sess.run([valid_iterator.initializer, test_iterator.initializer])
  109. valid_step = 0
  110. test_step = 0
  111. accuracy_valid = 0.0
  112. accuracy_test = 0.0
  113. epoch += 1
  114. while(True):
  115. try:
  116. images_v, labels_v = sess.run([valid_images_batch, valid_labels_batch])
  117. accuracy_valid += sess.run(accuracy_batch, feed_dict={input_images: images_v, input_labels: labels_v})
  118. valid_step += 1
  119. except tf.errors.OutOfRangeError:
  120. print("epoch %i validation accuracy: %f" %(epoch, accuracy_valid/valid_step))
  121. break
  122. while(True):
  123. try:
  124. images_t, labels_t = sess.run([test_images_batch, test_labels_batch])
  125. accuracy_test += sess.run(accuracy_batch, feed_dict={input_images: images_t, input_labels: labels_t})
  126. test_step += 1
  127. except tf.errors.OutOfRangeError:
  128. print("epoch %i test accuracy: %f" %(epoch, accuracy_test/test_step))
  129. break
  130. except tf.errors.OutOfRangeError:
  131. break

训练第一个EPOCH,测试集准确率就可以达到60%,可见迁移测试确实利用了原有模型来有效的提取了图像的特征,加快了训练速度。在训练50个EPOCH,在测试集上达到96%左右的准确率,和Tensorflow官网上用Inception V3模型迁移学习测试的结果相近。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/123975
推荐阅读
相关标签
  

闽ICP备14008679号