当前位置:   article > 正文

Tensorflow框架 —— 训练数据读取的几种方式_训练数据 流式读取

训练数据 流式读取

1. 概述

深度学习训练的数据往往非常巨大,如何保证高效的加载数据,也是提升训练速度的关键。本文主要介绍如下几种训练数据加载方式:

  • 迭代器(iter() and next())
  • 多线程+队列
  • TFRecord

2. 迭代器

class IterTest():
    def __init__(self, data=1):
        self.data = data

    def __iter__(self):
 		# 表明 IterTest是可迭代类
        return self
        
    def __next__(self):
    	# 迭代器的具体实现
        if self.data > 5:
            raise StopIteration
        else:
            self.data += 1
            return self.data

for epoch in range(4):
    s = IterTest(3)
    for item in s:
        print(item)
# output:
	4
	5
	6
	==
	4
	5
	6
	==
	4
	5
	6
	==
	4
	5
	6
	==
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

for … in …: 该循环实现两件事,第一件事是获得一个可迭代器,即调用了__iter__()函数;第二件事是循环的过程,循环调用__next__()函数。
对于 IterTest 类来说,它定义了__iter__和__next__函数,所以是一个可迭代的类,也可以说是一个可迭代的对象(Python中一切皆对象)。

参考链接:https://blog.csdn.net/liweibin1994/article/details/77374854

3. 多线程+队列

通常神经网络的训练过程中,CPU由于加载数据和数据的预处理,GPU通常用于梯度计算和参数更新。鉴于CPU和GPU处理数据上的速度差别很大。GPU可以快速处理每一批次(batch)的数据,但是CPU的速度往往是无法与GPU的速度匹配,导致单位时间内GPU处于空闲状态。为了使的GPU处于最大利用状态,考虑使用多线程+队列的方式提升CPU处理数据的时间。

基本的处理流程如下:

在这里插入图片描述

代码实现

  数据处理类,用于CPU读取数据,数据增强等操作。

import random
import numpy as np
import cv2
from PIL import Image, ImageEnhance
from cv_rotation import *


class DataLoader:
    def __init__(self, file):
        # read text file: save train name list
        self.name_list = []

        data = open(file, 'r')
        for line in data:
            line = line.strip()
            self.name_list.append(line)
        random.shuffle(self.name_list)
	
	# 读取名字列表的具体实现,即target
    def name_queue_(self, name_queue):
        count = 0
        random.shuffle(self.name_list)
        while True:
            if count >= len(self.name_list):
                count = 0
                random.shuffle(self.name_list)
                continue

            name_queue.put(self.name_list[count])
            # print(self.name_list[count])
            count = count + 1
            # if name_queue.full():
            #     print('队列满')
            #     print('count: ', count)
	
	# 数据颜色类特征增强
    def image_enhance(self, img):
        p = random.randint(1, 3)
        a1 = random.uniform(0.8, 2)
        a2 = random.uniform(0.8, 1.4)
        a3 = random.uniform(0.8, 1.7)
        a4 = random.uniform(0.8, 2.5)
        img = Image.fromarray(img)

        img = ImageEnhance.Color(img).enhance(a1) if p == 0 else img
        img = ImageEnhance.Brightness(img).enhance(a2) if p == 1 else img
        img = ImageEnhance.Contrast(img).enhance(a3) if p == 2 else img
        img = ImageEnhance.Sharpness(img).enhance(a4) if p == 3 else img
        img = np.array(img)

        return img
	
	# 图像翻转
    def flip_img(self, img):
        flipped = (np.random.random() < 0.5)

        if flipped:
            img = img[:, ::-1, :]

        return img

    @staticmethod
    def show_image(name, data):
        cv2.imshow(name, data)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
	
	# 图像随机旋转
    def pose_rotation(self, img):
        w, h, c = img.shape
        deg = random.uniform(-15.0, 15.0)
        M_rotate = affine_rotation_matrix(angle=deg)
        transform_matrix = transform_matrix_offset_center(M_rotate, x=w, y=h)

        img_result = affine_transform_cv2(img, transform_matrix)

        return img_result
	
	# 读取数据,每次存储一个batch,用于直接传入GPU训练
    def load_data(self, human_data, batch, queue, thread, name_queue):
        image = []
        label = []
        data_name = []
        thread_name = []
		
		# 无限循环读取数据
        while 1:
        	# 从名字列表队列中取数据,数据的名字
            data = name_queue.get()
            # print('data: ', data)
            d1 = data.split(' ')
			
			# 根据获取的名字,opencv读取图像,标签
			# 数据增强
            if len(d1) == 2:
                data_image = '../roc_0716/' + d1[0] + '/correction.roc_0.bmp'
                if float(d1[-1]) > 10:
                    data_label = 1
                # elif float(d1[-1]) > 12:
                #     continue
                else:
                    data_label = 0
            else:
                ss = ' '.join(d1[:-1])
                # print(ss)
                data_image = '../roc_0716/' + ss + '/correction.roc_0.bmp'

                if float(d1[-1]) > 10:
                    data_label = 1
                # elif float(d1[-1]) > 12:
                #     continue
                else:
                    data_label = 0

            img = cv2.imread(data_image)
            # human_data.show_image('ori image', img)

            # 数据增强
            img = human_data.image_enhance(img)
            # human_data.show_image('enhance', img)
            img = human_data.flip_img(img)
            # human_data.show_image('flip', img)

            img = human_data.pose_rotation(img)
            img = cv2.resize(img, (320, 480))
            # human_data.show_image('resize', img)
			
			# 归一化,数据处理的常见操作
            img = img.astype(np.float32)
            img = (img - np.mean(img, axis=(0, 1))) / (np.std(img, axis=(0, 1)) + 1e-8)
			
			# 将处理完的数据放入列表
            data_name.append(data_image)
            # thread_name.append(thread)
            image.append(img)
            label.append(data_label)
			
			# 每次读满一个Batch才会存储数据到队列
            if len(image) != batch:
                continue
			
            queue.put([data_name, thread_name, np.array(image), np.array(label)])
            # print('name+++: ', data_image)
            # print('thread++: ', thread)

            image = []
            label = []
            data_name = []
            thread_name = []

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150

  下面的代码片段为多线程的初始化,以及分配队列长度等。线程具体分配多少?这个需要参考电脑的线程的数量,不能占用全部的线程。队列的长度影响数据的预存储数量,也间接决定GPU的效率。所以,队列的长度除了考虑电脑的内存大小,还要考虑GPU的使用情况。那么,队列长度和线程数量分配的标准是保证GPU满负荷状态训练。

