当前位置:   article > 正文

手动实现打乱训练集并生成一个batch的简单方法_打乱多个batch数据

打乱多个batch数据

​ 在训练一个深度学习模型之前,我们会将数据集划分为训练集、验证集和测试集。在训练的时候,我们往往会将训练集打乱,划分成多个batch来进行训练。一般情况下,我们可以使用tf.data.Dataset或者tf.TFRecordReader()来实现。如果不使用这两个方法,我们利用numpy也可以实现这个功能。

​ 下面我将定义一个类,来简单实现这个功能:

import numpy as np

class Model(object):
    def __init__(self, X, y):
        # 测试集
		self.X_test = X[:10000]
        self.y_test = y[:10000]
        # 验证集
        self.X_val = X[10000:20000]
        self.y_val = y[10000:20000]
        # 训练集
        self.X_train = X[20000:]
        self.y_train = y[20000:]
	
    def train(self, iters, batch_size=8):
        """
        定义网络及操作
        ......
        """
        with tf.Session() as sess:
            # 初始化全局变量
            sess.run(init_op)
            # 开始训练
            for epoch in range(iters):	# 迭代轮数
                # 打乱数据
                permutation = np.random.permutation(len(self.y_train))
                # 遍历所有的batch
                loops = len(self.y_train) // batch_size
                for i in range(loops-1):
                    sess.run(train_op, feed_dict={x: self.X_train[permutation[i*batch_size: (i+1)*batch_size]],
                                                 y: self.y_train[permutation[i*batch_size: (i+1)*batch_size]]})
        
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

在初始化阶段,我将数据集进行了划分, 然后我利用numpy中的随机函数中的permutation函数来实现打乱数据集。这个函数的功能是生成由0到n之间整数的随机排列:

import numpy as np

permuation = numpy.random.permutation(10)
print(permutation)

"""
[out]: array([1, 0, 4, 9, 3, 5, 7, 8, 6, 2])
"""
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

通过这种方式,其实是将数据集的index打乱,然后通过这些打乱后的index去训练集中取一个batch的数据,直到取完所有数据。

​ 当然,你也可以先全部打乱所有数据,如:

self.X_train = self.X_train[permutation]
self.y_train = self.y_train[permutation]
  • 1
  • 2

但是不推荐这种办法,因为如果数据量很大的时候,这种操作会造成内存不足的问题。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/521368
推荐阅读
相关标签
  

闽ICP备14008679号