赞
踩
在训练一个深度学习模型之前,我们会将数据集划分为训练集、验证集和测试集。在训练的时候,我们往往会将训练集打乱,划分成多个batch来进行训练。一般情况下,我们可以使用tf.data.Dataset
或者tf.TFRecordReader()
来实现。如果不使用这两个方法,我们利用numpy
也可以实现这个功能。
下面我将定义一个类,来简单实现这个功能:
import numpy as np class Model(object): def __init__(self, X, y): # 测试集 self.X_test = X[:10000] self.y_test = y[:10000] # 验证集 self.X_val = X[10000:20000] self.y_val = y[10000:20000] # 训练集 self.X_train = X[20000:] self.y_train = y[20000:] def train(self, iters, batch_size=8): """ 定义网络及操作 ...... """ with tf.Session() as sess: # 初始化全局变量 sess.run(init_op) # 开始训练 for epoch in range(iters): # 迭代轮数 # 打乱数据 permutation = np.random.permutation(len(self.y_train)) # 遍历所有的batch loops = len(self.y_train) // batch_size for i in range(loops-1): sess.run(train_op, feed_dict={x: self.X_train[permutation[i*batch_size: (i+1)*batch_size]], y: self.y_train[permutation[i*batch_size: (i+1)*batch_size]]})
在初始化阶段,我将数据集进行了划分, 然后我利用numpy中的随机函数中的permutation函数来实现打乱数据集。这个函数的功能是生成由0到n之间整数的随机排列:
import numpy as np
permuation = numpy.random.permutation(10)
print(permutation)
"""
[out]: array([1, 0, 4, 9, 3, 5, 7, 8, 6, 2])
"""
通过这种方式,其实是将数据集的index打乱,然后通过这些打乱后的index去训练集中取一个batch的数据,直到取完所有数据。
当然,你也可以先全部打乱所有数据,如:
self.X_train = self.X_train[permutation]
self.y_train = self.y_train[permutation]
但是不推荐这种办法,因为如果数据量很大的时候,这种操作会造成内存不足的问题。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。