当前位置:   article > 正文

【tensorflow 深度学习】8.训练图片分类模型_tensorflow 图片分类模型

tensorflow 图片分类模型

1.训练图片分类模型的三种方法

    (1).从无到有,先确定好算法框架,准备好需要训练的数据集,从头开始训练,参数一开始也是初始化的随机值,一个批次一个批次地进行训练。

    (2).准备好已经训练好的模型,权值参数也都已经确定,只训练最后一层,因为前面的参数都是经过大量图片的训练来的,所以参数都比较好,比如卷积层主要的作用的对图像特征的提取,我们要做自己的分类模型的话也得对图像进行特征提取,做特征提取的话直接使用训练好的权值也行。

    (3).跟2差不多,不同的是都之前的参数也做微调。


2.retrain图片分类模型

    (1). https://github.com/tensorflow/tensorflow 下载官方包 有一些官方提供的案例,里面有后面要用到的retrain.py文件。


    (2).下载图片集

    网址:http://www.robots.ox.ac.uk/~vgg/data/

    

    (3)然后写批处理文件

    :

  1. activate py3 ^
  2. python E:/graduate_student/deep_learning/tensorflow-master/tensorflow-master/tensorflow/examples/image_retraining/retrain.py ^
  3. --bottleneck_dir bottleneck ^
  4. --how_many_training_steps 200 ^
  5. --model_dir D:/software/mycodes/python35/py3/inception_model/ ^
  6. --output_graph output_graph.pb ^
  7. --output_labels output_labels.txt ^
  8. --image_dir imagedata/
  9. pause

    bottleneck 瓶颈,图片预处理的时候把这个值算出来

    output_graph 输出训练好的模型到当前文件夹

    output_labels 输出训练好的标签到当前文件夹 

    image的格式需要里面文件夹的名字代表分类类型。(里面图片不能由大写字母也不能由中文!!)


    运行批处理文件,得到:



    (4)测试训练好的模型:

    代码:

  1. # coding: utf-8
  2. # In[1]:
  3. import tensorflow as tf
  4. import os
  5. import numpy as np
  6. import re
  7. from PIL import Image
  8. import matplotlib.pyplot as plt
  9. # In[2]:
  10. lines = tf.gfile.GFile('E:/graduate_student/deep_learning/a-tensorflow/9/retain/output_labels.txt').readlines()
  11. uid_to_human = {}
  12. #一行一行读取数据
  13. for uid,line in enumerate(lines) :
  14. #去掉换行符
  15. line=line.strip('\n')
  16. uid_to_human[uid] = line
  17. def id_to_string(node_id):
  18. if node_id not in uid_to_human:
  19. return ''
  20. return uid_to_human[node_id]
  21. #创建一个图来存放google训练好的模型
  22. with tf.gfile.FastGFile('E:/graduate_student/deep_learning/a-tensorflow/9/retain/output_graph.pb', 'rb') as f:
  23. graph_def = tf.GraphDef()
  24. graph_def.ParseFromString(f.read())
  25. tf.import_graph_def(graph_def, name='')
  26. with tf.Session() as sess:
  27. softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
  28. #遍历目录
  29. for root,dirs,files in os.walk('E:/graduate_student/deep_learning/a-tensorflow/9/retain/images/'):
  30. for file in files:
  31. #载入图片
  32. image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()
  33. predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式
  34. predictions = np.squeeze(predictions)#把结果转为1维数据
  35. #打印图片路径及名称
  36. image_path = os.path.join(root,file)
  37. print(image_path)
  38. #显示图片
  39. img=Image.open(image_path)
  40. plt.imshow(img)
  41. plt.axis('off')
  42. plt.show()
  43. #排序
  44. top_k = predictions.argsort()[::-1]
  45. print(top_k)
  46. for node_id in top_k:
  47. #获取分类名称
  48. human_string = id_to_string(node_id)
  49. #获取该分类的置信度
  50. score = predictions[node_id]
  51. print('%s (score = %.5f)' % (human_string, score))
  52. print()

结果:


    使用finetune的优点:1.训练速度快,计算量少,只计算最后一层。2.迭代周期少,因为训练的权值少。3.需要使用到图片的数据量比较少。

