韦访 20190118


















  1. # encoding:utf-8
  2. from __future__ import absolute_import
  3. from __future__ import division
  4. from __future__ import print_function
  5. import math
  6. import os
  7. import random
  8. import sys
  9. import tensorflow as tf
  10. from datasets import dataset_utils
  11. # The URL where the Flowers data can be downloaded.
  12. _DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'
  13. # The number of images in the validation set.
  14. _NUM_VALIDATION = 350
  15. # Seed for repeatability.
  16. _RANDOM_SEED = 0
  17. # The number of shards per dataset split.
  18. _NUM_SHARDS = 4
  19. class ImageReader(object):
  20. """Helper class that provides TensorFlow image coding utilities."""
  21. def __init__(self):
  22. # Initializes function that decodes RGB JPEG data.
  23. self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
  24. self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
  25. def read_image_dims(self, sess, image_data):
  26. image = self.decode_jpeg(sess, image_data)
  27. return image.shape[0], image.shape[1]
  28. def decode_jpeg(self, sess, image_data):
  29. image = sess.run(self._decode_jpeg,
  30. feed_dict={self._decode_jpeg_data: image_data})
  31. assert len(image.shape) == 3
  32. assert image.shape[2] == 3
  33. return image
  34. def _get_filenames_and_classes(dataset_dir):
  35. """Returns a list of filenames and inferred class names.
  36. Args:
  37. dataset_dir: A directory containing a set of subdirectories representing
  38. class names. Each subdirectory should contain PNG or JPG encoded images.
  39. Returns:
  40. A list of image file paths, relative to `dataset_dir` and the list of
  41. subdirectories, representing class names.
  42. """
  43. # 将flower_photos改为eye_photos
  44. flower_root = os.path.join(dataset_dir, 'eye_open_and_close')
  45. directories = []
  46. class_names = []
  47. for filename in os.listdir(flower_root):
  48. path = os.path.join(flower_root, filename)
  49. if os.path.isdir(path):
  50. directories.append(path)
  51. class_names.append(filename)
  52. photo_filenames = []
  53. for directory in directories:
  54. for filename in os.listdir(directory):
  55. path = os.path.join(directory, filename)
  56. photo_filenames.append(path)
  57. return photo_filenames, sorted(class_names)
  58. def _get_dataset_filename(dataset_dir, split_name, shard_id):
  59. # 修改文件名,将flowersg改为eye
  60. output_filename = 'eye_%s_%05d-of-%05d.tfrecord' % (
  61. split_name, shard_id, _NUM_SHARDS)
  62. return os.path.join(dataset_dir, output_filename)
  63. def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
  64. """Converts the given filenames to a TFRecord dataset.
  65. Args:
  66. split_name: The name of the dataset, either 'train' or 'validation'.
  67. filenames: A list of absolute paths to png or jpg images.
  68. class_names_to_ids: A dictionary from class names (strings) to ids
  69. (integers).
  70. dataset_dir: The directory where the converted datasets are stored.
  71. """
  72. assert split_name in ['train', 'validation']
  73. num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))
  74. with tf.Graph().as_default():
  75. image_reader = ImageReader()
  76. with tf.Session('') as sess:
  77. for shard_id in range(_NUM_SHARDS):
  78. output_filename = _get_dataset_filename(
  79. dataset_dir, split_name, shard_id)
  80. with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
  81. start_ndx = shard_id * num_per_shard
  82. end_ndx = min((shard_id + 1) * num_per_shard, len(filenames))
  83. for i in range(start_ndx, end_ndx):
  84. sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
  85. i + 1, len(filenames), shard_id))
  86. sys.stdout.flush()
  87. # Read the filename:
  88. image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
  89. height, width = image_reader.read_image_dims(sess, image_data)
  90. class_name = os.path.basename(os.path.dirname(filenames[i]))
  91. class_id = class_names_to_ids[class_name]
  92. example = dataset_utils.image_to_tfexample(
  93. image_data, b'jpg', height, width, class_id)
  94. tfrecord_writer.write(example.SerializeToString())
  95. sys.stdout.write('\n')
  96. sys.stdout.flush()
  97. def _clean_up_temporary_files(dataset_dir):
  98. """Removes temporary files used to create the dataset.
  99. Args:
  100. dataset_dir: The directory where the temporary files are stored.
  101. """
  102. filename = _DATA_URL.split('/')[-1]
  103. filepath = os.path.join(dataset_dir, filename)
  104. tf.gfile.Remove(filepath)
  105. # 将flower_photos改为eye_photos
  106. tmp_dir = os.path.join(dataset_dir, 'eye_photos')
  107. tf.gfile.DeleteRecursively(tmp_dir)
  108. def _dataset_exists(dataset_dir):
  109. for split_name in ['train', 'validation']:
  110. for shard_id in range(_NUM_SHARDS):
  111. output_filename = _get_dataset_filename(
  112. dataset_dir, split_name, shard_id)
  113. if not tf.gfile.Exists(output_filename):
  114. return False
  115. return True
  116. def run(dataset_dir):
  117. """Runs the download and conversion operation.
  118. Args:
  119. dataset_dir: The dataset directory where the dataset is stored.
  120. """
  121. if not tf.gfile.Exists(dataset_dir):
  122. tf.gfile.MakeDirs(dataset_dir)
  123. if _dataset_exists(dataset_dir):
  124. print('Dataset files already exist. Exiting without re-creating them.')
  125. return
  126. # 因为我们不需要下载,所以这行注释掉
  127. # dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir)
  128. photo_filenames, class_names = _get_filenames_and_classes(dataset_dir)
  129. class_names_to_ids = dict(zip(class_names, range(len(class_names))))
  130. # Divide into train and test:
  131. random.seed(_RANDOM_SEED)
  132. random.shuffle(photo_filenames)
  133. training_filenames = photo_filenames[_NUM_VALIDATION:]
  134. validation_filenames = photo_filenames[:_NUM_VALIDATION]
  135. # First, convert the training and validation sets.
  136. _convert_dataset('train', training_filenames, class_names_to_ids,
  137. dataset_dir)
  138. _convert_dataset('validation', validation_filenames, class_names_to_ids,
  139. dataset_dir)
  140. # Finally, write the labels file:
  141. labels_to_class_names = dict(zip(range(len(class_names)), class_names))
  142. dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
  143. # 将这行注释掉,要不然转换完以后,原始数据会被删除
  144. # _clean_up_temporary_files(dataset_dir)
  145. print('\nFinished converting the Flowers dataset!')


