当前位置:   article > 正文

Tensorflow之构建自己的图片数据集TFrecords_build_dataset_from_tfrecord

build_dataset_from_tfrecord

  学习谷歌的深度学习终于有点眉目了,给大家分享我的Tensorflow学习历程。

   tensorflow的官方中文文档比较生涩,数据集一直采用的MNIST二进制数据集。并没有过多讲述怎么构建自己的图片数据集tfrecords。

   先贴我的转化代码将图片文件夹下的图片转存tfrecords的数据集。

[python]  view plain  copy
  1. ############################################################################################  
  2. #!/usr/bin/python2.7  
  3. # -*- coding: utf-8 -*-  
  4. #Author  : zhaoqinghui  
  5. #Date    : 2016.5.10  
  6. #Function: image convert to tfrecords   
  7. #############################################################################################  
  8.   
  9. import tensorflow as tf  
  10. import numpy as np  
  11. import cv2  
  12. import os  
  13. import os.path  
  14. from PIL import Image  
  15.   
  16. #参数设置  
  17. ###############################################################################################  
  18. train_file = 'train.txt' #训练图片  
  19. name='train'      #生成train.tfrecords  
  20. output_directory='./tfrecords'  
  21. resize_height=32 #存储图片高度  
  22. resize_width=32 #存储图片宽度  
  23. ###############################################################################################  
  24. def _int64_feature(value):  
  25.     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))  
  26.   
  27. def _bytes_feature(value):  
  28.     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))  
  29.   
  30. def load_file(examples_list_file):  
  31.     lines = np.genfromtxt(examples_list_file, delimiter=" ", dtype=[('col1''S120'), ('col2''i8')])  
  32.     examples = []  
  33.     labels = []  
  34.     for example, label in lines:  
  35.         examples.append(example)  
  36.         labels.append(label)  
  37.     return np.asarray(examples), np.asarray(labels), len(lines)  
  38.   
  39. def extract_image(filename,  resize_height, resize_width):  
  40.     image = cv2.imread(filename)  
  41.     image = cv2.resize(image, (resize_height, resize_width))  
  42.     b,g,r = cv2.split(image)         
  43.     rgb_image = cv2.merge([r,g,b])       
  44.     return rgb_image  
  45.   
  46. def transform2tfrecord(train_file, name, output_directory, resize_height, resize_width):  
  47.     if not os.path.exists(output_directory) or os.path.isfile(output_directory):  
  48.         os.makedirs(output_directory)  
  49.     _examples, _labels, examples_num = load_file(train_file)  
  50.     filename = output_directory + "/" + name + '.tfrecords'  
  51.     writer = tf.python_io.TFRecordWriter(filename)  
  52.     for i, [example, label] in enumerate(zip(_examples, _labels)):  
  53.         print('No.%d' % (i))  
  54.         image = extract_image(example, resize_height, resize_width)  
  55.         print('shape: %d, %d, %d, label: %d' % (image.shape[0], image.shape[1], image.shape[2], label))  
  56.         image_raw = image.tostring()  
  57.         example = tf.train.Example(features=tf.train.Features(feature={  
  58.             'image_raw': _bytes_feature(image_raw),  
  59.             'height': _int64_feature(image.shape[0]),  
  60.             'width': _int64_feature(image.shape[1]),  
  61.             'depth': _int64_feature(image.shape[2]),  
  62.             'label': _int64_feature(label)  
  63.         }))  
  64.         writer.write(example.SerializeToString())  
  65.     writer.close()  
  66.   
  67. def disp_tfrecords(tfrecord_list_file):  
  68.     filename_queue = tf.train.string_input_producer([tfrecord_list_file])  
  69.     reader = tf.TFRecordReader()  
  70.     _, serialized_example = reader.read(filename_queue)  
  71.     features = tf.parse_single_example(  
  72.         serialized_example,  
  73.  features={  
  74.           'image_raw': tf.FixedLenFeature([], tf.string),  
  75.           'height': tf.FixedLenFeature([], tf.int64),  
  76.           'width': tf.FixedLenFeature([], tf.int64),  
  77.           'depth': tf.FixedLenFeature([], tf.int64),  
  78.           'label': tf.FixedLenFeature([], tf.int64)  
  79.       }  
  80.     )  
  81.     image = tf.decode_raw(features['image_raw'], tf.uint8)  
  82.     #print(repr(image))  
  83.     height = features['height']  
  84.     width = features['width']  
  85.     depth = features['depth']  
  86.     label = tf.cast(features['label'], tf.int32)  
  87.     init_op = tf.initialize_all_variables()  
  88.     resultImg=[]  
  89.     resultLabel=[]  
  90.     with tf.Session() as sess:  
  91.         sess.run(init_op)  
  92.         coord = tf.train.Coordinator()  
  93.         threads = tf.train.start_queue_runners(sess=sess, coord=coord)  
  94.         for i in range(21):  
  95.             image_eval = image.eval()  
  96.             resultLabel.append(label.eval())  
  97.             image_eval_reshape = image_eval.reshape([height.eval(), width.eval(), depth.eval()])  
  98.             resultImg.append(image_eval_reshape)  
  99.             pilimg = Image.fromarray(np.asarray(image_eval_reshape))  
  100.             pilimg.show()  
  101.         coord.request_stop()  
  102.         coord.join(threads)  
  103.         sess.close()  
  104.     return resultImg,resultLabel  
  105.   
  106. def read_tfrecord(filename_queuetemp):  
  107.     filename_queue = tf.train.string_input_producer([filename_queuetemp])  
  108.     reader = tf.TFRecordReader()  
  109.     _, serialized_example = reader.read(filename_queue)  
  110.     features = tf.parse_single_example(  
  111.         serialized_example,  
  112.         features={  
  113.           'image_raw': tf.FixedLenFeature([], tf.string),  
  114.           'width': tf.FixedLenFeature([], tf.int64),  
  115.           'depth': tf.FixedLenFeature([], tf.int64),  
  116.           'label': tf.FixedLenFeature([], tf.int64)  
  117.       }  
  118.     )  
  119.     image = tf.decode_raw(features['image_raw'], tf.uint8)  
  120.     # image  
  121.     tf.reshape(image, [2562563])  
  122.     # normalize  
  123.     image = tf.cast(image, tf.float32) * (1. /255) - 0.5  
  124.     # label  
  125.     label = tf.cast(features['label'], tf.int32)  
  126.     return image, label  
  127.   
  128. def test():  
  129.     transform2tfrecord(train_file, name , output_directory,  resize_height, resize_width) #转化函数     
  130.     img,label=disp_tfrecords(output_directory+'/'+name+'.tfrecords'#显示函数  
  131.     img,label=read_tfrecord(output_directory+'/'+name+'.tfrecords'#读取函数  
  132.     print label  
  133.   
  134. if __name__ == '__main__':  
  135.     test()  


这样就可以得到自己专属的数据集.tfrecords了  ,它可以直接用于tensorflow的数据集。


tensorflow之使用shell脚本定义自己的图片标签

深度学习tensorflow一直认为给自己的数据库图片打上分类标签是一种费力事。

所以就尝试着写一些shell脚本来进行数据的脚本操作。以下代码是我为深度学习自动标签分类的代码。

[plain]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. ############################################  
  2. #! /bin/bash -  
  3. #Author:zhaoqinghui  
  4. #Data  :2016.05.11   
  5. #============ get the file name ===========   
  6. INIT_PATH="./image";  
  7. #==========================================  
  8.   
  9. function ergodic(){  
  10.   for file in `ls $1`  
  11.   do  
  12.     if [ -d $1"/"$file ]  
  13.     then  
  14.       ergodic $1"/"$file  
  15.     else  
  16.       local path=$1"/"$file   
  17.       local name=$file        
  18.       local size=`du --max-depth=1 $path|awk '{print $1}'`   
  19.       #echo $name  $size $path   
  20.       label=`echo ${path%/*}`  
  21.       #label=$`echo ${label%/*}`  
  22.       #label=$`echo ${label#*/}`  
  23.       label=`echo ${label#*/}`  
  24.       label=`echo ${label#*/}`  
  25.       echo  $path $label >> mytrain.txt  
  26.       echo 'save succuss'  $path  
  27.     fi  
  28.   done  
  29. }  
  30. IFS=$'\n' #防止有空格时出错  
  31. ergodic $INIT_PATH  


声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/123749
推荐阅读
相关标签
  

闽ICP备14008679号