3.重头开始训练图片识别模型

    (1)tensorflow官方包里下载:https://github.com/tensorflow/models


       

    (2)准备好分类图片

    (3)图像预处理,生成tfrecord文件。

    程序:

  1. # coding: utf-8
  2. # In[2]:
  3. import tensorflow as tf
  4. import os
  5. import random
  6. import math
  7. import sys
  8. # In[3]:
  9. #验证集数量
  10. _NUM_TEST = 500
  11. #随机种子
  12. _RANDOM_SEED = 0
  13. #数据块
  14. _NUM_SHARDS = 5
  15. #数据集路径
  16. DATASET_DIR = "E:/graduate_student/deep_learning/a-tensorflow\9/retain/imagedata/"
  17. #标签文件名字
  18. LABELS_FILENAME = "E:/graduate_student/deep_learning/a-tensorflow/9/retain/labels.txt"
  19. #定义tfrecord文件的路径+名字
  20. def _get_dataset_filename(dataset_dir, split_name, shard_id):
  21. output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name, shard_id, _NUM_SHARDS)
  22. return os.path.join(dataset_dir, output_filename)
  23. #判断tfrecord文件是否存在
  24. def _dataset_exists(dataset_dir):
  25. for split_name in ['train', 'test']:
  26. for shard_id in range(_NUM_SHARDS):
  27. #定义tfrecord文件的路径+名字
  28. output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
  29. if not tf.gfile.Exists(output_filename):
  30. return False
  31. return True
  32. #获取所有文件以及分类
  33. def _get_filenames_and_classes(dataset_dir):
  34. #数据目录
  35. directories = []
  36. #分类名称
  37. class_names = []
  38. for filename in os.listdir(dataset_dir):
  39. #合并文件路径
  40. path = os.path.join(dataset_dir, filename)
  41. #判断该路径是否为目录
  42. if os.path.isdir(path):
  43. #加入数据目录
  44. directories.append(path)
  45. #加入类别名称
  46. class_names.append(filename)
  47. photo_filenames = []
  48. #循环每个分类的文件夹
  49. for directory in directories:
  50. for filename in os.listdir(directory):
  51. path = os.path.join(directory, filename)
  52. #把图片加入图片列表
  53. photo_filenames.append(path)
  54. return photo_filenames, class_names
  55. def int64_feature(values):
  56. if not isinstance(values, (tuple, list)):
  57. values = [values]
  58. return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
  59. def bytes_feature(values):
  60. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
  61. def image_to_tfexample(image_data, image_format, class_id):
  62. #Abstract base class for protocol messages.
  63. return tf.train.Example(features=tf.train.Features(feature={
  64. 'image/encoded': bytes_feature(image_data),
  65. 'image/format': bytes_feature(image_format),
  66. 'image/class/label': int64_feature(class_id),
  67. }))
  68. def write_label_file(labels_to_class_names, dataset_dir,filename=LABELS_FILENAME):
  69. labels_filename = os.path.join(dataset_dir, filename)
  70. with tf.gfile.Open(labels_filename, 'w') as f:
  71. for label in labels_to_class_names:
  72. class_name = labels_to_class_names[label]
  73. f.write('%d:%s\n' % (label, class_name))
  74. #把数据转为TFRecord格式
  75. def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
  76. assert split_name in ['train', 'test']
  77. #计算每个数据块有多少数据
  78. num_per_shard = int(len(filenames) / _NUM_SHARDS)
  79. with tf.Graph().as_default():
  80. with tf.Session() as sess:
  81. for shard_id in range(_NUM_SHARDS):
  82. #定义tfrecord文件的路径+名字
  83. output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
  84. with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
  85. #每一个数据块开始的位置
  86. start_ndx = shard_id * num_per_shard
  87. #每一个数据块最后的位置
  88. end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
  89. for i in range(start_ndx, end_ndx):
  90. try:
  91. sys.stdout.write('\r>> Converting image %d/%d shard %d' % (i+1, len(filenames), shard_id))
  92. sys.stdout.flush()
  93. #读取图片
  94. image_data = tf.gfile.FastGFile(filenames[i], 'r').read()
  95. #获得图片的类别名称
  96. class_name = os.path.basename(os.path.dirname(filenames[i]))
  97. #找到类别名称对应的id
  98. class_id = class_names_to_ids[class_name]
  99. #生成tfrecord文件
  100. example = image_to_tfexample(image_data, b'jpg', class_id)
  101. tfrecord_writer.write(example.SerializeToString())
  102. except IOError as e:
  103. print("Could not read:",filenames[i])
  104. print("Error:",e)
  105. print("Skip it\n")
  106. sys.stdout.write('\n')
  107. sys.stdout.flush()
  108. if __name__ == '__main__':
  109. #判断tfrecord文件是否存在
  110. if _dataset_exists(DATASET_DIR):
  111. print('tfcecord文件已存在')
  112. else:
  113. #获得所有图片以及分类
  114. photo_filenames, class_names = _get_filenames_and_classes(DATASET_DIR)
  115. #把分类转为字典格式,类似于{'house': 3, 'flower': 1, 'plane': 4, 'guitar': 2, 'animal': 0}
  116. class_names_to_ids = dict(zip(class_names, range(len(class_names))))
  117. #把数据切分为训练集和测试集
  118. random.seed(_RANDOM_SEED)
  119. random.shuffle(photo_filenames)
  120. training_filenames = photo_filenames[_NUM_TEST:]
  121. testing_filenames = photo_filenames[:_NUM_TEST]
  122. #数据转换
  123. _convert_dataset('train', training_filenames, class_names_to_ids, DATASET_DIR)
  124. _convert_dataset('test', testing_filenames, class_names_to_ids, DATASET_DIR)
  125. #输出labels文件
  126. labels_to_class_names = dict(zip(range(len(class_names)), class_names))
  127. write_label_file(labels_to_class_names, DATASET_DIR)