from datasets import convert_eye


elif FLAGS.dataset_name == 'mnist':


elif FLAGS.dataset_name == 'eye':



python download_and_convert_data.py --dataset_name=eye --dataset_dir=images_data/eye_open_and_close


Instructions for updating:

Use tf.gfile.GFile.

>> Converting image 143/4498 shard 0Traceback (most recent call last):

  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1335, in _do_call

    return fn(*args)

  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1320, in _run_fn

    options, feed_dict, fetch_list, target_list, run_metadata)

  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1408, in _call_tf_sessionrun


tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected image (JPEG, PNG, or GIF), got unknown format starting with '\320\317\021\340\241\261\032\341\000\000\000\000\000\000\000\000'

 [[{{node DecodeJpeg}}]]



tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected image (JPEG, PNG, or GIF), got unknown format starting with








继续修改代码,将datasets/flowers.py复制并重命名为eye.py ,将

_FILE_PATTERN = 'flowers_%s_*.tfrecord'


_FILE_PATTERN = 'eye_%s_*.tfrecord'

SPLITS_TO_SIZES = {'train': 3320, 'validation': 350}


SPLITS_TO_SIZES = {'train': 4496, 'validation': 350}


  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Provides data for the flowers dataset.
  16. The dataset scripts used to create the dataset can be found at:
  17. tensorflow/models/research/slim/datasets/download_and_convert_flowers.py
  18. """
  19. from __future__ import absolute_import
  20. from __future__ import division
  21. from __future__ import print_function
  22. import os
  23. import tensorflow as tf
  24. from datasets import dataset_utils
  25. slim = tf.contrib.slim
  26. _FILE_PATTERN = 'eye_%s_*.tfrecord'
  27. SPLITS_TO_SIZES = {'train': 4496, 'validation': 350}
  28. _NUM_CLASSES = 2
  30. 'image': 'A color image of varying size.',
  31. 'label': 'A single integer between 0 and 4',
  32. }
  33. def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
  34. """Gets a dataset tuple with instructions for reading flowers.
  35. Args:
  36. split_name: A train/validation split name.
  37. dataset_dir: The base directory of the dataset sources.
  38. file_pattern: The file pattern to use when matching the dataset sources.
  39. It is assumed that the pattern contains a '%s' string so that the split
  40. name can be inserted.
  41. reader: The TensorFlow reader type.
  42. Returns:
  43. A `Dataset` namedtuple.
  44. Raises:
  45. ValueError: if `split_name` is not a valid train/validation split.
  46. """
  47. if split_name not in SPLITS_TO_SIZES:
  48. raise ValueError('split name %s was not recognized.' % split_name)
  49. if not file_pattern:
  50. file_pattern = _FILE_PATTERN
  51. file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
  52. # Allowing None in the signature so that dataset_factory can use the default.
  53. if reader is None:
  54. reader = tf.TFRecordReader
  55. keys_to_features = {
  56. 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
  57. 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
  58. 'image/class/label': tf.FixedLenFeature(
  59. [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
  60. }
  61. items_to_handlers = {
  62. 'image': slim.tfexample_decoder.Image(),
  63. 'label': slim.tfexample_decoder.Tensor('image/class/label'),
  64. }
  65. decoder = slim.tfexample_decoder.TFExampleDecoder(
  66. keys_to_features, items_to_handlers)
  67. labels_to_names = None
  68. if dataset_utils.has_labels(dataset_dir):
  69. labels_to_names = dataset_utils.read_label_file(dataset_dir)
  70. return slim.dataset.Dataset(
  71. data_sources=file_pattern,
  72. reader=reader,
  73. decoder=decoder,
  74. num_samples=SPLITS_TO_SIZES[split_name],
  75. items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
  76. num_classes=_NUM_CLASSES,
  77. labels_to_names=labels_to_names)



from datasets import cifar10
from datasets import flowers
from datasets import imagenet
from datasets import mnist

datasets_map = {
    'cifar10': cifar10,
    'flowers': flowers,
    'imagenet': imagenet,
    'mnist': mnist,


from datasets import cifar10
from datasets import flowers
from datasets import imagenet
from datasets import mnist
from datasets import eye

datasets_map = {
    'cifar10': cifar10,
    'flowers': flowers,
    'imagenet': imagenet,
    'mnist': mnist,
    'eye': eye,


python train_image_classifier.py   --train_dir=saver/inv3_eye_open_and_close   --dataset_name=eye   --dataset_split_name=train   --dataset_dir=images_data/eye_open_and_close   --model_name=inception_v3       --learning_rate_decay_type=fixed   --save_interval_secs=60  --save_summaries_secs=60   --log_every_n_steps=10   --optimizer=rmsprop   --learning_rate=0.0001





python eval_image_classifier.py   --checkpoint_path=saver/inv3_eye_open_and_close/   --eval_dir=saver/inv3_eye_open_and_close/   --dataset_name=eye  --dataset_split_name=validation  --dataset_dir=images_data/eye_open_and_close  --model_name=inception_v3   --batch_size=64




