当前位置:   article > 正文

深度学习 基于TensorFlow的植物图像识别及其可视化界面设计_基于图像识别的植物识别代码

基于图像识别的植物识别代码

该小项目的最终的界面如图所示:

本项目中的所有植物图像均为自己拍摄,一共12种植物,每种1250张,共15000张。网络采用VGG-16,Resnet50,ALEXNET,可以随便切换网络进行训练。

本文仅供大家学习讨论,本人也是参考了很多位大佬的程序,如有错误还请大家指正。

如需完整代码还请支持一下我这个艰难的求学者,不为赚钱,谨为改善学习和生活的条件,真心感谢:

下面介绍相关实现部分(不含GUI界面的代码,GUI使用Pyqt5编写)

首先导入各种包

  1. import logging
  2. import os
  3. import pickle
  4. import random
  5. import time
  6. import numpy as np
  7. import tensorflow as tf
  8. import tensorflow.contrib.slim as slim
  9. from PIL import Image

所有图像按照如下方式进行分类,每一个文件夹五位数,最后两位为品种编号,代表一种植物图像,每种文件夹内的图像按照如下方式进行命名,即前两位与文件夹名称保持一致,后三位从000-999为个体编号。图像大小均为224*224*3.

对图像的label进行读取,遍历整个文件夹获取图像的名称,前两位就是图像的label,相关代码实现如下。

  1. def __init__(self, data_dir):
  2. truncate_path = data_dir + ('%05d' % FLAGS.charset_size)
  3. self.image_names = []
  4. for root, sub_folder, file_list in os.walk(data_dir):
  5. print(root)
  6. if root < truncate_path:
  7. self.image_names += [os.path.join(root, file_path) for file_path in file_list]
  8. random.shuffle(self.image_names)
  9. print(self.image_names)
  10. self.labels = [int(file_name[len(data_dir):].split(os.sep)[0]) for file_name in self.image_names]
  11. print(self.labels)

为了解决图像数据集过少的问题,引入了图像增强操作扩充数据集,使用随机上下、左右翻转、在一定范围内随机调整亮度、对比度、饱和度、色相等。可以按需要开启关闭或者调参。

  1. def data_augmentation(images):
  2. if FLAGS.random_flip_up_down:
  3. images = tf.image.random_flip_up_down(images)
  4. if FLAGS.random_flip_left_right:
  5. images = tf.image.random_flip_left_right(images)
  6. if FLAGS.random_brightness:
  7. images = tf.image.random_brightness(images, max_delta=0.1)
  8. if FLAGS.random_contrast:
  9. images = tf.image.random_contrast(images, 0.9, 1.1)
  10. if FLAGS.resize_image_with_crop_or_pad:
  11. images = tf.image.resize_image_with_crop_or_pad(images, FLAGS.image_size, FLAGS.image_size)
  12. if FLAGS.random_saturation:
  13. images = tf.image.random_saturation(images, 0.9, 1.1)
  14. if FLAGS.random_hue:
  15. images = tf.image.random_hue(images, max_delta=0.1)
  16. return images

构造批处理队列,将label放入队列中。

  1. def input_pipeline(self, batch_size, num_epochs=None):
  2. images_tensor = tf.convert_to_tensor(self.image_names, dtype=tf.string)
  3. labels_tensor = tf.convert_to_tensor(self.labels, dtype=tf.int64)
  4. input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor], num_epochs=num_epochs)
  5. labels = input_queue[1]
  6. images_content = tf.read_file(input_queue[0])
  7. images = tf.image.convert_image_dtype(tf.image.decode_jpeg(images_content, channels=3), tf.float32)
  8. images = self.data_augmentation(images)
  9. new_size = tf.constant([FLAGS.image_size, FLAGS.image_size], dtype=tf.int32)
  10. images = tf.image.resize_images(images, new_size)
  11. image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=150,
  12. min_after_dequeue=10)
  13. return image_batch, label_batch

