当前位置:   article > 正文

使用tensorflow训练自己的数据集(一)——制作数据集

tensorflow训练自己的数据集

使用tensorflow训练自己的数据集—制作数据集

想记录一下自己制作训练集并训练的过、希望踩过的坑能帮助后面入坑的人。
本次使用的训练集的是kaggle中经典的猫狗大战数据集(提取码:ufz5)。因为本人笔记本配置很差还不是N卡所以把train的数据分成了训练集和测试集并没有使用原数据集中的test。在tensorflow中使用TFRecord格式喂给神经网络但是现在官方推荐使用tf.data但这个API还没看所以还是使用了TFRecord。
文件目录
代码注释还挺清楚就直接上代码了。

import os
import tensorflow as tf
from PIL import Image
# 源数据地址
cwd = 'C:/Users/Qigq/Desktop/P_Data/kaggle/train'
# 生成record路径及文件名
train_record_path = "C:/Users/Qigq/Desktop/P_Data/kaggle/ouputdata/train.tfrecords"
test_record_path = "C:/Users/Qigq/Desktop/P_Data/kaggle/ouputdata/test.tfrecords"
# 分类
classes = {'cat','dog'}

def _byteslist(value):
    """二进制属性"""
    return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))

def _int64list(value):
    """整数属性"""
    return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))

def create_train_record():
    """创建训练集tfrecord"""
    writer = tf.python_io.TFRecordWriter(train_record_path)     # 创建一个writer
    NUM = 1                                     # 显示创建过程(计数)
    for index, name in enumerate(classes):
        class_path = cwd + "/" + name + '/'
        l = int(len(os.listdir(class_path)) * 0.7)      # 取前70%创建训练集
        for img_name in os.listdir(class_path)[:l]:
            img_path = class_path + img_name
            img = Image.open(img_path)
            img = img.resize((128, 128))                # resize图片大小
            img_raw = img.tobytes()                     # 将图片转化为原生bytes
            example = tf.train.Example(                 # 封装到Example中
                features=tf.train.Features(feature={
                    "label":_int64list(index),          # label必须为整数类型属性
                    'img_raw':_byteslist(img_raw)       # 图片必须为二进制属性
                }))
            writer.write(example.SerializeToString())
            print('Creating train record in ',NUM)
            NUM += 1
    writer.close()                                      # 关闭writer
    print("Create train_record successful!")

def create_test_record():
    """创建测试tfrecord"""
    writer = tf.python_io.TFRecordWriter(test_record_path)
    NUM = 1
    for index, name in enumerate(classes):
        class_path = cwd + '/' + name + '/'
        l = int(len(os.listdir(class_path)) * 0.7)
        for img_name in os.listdir(class_path)[l:]:     # 剩余30%作为测试集
            img_path = class_path + img_name
            img = Image.open(img_path)
            img = img.resize((128, 128))
            img_raw = img.tobytes()  # 将图片转化为原生bytes
            # print(index,img_raw)
            example = tf.train.Example(
                features=tf.train.Features(feature={
                    "label":_int64list(index),
                    'img_raw':_byteslist(img_raw)
                }))
            writer.write(example.SerializeToString())
            print('Creating test record in ',NUM)
            NUM += 1
    writer.close()
    print("Create test_record successful!")

def read_record(filename):
    """读取tfrecord"""
    filename_queue = tf.train.string_input_producer([filename])     # 创建文件队列
    reader = tf.TFRecordReader()                                    # 创建reader
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'img_raw': tf.FixedLenFeature([], tf.string)
        }
    )
    label = features['label']
    img = features['img_raw']
    img = tf.decode_raw(img, tf.uint8)
    img = tf.reshape(img, [128, 128, 3])
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5       # 归一化
    label = tf.cast(label, tf.int32)
    return img, label

def get_batch_record(filename,batch_size):
    """获取batch"""
    image,label = read_record(filename)
    image_batch,label_batch = tf.train.shuffle_batch([image,label],         # 随机抽取batch size个image、label
                                                     batch_size=batch_size,
                                                     capacity=2000,
                                                     min_after_dequeue=1000)
    return image_batch,label_batch

def main():
    create_train_record()
    create_test_record()
if __name__ == '__main__':
    main()

                                ### 调用示例 ###
# create_train_record(cwd,classes)
# create_test_record(cwd,classes)
# image_batch,label_batch = get_batch_record(filename,32)
# init = tf.initialize_all_variables()
#
# with tf.Session() as sess:
#     sess.run(init)
#
#     coord = tf.train.Coordinator()
#     threads = tf.train.start_queue_runners(sess=sess,coord=coord)
#
#     for i in range(1):
#         image,label = sess.run([image_batch,label_batch])
#         print(image.shape,1)
#
#
#     coord.request_stop()
#     coord.join(threads)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120

下一篇将介绍定义神经网络
如有错误望多多指教~~

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/123721
推荐阅读
相关标签
  

闽ICP备14008679号