当前位置:   article > 正文

tensorflow mnist实战笔记(二)制作和读取自己的数据集_mnist数据集制作

mnist数据集制作

这里面写的非常详细

http://www.itdadao.com/articles/c15a1401577p0.html

看了网上N多的教程,发现mnist的教程的数据都是官网已经制作好的,那么如果我们自己有数字图片,我们该怎么利用tensoeflow制作数据呢?

现在我有6万张训练集,1万张测试集,下载地址在这mnist图片数据下载:http://pan.baidu.com/s/1pLMV4Kz

首先我们需要有图片数据的txt表,以及对应的标签,如下所示,制作txt表在caffe中已经提到,传送门

mnist/train/5/00000.png 5
mnist/train/0/00001.png 0
mnist/train/4/00002.png 4
mnist/train/1/00003.png 1 下面这串代码就可以在原路径得到a.tfrecords文件

  1. import numpy as np
  2. import cv2
  3. import tensorflow as tf
  4. resize_height=28 #存储图片高度
  5. resize_width=28 #存储图片宽度
  6. train_file_root = '/home/hjxu/PycharmProjects/tf_examples/hjxu_mnist/mnist_img_data'
  7. train_file = train_file_root+'/train.txt' #trainfile是txt文件存放的目录
  8. def _int64_feature(value):#将value转化成int64字节属性,
  9. return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
  10. def _bytes_feature(value):#将value转化成bytes属性
  11. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
  12. def load_file(examples_list_file):
  13. # type: (object) -> object
  14. lines = np.genfromtxt(examples_list_file, delimiter=" ", dtype=[('col1', 'S120'), ('col2', 'i8')])
  15. examples = []
  16. labels = []
  17. for example, label in lines:
  18. examples.append(example)
  19. labels.append(label)
  20. return np.asarray(examples), np.asarray(labels), len(lines)
  21. ##load_file函数返回的examples,labels,lines,比如examples[0]指的是mnist/train/4/00002.png,也就是 txt的路径 labels[0]返回的是对应的label值是4
  22. def extract_image(filename, resize_height, resize_width): #这边调用cv2.imread()来读取图像,由于cv2读取是BGR格式,需要转换成RGB格式
  23. image = cv2.imread(filename)
  24. image = cv2.resize(image, (resize_height, resize_width))
  25. b,g,r = cv2.split(image)
  26. rgb_image = cv2.merge([r,g,b]) # this is suitable
  27. rgb_image = rgb_image / 255.
  28. rgb_image = rgb_image.astype(np.float32)
  29. return rgb_image
  30. examples, labels, examples_num = load_file(train_file)
  31. writer = tf.python_io.TFRecordWriter('/home/hjxu/PycharmProjects/tf_examples/hjxu_mnist/a.tfrecords')
  32. # root = train_file_root + '/' + examples[0]
  33. for i, [example, label] in enumerate(zip(examples, labels)):
  34. print('No.%d' % (i))
  35. root = train_file_root + '/' + examples[i]
  36. image = extract_image(root, resize_height, resize_width)
  37. a = image.shape
  38. print(root)
  39. print('shape: %d, %d, %d, label: %d' % (image.shape[0], image.shape[1], image.shape[2], label))
  40. image_raw = image.tostring() #将Image转化成字符
  41. example = tf.train.Example(features=tf.train.Features(feature={
  42. 'image_raw': _bytes_feature(image_raw),
  43. 'height': _int64_feature(image.shape[0]),
  44. 'width': _int64_feature(image.shape[1]),
  45. 'depth': _int64_feature(image.shape[2]),
  46. 'label': _int64_feature(label)
  47. }))
  48. writer.write(example.SerializeToString())
  49. writer.close()
上面代码最重要的是

  1. example = tf.train.Example(features=tf.train.Features(feature={
  2. 'image_raw': _bytes_feature(image_raw),
  3. 'height': _int64_feature(image.shape[0]),
  4. 'width': _int64_feature(image.shape[1]),
  5. 'depth': _int64_feature(image.shape[2]),
  6. 'label': _int64_feature(label)
  7. }))
