当前位置:   article > 正文

tensorflow学习笔记——获取训练数据集和测试数据集_如何赵某个物体的训练数据集

如何赵某个物体的训练数据集

训练神经网络模型之前,需要先获取训练数据集和测试数据集,本文介绍的获取数据集(get_data_train_test)的方法包括以下步骤:
1 在数据集文件夹中,不同类别图像分别放在以各自类别名称命名的文件夹中;
2 获取所有图像路径以及分类;
3 将分类转为字典格式;
4 将所有图像路径打乱;
5 将所有图像路径切分为训练部分和测试部分;
6 获取x部分
6.1 获取图像;
6.2 图像尺寸调整;
6.3 图像降维;
6.4 图像像素值取反;
6.5 图像像素值归一化;
7 获取y部分
7.1 获取图像的类别名称;
7.2 找到类别名称对应的id;
7.3 列表推到;

import os
import random
import math
import sys
import cv2
import numpy as np
from PIL import Image

#数据集路径
DATASET_TRAIN_TEST_DIR = 'D:/word/data_train_test'
DATASET_TEST_DIR = 'D:/word/data_test'
#随机种子
RANDOM_SEED = 0
#验证集数量
NUM_TEST = 20
#分类数量
NUM_CLASS = 10

#获取所有文件以及分类
def get_filenames_and_classes(dataset_dir):
	#数据目录
	directories = []
	#分类名称
	class_names = []
	for filename in os.listdir(dataset_dir):
		#合并文件路径
		path = os.path.join(dataset_dir, filename)
		#判断该路径是否为目录
		if os.path.isdir(path):
			#加入数据目录
			directories.append(path)
			#加入类别名称
			class_names.append(filename)

	photo_filenames = []
	#循环每个分类的文件夹
	for directory in directories:
		for filename in os.listdir(directory):
			path = os.path.join(directory, filename)
			#把图片加入图片列表
			photo_filenames.append(path)

	return photo_filenames, class_names

def get_xs(filenames):
	xs = []
	for i in range(len(filenames)):
		image = Image.open(filenames[i]).convert('L')
		blank = Image.new('L',[28,28],(255))
		max_length = np.max(image.size)
		w = int(image.size[0]*28/max_length)
		h = int(image.size[1]*28/max_length)
		#图像尺寸不超过28*28
		image = image.resize((w,h), Image.NEAREST)
		#图像尺寸调整为28*28
		blank.paste(image, ((28-w)//2, (28-h)//2))
		#图像尺寸调整为1*784
		x = blank.resize((1,784))
		#图像转换为数组
		x = np.array(x)
		#图像降维,如[[1],[2],[3]]变为[1,2,3]
		x = x.squeeze()
		#图像像素值取反
		x = np.full(784, 255) - x
		#图像像素值归一化
		max = np.max(x)
		x = x / np.full(784, max)
		#获取多幅图像数据
		xs.append(x)
	return xs

def get_ys(filenames, class_names_to_ids):
	ys = []
	for i in range(len(filenames)):
		#获得图片的类别名称
		class_name = os.path.basename(os.path.dirname(filenames[i]))
		#找到类别名称对应的id
		class_id = class_names_to_ids[class_name]
		#列表推到
		y=[1 if id==class_id else 0 for id in range(NUM_CLASS)]
		ys.append(y)
	return ys

def get_data_train_test():
	#获得所有图片路径以及分类
	photo_filenames, class_names = get_filenames_and_classes(DATASET_TRAIN_TEST_DIR)

	#把分类转为字典格式,类似于{'A':0, 'B':1, 'C':2}
	class_names_to_ids = dict(zip(class_names, range(len(class_names))))

	#把数据切分为训练集和测试集
	random.seed(RANDOM_SEED)
	random.shuffle(photo_filenames)
	training_filenames = photo_filenames[NUM_TEST:]
	testing_filenames = photo_filenames[:NUM_TEST]
	train_xs = get_xs(training_filenames)
	train_ys = get_ys(training_filenames, class_names_to_ids)
	test_xs = get_xs(testing_filenames)
	test_ys = get_ys(testing_filenames, class_names_to_ids)

	return train_xs, train_ys, test_xs, test_ys

def get_data_test():
	filenames = []
	for filename in os.listdir(DATASET_TEST_DIR):
		#合并文件路径
		path = os.path.join(DATASET_TEST_DIR, filename)
		filenames.append(path)
	xs = get_xs(filenames)
	return xs

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/123923
推荐阅读
相关标签
  

闽ICP备14008679号