当前位置:   article > 正文

自定义数据集的训练测试txt文件生成_生成训练测试txt

生成训练测试txt

在自己数据集上根据指定比例,生成测试集和训练集,并写入txt文件

import os
import numpy as np

abs_str = ''            #绝对路径 + 文件名
dirname = ''            #源文件所在目录

import numpy as np


def get_file(dir):
    file_list = []
    label_list = []
    for (index,item) in enumerate(os.listdir(dir)):
        imagDir = os.path.join(os.path.abspath(dir),item)
        if(os.path.isdir(imagDir)):
            for image in os.listdir(imagDir):
                if os.path.isfile(os.path.join(imagDir,image)):
                    file_list.append(os.path.join(item,image))
                    label_list.append(index)
    return file_list,label_list

if __name__=="__main__":
    # file_list,label_list = get_file(dirname)
    # file_handle = open('train_test_split.txt',mode='w')
    # for i,j in enumerate(label_list):
    #     file_handle.write('{} '.format(i + 1))
    #     file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
    #     # file_handle.write(j) # images.txt
    #     file_handle.write('\n')
    # file_handle.close()
    # ------------------------split-test-spilt-V1	
    nums = np.ones(5378, dtype=int)
    test_size = int(0.8 * len(nums))
    nums[:test_size] = 0
    np.random.shuffle(nums)
    file_handle = open('train_test_split.txt', mode='w')
    for i,j in enumerate(nums):
        file_handle.write('{} '.format(i + 1))
        # file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
        file_handle.write('{} '.format(j)) # train_test_split.txt
        # file_handle.write(j) # images.txt
        file_handle.write('\n')
    file_handle.close()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43

发现一点问题,不能简单的根据一个随机数进行划分,存在一种可能是在某一个类中没有取到训练或者测试数据,有问题,因此还是需要进行遍历每一个文件夹,有了如下的更新:

import os
import numpy as np
import random
abs_str = ''            #绝对路径 + 文件名
dirname = ''            #源文件所在目录


def get_file(dir):
    file_list = []
    label_list = []
    for (index,item) in enumerate(os.listdir(dir)):
        imagDir = os.path.join(os.path.abspath(dir),item)
        if(os.path.isdir(imagDir)):
            for image in os.listdir(imagDir):
                if os.path.isfile(os.path.join(imagDir,image)):
                    file_list.append(os.path.join(item,image))
                    label_list.append(index)
    return file_list,label_list

def get_train_test(dir,split_rate):
    train_test_list = []
    for (index,item) in enumerate(os.listdir(dir)):
        imagDir = os.path.join(os.path.abspath(dir),item)
        if(os.path.isdir(imagDir)):
            print('imagDir', imagDir)
            images = os.listdir(imagDir)
            num = len(images)
            eval_index = random.sample(images, k=int(num * split_rate))
            for index, image in enumerate(images):
                if image in eval_index:
                    # 将分配至验证集中的文件复制到相应目录
                    train_test_list.append(0)
                else:
                    # 将分配至训练集中的文件复制到相应目录
                    train_test_list.append(1)
            print()
    return train_test_list
if __name__=="__main__":
    # file_list,label_list = get_file(dirname)
    # file_handle = open('train_test_split.txt',mode='w')
    # for i,j in enumerate(label_list):
    #     file_handle.write('{} '.format(i + 1))
    #     file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
    #     # file_handle.write(j) # images.txt
    #     file_handle.write('\n')
    # file_handle.close()
    # ------------------------split-test-spilt-V1
    # nums = np.zeros(5378, dtype=int)
    # test_size = int(0.8 * len(nums))
    # nums[:test_size] = 1
    # np.random.shuffle(nums)
    # file_handle = open('train_test_split.txt', mode='w')
    # for i,j in enumerate(nums):
    #     file_handle.write('{} '.format(i + 1))
    #     # file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
    #     file_handle.write('{} '.format(j)) # train_test_split.txt
    #     # file_handle.write(j) # images.txt
    #     file_handle.write('\n')
    # file_handle.close()
    # ------------------------split-test-spilt-V2
    nums = get_train_test(dirname,0.2)
    file_handle = open('train_test_split.txt', mode='w')
    for i,j in enumerate(nums):
        file_handle.write('{} '.format(i + 1))
        # file_handle.write('{} '.format(j + 1)) # image_class_labels.txt
        file_handle.write('{} '.format(j)) # train_test_split.txt
        # file_handle.write(j) # images.txt
        file_handle.write('\n')
    file_handle.close()
    n = str(nums).count('1')
    m = str(nums).count('0')
    print('nums',nums)
    print('train',n)
    print('value',m)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/正经夜光杯/article/detail/899936
推荐阅读
相关标签
  

闽ICP备14008679号