这一段,我们可以看出,a.tfrecords里面其实对应的是一些字典,比如Image_raw对应的是图像矩阵本身保存的字节文件,height则是则是对应的高,其实height什么的不写进去也没事,但label一定要写。

现在我们可以得到a.tfrecords这个文件,我们该怎么解析里面的内容呢?或者我们该怎么将tfrecords里面的二进制文件转换成我们可以可视化的数字图片呢

下面这串代码可以得出

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import tensorflow as tf
  4. tfrecord_list_file = '/home/hjxu/PycharmProjects/tf_examples/hjxu_mnist/a.tfrecords'
  5. def read_and_decode(filename_queue,shuffle_batch=True):
  6. reader = tf.TFRecordReader()
  7. _, serialized_example = reader.read(filename_queue)
  8. features = tf.parse_single_example(serialized_example, features={
  9. 'image_raw': tf.FixedLenFeature([], tf.string),
  10. 'label': tf.FixedLenFeature([], tf.int64)
  11. })
  12. image = tf.decode_raw(features['image_raw'], tf.float32)
  13. image = tf.reshape(image, [28, 28, 3])
  14. image = image * 255.0
  15. labels = features['label']
  16. if shuffle_batch:
  17. images, labels = tf.train.shuffle_batch(
  18. [image,labels],
  19. batch_size=4,
  20. capacity=8000,
  21. num_threads=4,
  22. min_after_dequeue=2000)
  23. else:
  24. images,labels = tf.train.batch([image,labels],
  25. batch_size=4,
  26. capacity=8000,
  27. num_threads=4)
  28. return images,labels
  29. def test_run(tfrecord_filename):
  30. filename_queue = tf.train.string_input_producer([tfrecord_filename],
  31. num_epochs=3)
  32. images,labs = read_and_decode(filename_queue)
  33. init_op = tf.group(tf.global_variables_initializer(),
  34. tf.local_variables_initializer())
  35. # meanfile = sio.loadmat(root_path + 'mats/mean300.mat')
  36. # meanvalue = meanfile['mean'] #如果在制作数据时减去的均值,则需要加上来
  37. with tf.Session() as sess:
  38. sess.run(init_op)
  39. coord = tf.train.Coordinator()
  40. threads = tf.train.start_queue_runners(coord=coord)
  41. for i in range(1):
  42. imgs,labs = sess.run([images,labs])
  43. print 'batch' + str(i) + ': '
  44. #print type(imgs[0])
  45. for j in range(4):
  46. print str(labs[j])
  47. img = np.uint8(imgs[j] )
  48. plt.subplot(4, 2, j * 2 + 1)
  49. plt.imshow(img)
  50. plt.show()
  51. coord.request_stop()
  52. coord.join(threads) #注意,要关闭文件
  53. test_run('/home/hjxu/PycharmProjects/tf_examples/hjxu_mnist/a.tfrecords')
  54. print ("has done")

主要用到tf.decode_raw,这个内置函数的意思是解析 tfrecords文件里的二进制数据,我的read_and_decode只返回图像和label,所以只需要用到tfrecords里面的image_raw和label

  1. image = tf.decode_raw(features['image_raw'], tf.float32) #解析image_raw数据,注意,tf.float32是数据类型,一定要和制作数据时用的类型一样
  2. image = tf.reshape(image, [28, 28, 3])
  3. image = image * 255.0 #我在制作数据时除了255,这里可以补回来或者不补
  4. labels = features['label'] #label则是对应的标签

目前了解的也就这么多,基本都是从其他博客整理得到的,下面是参考博客

cv2.imread()和caffe.io.loadimage的区别

TFRecords 文件的生成和读取

【TensorFlow动手玩】数据导入2

由浅入深之Tensorflow(3)----数据读取之TFRecords

TensorFlow的reshape操作 tf.reshape

Tensorflow之构建自己的图片数据集TFrecords


http://blog.csdn.net/u010358677/article/details/70544241



声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/123712
推荐阅读
相关标签
  

闽ICP备14008679号