得到:




    (4)批处理文件:

        1.在slim里面加入image文件夹,里面放入图片tfrecord文件


        2.在slim文件夹里的datasets文件夹里新建

        程序如下:

        

  1. """Provides data for the flowers dataset.
  2. The dataset scripts used to create the dataset can be found at:
  3. tensorflow/models/slim/datasets/download_and_convert_flowers.py
  4. """
  5. from __future__ import absolute_import
  6. from __future__ import division
  7. from __future__ import print_function
  8. import os
  9. import tensorflow as tf
  10. from datasets import dataset_utils
  11. slim = tf.contrib.slim
  12. _FILE_PATTERN = 'image_%s_*.tfrecord'
  13. SPLITS_TO_SIZES = {'train': 1000, 'test': 500}
  14. _NUM_CLASSES = 5
  15. _ITEMS_TO_DESCRIPTIONS = {
  16. 'image': 'A color image of varying size.',
  17. 'label': 'A single integer between 0 and 4',
  18. }
  19. def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
  20. """Gets a dataset tuple with instructions for reading flowers.
  21. Args:
  22. split_name: A train/validation split name.
  23. dataset_dir: The base directory of the dataset sources.
  24. file_pattern: The file pattern to use when matching the dataset sources.
  25. It is assumed that the pattern contains a '%s' string so that the split
  26. name can be inserted.
  27. reader: The TensorFlow reader type.
  28. Returns:
  29. A `Dataset` namedtuple.
  30. Raises:
  31. ValueError: if `split_name` is not a valid train/validation split.
  32. """
  33. if split_name not in SPLITS_TO_SIZES:
  34. raise ValueError('split name %s was not recognized.' % split_name)
  35. if not file_pattern:
  36. file_pattern = _FILE_PATTERN
  37. file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
  38. # Allowing None in the signature so that dataset_factory can use the default.
  39. if reader is None:
  40. reader = tf.TFRecordReader
  41. keys_to_features = {
  42. 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
  43. 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
  44. 'image/class/label': tf.FixedLenFeature(
  45. [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
  46. }
  47. items_to_handlers = {
  48. 'image': slim.tfexample_decoder.Image(),
  49. 'label': slim.tfexample_decoder.Tensor('image/class/label'),
  50. }
  51. decoder = slim.tfexample_decoder.TFExampleDecoder(
  52. keys_to_features, items_to_handlers)
  53. labels_to_names = None
  54. if dataset_utils.has_labels(dataset_dir):
  55. labels_to_names = dataset_utils.read_label_file(dataset_dir)
  56. return slim.dataset.Dataset(
  57. data_sources=file_pattern,
  58. reader=reader,
  59. decoder=decoder,
  60. num_samples=SPLITS_TO_SIZES[split_name],
  61. items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
  62. num_classes=_NUM_CLASSES,
  63. labels_to_names=labels_to_names)

        然后在datasets文件夹里的里修改:

        

        添加myimages。

           3.在slim目录下新建批处理文件。 其中,train_image_classifier.py是在(1)里下载的文件里的程序,在slim文件夹里面,slim文件夹需要拷贝到当前目录。

  1. D:\Anaconda2\envs\PY3\python E:/graduate_student/deep_learning/models-master/models-master/research/slim/train_image_classifier.py ^
  2. --train_dir=D:/software/mycodes/python3/python35/captcha/model/ ^
  3. --dataset_name=myimages ^
  4. --dataset_split_name=train ^
  5. --dataset_dir=D:/software/mycodes/python3/python35/captcha/image ^
  6. --batch_size=10 ^
  7. --max_number_of_steps=10000 ^
  8. --model_name=inception_v3 ^
  9. pause

    然后运行批处理文件。缓慢运行。


声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/151541
推荐阅读
相关标签
  

闽ICP备14008679号