赞
踩
在自己数据集上根据指定比例,生成测试集和训练集,并写入txt文件
import os import numpy as np abs_str = '' #绝对路径 + 文件名 dirname = '' #源文件所在目录 import numpy as np def get_file(dir): file_list = [] label_list = [] for (index,item) in enumerate(os.listdir(dir)): imagDir = os.path.join(os.path.abspath(dir),item) if(os.path.isdir(imagDir)): for image in os.listdir(imagDir): if os.path.isfile(os.path.join(imagDir,image)): file_list.append(os.path.join(item,image)) label_list.append(index) return file_list,label_list if __name__=="__main__": # file_list,label_list = get_file(dirname) # file_handle = open('train_test_split.txt',mode='w') # for i,j in enumerate(label_list): # file_handle.write('{} '.format(i + 1)) # file_handle.write('{} '.format(j + 1)) # image_class_labels.txt # # file_handle.write(j) # images.txt # file_handle.write('\n') # file_handle.close() # ------------------------split-test-spilt-V1 nums = np.ones(5378, dtype=int) test_size = int(0.8 * len(nums)) nums[:test_size] = 0 np.random.shuffle(nums) file_handle = open('train_test_split.txt', mode='w') for i,j in enumerate(nums): file_handle.write('{} '.format(i + 1)) # file_handle.write('{} '.format(j + 1)) # image_class_labels.txt file_handle.write('{} '.format(j)) # train_test_split.txt # file_handle.write(j) # images.txt file_handle.write('\n') file_handle.close()
发现一点问题,不能简单的根据一个随机数进行划分,存在一种可能是在某一个类中没有取到训练或者测试数据,有问题,因此还是需要进行遍历每一个文件夹,有了如下的更新:
import os import numpy as np import random abs_str = '' #绝对路径 + 文件名 dirname = '' #源文件所在目录 def get_file(dir): file_list = [] label_list = [] for (index,item) in enumerate(os.listdir(dir)): imagDir = os.path.join(os.path.abspath(dir),item) if(os.path.isdir(imagDir)): for image in os.listdir(imagDir): if os.path.isfile(os.path.join(imagDir,image)): file_list.append(os.path.join(item,image)) label_list.append(index) return file_list,label_list def get_train_test(dir,split_rate): train_test_list = [] for (index,item) in enumerate(os.listdir(dir)): imagDir = os.path.join(os.path.abspath(dir),item) if(os.path.isdir(imagDir)): print('imagDir', imagDir) images = os.listdir(imagDir) num = len(images) eval_index = random.sample(images, k=int(num * split_rate)) for index, image in enumerate(images): if image in eval_index: # 将分配至验证集中的文件复制到相应目录 train_test_list.append(0) else: # 将分配至训练集中的文件复制到相应目录 train_test_list.append(1) print() return train_test_list if __name__=="__main__": # file_list,label_list = get_file(dirname) # file_handle = open('train_test_split.txt',mode='w') # for i,j in enumerate(label_list): # file_handle.write('{} '.format(i + 1)) # file_handle.write('{} '.format(j + 1)) # image_class_labels.txt # # file_handle.write(j) # images.txt # file_handle.write('\n') # file_handle.close() # ------------------------split-test-spilt-V1 # nums = np.zeros(5378, dtype=int) # test_size = int(0.8 * len(nums)) # nums[:test_size] = 1 # np.random.shuffle(nums) # file_handle = open('train_test_split.txt', mode='w') # for i,j in enumerate(nums): # file_handle.write('{} '.format(i + 1)) # # file_handle.write('{} '.format(j + 1)) # image_class_labels.txt # file_handle.write('{} '.format(j)) # train_test_split.txt # # file_handle.write(j) # images.txt # file_handle.write('\n') # file_handle.close() # ------------------------split-test-spilt-V2 nums = get_train_test(dirname,0.2) file_handle = open('train_test_split.txt', mode='w') for i,j in enumerate(nums): file_handle.write('{} '.format(i + 1)) # file_handle.write('{} '.format(j + 1)) # image_class_labels.txt file_handle.write('{} '.format(j)) # train_test_split.txt # file_handle.write(j) # images.txt file_handle.write('\n') file_handle.close() n = str(nums).count('1') m = str(nums).count('0') print('nums',nums) print('train',n) print('value',m)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。