当前位置:   article > 正文

tensorflow读取自己的数据集_tensorflow 读取自己的数据集

tensorflow 读取自己的数据集
注意:
  • 我自己是用来训练医学图像生成新的医学图像的,所以我的label都是1(或者其他数值),因为在GAN训练的时候完全用不到,只有分类的问题才能用到。
  • 如果训练分类问题下文中的代码就没用了,不过可以参考这个链接
  • 我的图片是黑白的,所以channel为1,并且图片保存的格式为png的,所以大家理解的时候请注意image=tf.image.decode_png(image_contents,channels=1)
  • 代码的输入是png图片,通道数为1,输出为batch_size的tensor(shape=[batch_size, image_W, image_H, 1]),但是绝对不能用到feed_dict里面,因为feed_dict只接收numpy格式的数据格式,所以在接下来的博客中,本博主要在上面下功夫!
  • tensorflow中 tf.train.slice_input_producer函数可以参考这个链接,和这个链接讲的非常详细。
  • 初始化工作相当重要
输出其中一个样例:

这里写图片描述

# -*- coding: utf-8 -*-
"""
Created on Thu Jul 19 15:40:11 2018
E-mail: Eric2014_Lv@sjtu.edu.cn
@author: DidiLv
"""



import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt

#训练样本在本地磁盘中的地址
file_dir='D:\\CT_data\\Data_preprocessing\\' # 这里是输入数据的地址


def get_files(file_dir):
    lung_img = [];
    label_lung_img = [];
    for file in os.listdir(file_dir):
        lung_img.append( file_dir + file)
        label_lung_img.append(1)

    image_list = np.hstack((lung_img))

    label_list = np.hstack((label_lung_img))

    temp = np.array([lung_img, label_lung_img]).T
    #利用shuffle打乱数据
    np.random.shuffle(temp)
    image_list = list(temp[:,0])
    label_list = list(temp[:,1])
    label_list = [int(i) for i in label_list]
    return image_list, label_list
#       
#将上面生成的List传入get_batch() ,转换类型,产生一个输入队列queue,因为img和lab  
#是分开的,所以使用tf.train.slice_input_producer(),然后用tf.read_file()从队列中读取图像
def get_batch(image,label,batch_size):

    image_W, image_H = 221, 181 

    #将python.list类型转换成tf能够识别的格式
    image=tf.cast(image,tf.string)
    label=tf.cast(label,tf.int32)

    #产生一个输入队列queue
    epoch_num = 50 #防止无限循环
    input_queue=tf.train.slice_input_producer([image,label], num_epochs=epoch_num)

    label=input_queue[1]
    image_contents=tf.read_file(input_queue[0])
    #将图像解码,不同类型的图像不能混在一起,要么只用jpeg,要么只用png等。  
    image=tf.image.decode_png(image_contents,channels=1)

    #将数据预处理,对图像进行旋转、缩放、裁剪、归一化等操作,让计算出的模型更健壮。
    image=tf.image.resize_image_with_crop_or_pad(image,image_W,image_H)
    image=tf.image.per_image_standardization(image)

    #生成batch
    min_after_dequeue=1000
    capacity=min_after_dequeue+3*batch_size
    image_batch,label_batch=tf.train.shuffle_batch([image,label],batch_size=batch_size,num_threads=64,capacity=capacity,min_after_dequeue=min_after_dequeue)

    #重新排列标签,行数为[batch_size]
#    label_batch=tf.reshape(label_batch,[batch_size])
    image_batch = tf.reshape(image_batch,[batch_size,image_W,image_H,1])
    image_batch=tf.cast(image_batch,np.float32)

    return image_batch, label_batch

if __name__ == "__main__":
    image_list, label_list = get_files(file_dir)
    image_batch, label_batch = get_batch(image_list, label_list, 64)
    with tf.Session() as sess:
        ## 初始化工作,相当重要
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        i = 0
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        try:
            while not coord.should_stop() and i<1: # 加上i其实是强制终止线程,但是文件队列线程并没有结束,因为我们产生的文件队列结束为epoch_num个epoch

                img, label = sess.run([image_batch, label_batch])

                # just test one batch
                for j in np.arange(64):
                    print('label: %d' %label[j])
                    plt.imshow(img[j,:,:,0])
                    plt.show()
                i+=1

        except tf.errors.OutOfRangeError:
            print('done!')
        finally:
            coord.request_stop()
            print('-----------')
        coord.join(threads)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/123859
推荐阅读
相关标签
  

闽ICP备14008679号