def main():
    tf.set_random_seed(-1)

    # ****************************************************************** #
    #                   1. Python多线程数据读取与数据增强                   #
    # ****************************************************************** #
    train_file = "./data/train.txt"  # 保存训练集的名字列表
    human_data_train = DataLoader(train_file)
    print("num of train data: ", len(human_data_train.name_list))

    # 单线程读取,存入队列,读取训练集名字
    train_name_queue = Queue(cfg.Train.Train_Num)  # len(human_data_train.name_list)
    name_process = Process(target=human_data_train.name_queue_, args=(train_name_queue, ))
    name_process.start()

    # # create queue and read train data
    cache_train_data = 300
    train_thread_num = 3
	
	# 初始化训练集存储队列
    q = Queue(cache_train_data)
    for thread in range(train_thread_num):
    	# target: 队列读取数据的方法或者实现
    	# args: 方法或者实现的参数
        p_train = Process(target=human_data_train.load_data,
                          args=(human_data_train, cfg.Train.Batch_Size, q, thread, train_name_queue))
        p_train.start()

    # load valid data
    valid_file = "./data/valid.txt"  # 保存验证集的名字列表
    human_data_valid = DataLoader(valid_file)
    print("num of valid data: ", len(human_data_valid.name_list))

    # 单线程读取数据并存入队列,读取验证集名字
    valid_name_queue = Queue(cfg.Train.Valid_Num)  # len(human_data_valid.name_list)
    valid_name_process = Process(target=human_data_valid.name_queue_, args=(valid_name_queue,))
    valid_name_process.start()

    # create queue and read valid data
    cache_valid_data = 50
    valid_thread_num = 1
	
	# 初始化存储验证集数据的队列
    valid_queue = Queue(cache_valid_data)
    for thread in range(valid_thread_num):
        p_valid = Process(target=human_data_train.load_data,
                          args=(human_data_valid, cfg.Train.Batch_Size, valid_queue, thread, valid_name_queue))
        p_valid.start()
	
	# 模型加载
	model = Model()
	... ...
	
	# 开始训练
	with tf.Session() as sess:
		# train process
		 _, _, image, label = q.get()
		sess.run([train_op], feed_dict={input:image, label_c:label})
		
		# valid process
		_, _, valid_image, valid_label = valid_queue.get()
		sess.run([train_op], feed_dict={valid_image, label_c:valid_label })
		... ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/123933
推荐阅读
相关标签
  

闽ICP备14008679号