当前位置:   article > 正文

实例7:将图片文件制作成Dataset数据集_sklearn将图片转成数据集

sklearn将图片转成数据集

实例7:将图片文件制作成Dataset数据集

在图片训练过程中,一个变形丰富的数据集会使模型的精度与泛化性成倍提升

1. 代码实现:读取样本文件的目录与标签

定义load_sample函数,用来将样本图片的目录名称与对应的标签读入内存。

import os
import tensorflow as tf
from PIL import Image
import numpy as np
from tqdm import tqdm
from sklearn.utils import shuffle

def load_sample(sample_dir, shuffle):
    print("loading sample dataset...")
    lfilenames = []
    labelsnames = []
    for (dirpath, dirnames, filenames) in os.walk(sample_dir):
        for filename in filenames :
            filename_path = os.sep.join([dirpath, filename])
            lfilenames.append(filename_path)
            labelsnames.append(dirpath.split('\\')[-1])
    lab = list(sorted(set(labelsnames)))
    labdict = dict(zip(lab, list(range(len(lab)))))
    labels = [labdict[i] for i in labelsnames]
    if shuffle == True:
        return shuffle(np.asarray(lfilenames), np.asarray(labels)), np.asarray(lab)
    else:
        return np.asarray(lfilenames), np.asarray(labels), np.asarray(lab)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

2. 代码实现:定义函数,实现函数转换操作

定义函数_distorted_image,用TensorFlow自带的API实现单一图片的变换处理

def _distorted_image(image, size, ch=1, shuffleflag=False, cropflag=False, brightnessflag=False, contrastflag=False):
    distorted_image =tf.image.random_flip_left_right(image)

    if cropflag == True:                                                #随机裁剪
        s = tf.random_uniform((1,2),int(size[0]*0.8),size[0],tf.int32)
        distorted_image = tf.random_crop(distorted_image, [s[0][0],s[0][0],ch])

    distorted_image = tf.image.random_flip_up_down(distorted_image)#上下随机翻转
    if brightnessflag == True:#随机变化亮度
        distorted_image = tf.image.random_brightness(distorted_image,max_delta=10)
    if contrastflag == True:   #随机变化对比度
        distorted_image = tf.image.random_contrast(distorted_image,lower=0.2, upper=1.8)
    if shuffleflag==True:
        distorted_image = tf.random_shuffle(distorted_image)#沿着第0维乱序
    return distorted_image
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

3. 代码实现:用自定义函数实现图片归一化

def _norm_image(image,size,ch=1,flattenflag = False):    #定义函数,实现归一化,并且拍平
    image_decoded = image/255.0
    if flattenflag==True:
        image_decoded = tf.reshape(image_decoded, [size[0]*size[1]*ch])
    return image_decoded
  • 1
  • 2
  • 3
  • 4
  • 5

本实例将图片的值域变成0~1之间的小数,实际开发中,也可以将图片的值域变成-1-1之间的小数

4. 代码实现:用第三方函数将图片旋转30度

定义函数random_rotated30实现图片旋转功能,用skimage库函数将图片旋转30度
在整个数据集的处理流程中,对图片的操作丢失基于张量进行的,所以第三方函数无法操作TensorFlow中张量,所以需要额外的封装
用tf.py_function函数可以将第三方 库函数成一个TensorFlow的中操作符(op)

from skimage import transform
def _random_rotated30(image, label): #定义函数实现图片随机旋转操作
    
    def _rotated(image):                #封装好的skimage模块,来进行图片旋转30度
        shift_y, shift_x = np.array(image.shape.as_list()[:2],np.float32) / 2.
        tf_rotate = transform.SimilarityTransform(rotation=np.deg2rad(30))
        tf_shift = transform.SimilarityTransform(translation=[-shift_x, -shift_y])
        tf_shift_inv,image.size = transform.SimilarityTransform(translation=[shift_x, shift_y]),image.shape#兼容transform函数
        image_rotated = transform.warp(image, (tf_shift + (tf_rotate + tf_shift_inv)).inverse)
        return image_rotated

    def _rotatedwrap():
        image_rotated = tf.py_function( _rotated,[image],[tf.float64])   #调用第三方函数
        return tf.cast(image_rotated,tf.float32)[0]

    a = tf.random_uniform([1],0,2,tf.int32)#实现随机功能
    image_decoded = tf.cond(tf.equal(tf.constant(0),a[0]),lambda: image,_rotatedwrap)

    return image_decoded, label
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

使用TensorFlow中的tf.cond方法,用来根据随机条件判断是否需要对本次图片进行旋转

5. 代码实现:定义函数,生成Dataset对象

咋dataset函数转给你,用内置函数_parseone将所有文件名转化为具体的图片内容,并返回Dataset队形

def dataset(directory,size,batchsize,random_rotated=False):#定义函数,创建数据集
    (filenames,labels),_ =load_sample(directory,shuffleflag=False) #载入文件名称与标签
    def _parseone(filename, label):                         #解析一个图片文件
        image_string = tf.read_file(filename)         #读取整个文件
        image_decoded = tf.image.decode_image(image_string)
        image_decoded.set_shape([None, None, None])    # 必须有这句,不然下面会转化失败
        image_decoded = _distorted_image(image_decoded,size)#对图片做扭曲变化
        image_decoded = tf.image.resize(image_decoded, size)  #变化尺寸
        image_decoded = _norm_image(image_decoded,size)#归一化
        image_decoded = tf.cast(image_decoded,dtype=tf.float32)
        label = tf.cast(  tf.reshape(label, []) ,dtype=tf.int32  )#将label 转为张量
        return image_decoded, label

    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))#生成Dataset对象
    dataset = dataset.map(_parseone)   #有图片内容的数据集

    if random_rotated == True:
        dataset = dataset.map(_random_rotated30)

    dataset = dataset.batch(batchsize) #批次划分数据集

    return dataset
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

5. 代码实现:建立会话,输出数据

def showresult(subplot,title,thisimg):          #显示单个图片
    p =plt.subplot(subplot)
    p.axis('off')
    p.imshow(thisimg)
    p.set_title(title)

def showimg(index,label,img,ntop):   #显示
    plt.figure(figsize=(20,10))     #定义显示图片的宽、高
    plt.axis('off')
    ntop = min(ntop,9)
    print(index)
    for i in range (ntop):
        showresult(100+10*ntop+1+i,label[i],img[i])
    plt.show()

def getone(dataset):
    iterator = dataset.make_one_shot_iterator()			#生成一个迭代器
    one_element = iterator.get_next()					#从iterator里取出一个元素
    return one_element

sample_dir="man_woman"
size = [96,96]
batchsize = 10
tdataset = dataset(sample_dir,size,batchsize)
tdataset2 = dataset(sample_dir,size,batchsize,True)
print(tdataset.output_types)  #打印数据集的输出信息
print(tdataset.output_shapes)

one_element1 = getone(tdataset)				#从tdataset里取出一个元素
one_element2 = getone(tdataset2)				#从tdataset2里取出一个元素


with tf.Session() as sess:	# 建立会话(session)
    sess.run(tf.global_variables_initializer())  #初始化

    try:
        for step in np.arange(1):
            value = sess.run(one_element1)
            value2 = sess.run(one_element2)

            showimg(step,value[1],np.asarray( value[0]*255,np.uint8),10)       #显示图片
            #showimg(step,value2[1],np.asarray( value2[0]*255,np.uint8),10)       #显示图片


    except tf.errors.OutOfRangeError:           #捕获异常
        print("Done!!!")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/372565
推荐阅读
相关标签
  

闽ICP备14008679号