神经网络部分就跳过了,接着就是返回一些相关参数,比如准确率、topk,loss,step等等一系列的。

  1. def build_graph(top_k):
  2. with tf.device('/gpu:0'):
  3. keep_prob = tf.placeholder(dtype=tf.float32, shape=[], name='keep_prob')
  4. images = tf.placeholder(dtype=tf.float32, shape=[None, FLAGS.image_size, FLAGS.image_size, FLAGS.pic_channel],
  5. name='image_batch')
  6. labels = tf.placeholder(dtype=tf.int64, shape=[None], name='label_batch')
  7. logits = cnn(images)
  8. with tf.device('/gpu:0'):
  9. with tf.name_scope("loss"):
  10. loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))
  11. with tf.device('/gpu:0'):
  12. with tf.name_scope("accuracy"):
  13. accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32))
  14. global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False)
  15. rate = tf.train.exponential_decay(2e-4, global_step, decay_steps=2000, decay_rate=0.97, staircase=True)
  16. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  17. with tf.control_dependencies(update_ops):
  18. train_op = tf.train.AdamOptimizer(learning_rate=rate).minimize(loss, global_step=global_step)
  19. with tf.name_scope("probabilities"):
  20. with tf.device('/gpu:0'):
  21. probabilities = tf.nn.softmax(logits)
  22. predicted_val_top_k, predicted_index_top_k = tf.nn.top_k(probabilities, k=top_k)
  23. accuracy_in_top_k = tf.reduce_mean(tf.cast(tf.nn.in_top_k(probabilities, labels, top_k), tf.float32))
  24. tf.summary.scalar('loss', loss)
  25. tf.summary.scalar('accuracy', accuracy)
  26. tf.summary.scalar('top_k', accuracy_in_top_k)
  27. merged_summary_op = tf.summary.merge_all()
  28. return {'images': images,
  29. 'labels': labels,
  30. 'keep_prob': keep_prob,
  31. 'top_k': top_k,
  32. 'global_step': global_step,
  33. 'train_op': train_op,
  34. 'loss': loss,
  35. 'accuracy': accuracy,
  36. 'accuracy_top_k': accuracy_in_top_k,
  37. 'merged_summary_op': merged_summary_op,
  38. 'predicted_distribution': probabilities,
  39. 'predicted_index_top_k': predicted_index_top_k,
  40. 'predicted_val_top_k': predicted_val_top_k}

随后开启回话进行训练,保存模型,写入log。

  1. def train():
  2. print('Begin training')
  3. train_feeder = DataIterator(data_dir=FLAGS.train_data_dir)
  4. test_feeder = DataIterator(data_dir=FLAGS.test_data_dir)
  5. with tf.Session() as sess:
  6. train_images, train_labels = train_feeder.input_pipeline(batch_size=FLAGS.batch_size, num_epochs=FLAGS.epoch)
  7. test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size)
  8. graph = build_graph(top_k=5)
  9. sess.run(tf.global_variables_initializer())
  10. sess.run(tf.local_variables_initializer())
  11. coord = tf.train.Coordinator()
  12. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  13. saver = tf.train.Saver()
  14. train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
  15. test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test')
  16. start_step = 0
  17. if FLAGS.restore:
  18. ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
  19. if ckpt:
  20. saver.restore(sess, ckpt)
  21. print("restore from the checkpoint {0}".format(ckpt))
  22. start_step += int(ckpt.split('-')[-1])
  23. logger.info(':::Training Start:::')
  24. try:
  25. while not coord.should_stop():
  26. start_time = time.time()
  27. train_images_batch, train_labels_batch = sess.run([train_images, train_labels])
  28. feed_dict = {graph['images']: train_images_batch,
  29. graph['labels']: train_labels_batch,
  30. graph['keep_prob']: 0.8}
  31. _, loss_val, accuracy_train, train_summary, step = sess.run([graph['train_op'],
  32. graph['loss'],
  33. graph['accuracy'],
  34. graph['merged_summary_op'],
  35. graph['global_step']], feed_dict=feed_dict)
  36. train_writer.add_summary(train_summary, step)
  37. # print(train_labels_batch)
  38. # Ending time
  39. end_time = time.time()
  40. logger.info("the step: {0} takes {1}s loss: {2} accuracy: {3}%".format(round(step, 0),
  41. round(end_time - start_time,
  42. 2), round(loss_val, 2),
  43. round(accuracy_train * 100,
  44. 2)))
  45. if step > FLAGS.max_steps:
  46. break
  47. if step % FLAGS.eval_steps == 1:
  48. test_images_batch, test_labels_batch = sess.run([test_images, test_labels])
  49. feed_dict = {graph['images']: test_images_batch,
  50. graph['labels']: test_labels_batch,
  51. graph['keep_prob']: 1.0}
  52. accuracy_test, test_summary = sess.run([graph['accuracy'],
  53. graph['merged_summary_op']], feed_dict=feed_dict)
  54. test_writer.add_summary(test_summary, step)
  55. logger.info('======================= Eval a batch =======================')
  56. logger.info('the step: {0} test accuracy: {1} %'.format(step, round(accuracy_test * 100, 2)))
  57. logger.info('======================= Eval a batch =======================')
  58. if step % FLAGS.save_steps == 1:
  59. logger.info('Save the ckpt of {0}'.format(step))
  60. saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=graph['global_step'])
  61. except tf.errors.OutOfRangeError: # Raised when an operation iterates past the valid input range.
  62. logger.info('==================Train Finished================')
  63. saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=graph['global_step'])
  64. finally:
  65. coord.request_stop()
  66. coord.join(threads)

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号