当前位置:   article > 正文

笔记1:cifar10数据集获取及pytorch批量处理

笔记1:cifar10数据集获取及pytorch批量处理

(1)cifar10数据集预处理

CIFAR-10是一个广泛使用的图像数据集,它由10个类别的共60000张32x32彩色图像组成,每个类别有6000张图像。
CIFAR-10官网
以下为CIFAR-10数据集data_batch_*表示训练集数据,test_batch表示测试集数据
在这里插入图片描述
预处理结果(将CIFAR-10保存为图片格式)
在这里插入图片描述

#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: LIFEI
@time: 2024/5/8 15:00 
@file: 加载cifar10数据.py
@project: 深度学习(4):深度神经网络(DNN)
@describe: TEXT
@# ------------------------------------------(one)--------------------------------------
@# ------------------------------------------(two)--------------------------------------
"""
import glob
import pickle
import numpy as np
import cv2 as
import os
#%% md
cifar10官网处理函数:
#%%
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict
#%% md
利用上面的函数进行读取数据:
#%%
label = ["airplane","automobile", "bird","cat", 'deer',"dog","frog","horse","ship","truck"]  #标签矩阵
filepath = glob.glob("../../test_doucments/cifar-10-batches-py/data_batch_*") # 获取当前文件的路径,返回路径矩阵,获取test数据集时将data_batch——*改为test_batch*
write_path =["./train","./test"] #
print(filepath)
for file in filepath:
    if not file:
        print("空集出错")
    else:
        # print(file)
        data_dic = unpickle(file) # 将二进制表示形式转换回 Python 对象的反序列化过程,结果为字节型数据
        # print(data_dic.keys()) #此处的keys主要有b"data",b"labels",b"filenames"
        index = 0
        for im_data in data_dic[b"data"]:  # 遍历影像矩阵数据
            im_label = data_dic[b"labels"][index] # 赋值标签数据
            im_filename = data_dic[b"filenames"][index] # 赋值影像名字
            index +=1
            # print(f"图像的文件名为:{im_filename}\n",f"图像的所属标签为:{im_label}\n",f"图像的矩阵数据为:{im_data}\n")

            #开始存放数据
            im_label_name = label[im_label]
            im_data_data = np.reshape(im_data,(3,32,32)) # 将影像矩阵数据转换为图像形式

            # 由于需要opencv进行写出图像,因此需要转化通道
            im_data_data = np.transpose(im_data_data,(1,2,0))
            imgname = f"当前图像名称{im_label},所属标签{im_label_name}"
            cv.imshow(str( im_label_name),cv.resize(im_data_data,(500,500))) # 将显示时的图像变大,图像数据本身大小不变
            cv.waitKey(0)
            cv.destroyAllWindows()

            #创建文件夹
            for path in write_path:
                if not os.path.exists("{}/{}".format(path,im_label_name)): #查看存储路径中的文件夹是否存在
                    os.mkdir("{}/{}".format(path,im_label_name)) # 没有就创建文件
                else:
                    break
            cv.imwrite("{}/{}/{}".format(write_path[0],im_label_name,str(im_filename,'utf-8')),im_data_data)
            # #write_path[1]写出测试数据的时候将write_path[0]改为write_path[1]
#%% md
将cifar10数据转为图片格式并保存
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65

(2)利用pytorch将图像转为张量数据

或是批量读取训练集和测试集数据
在这里插入图片描述

#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: LIFEI
@time: 2024/5/8 15:00 
@file: 加载cifar10数据.py
@project: 深度学习(4):深度神经网络(DNN)
@describe: TEXT
@# ------------------------------------------(one)--------------------------------------
@# ------------------------------------------(two)--------------------------------------
"""
# 导入库
import glob
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
import cv2 as cv
# DataLoader参考网址https://blog.csdn.net/sazass/article/details/116641511

from PIL import Image

label_name = ["airplane","automobile", "bird","cat", 'deer',"dog","frog","horse","ship","truck"]
label_list = {} # 创建一个字典用于存储标签和下标
index = 0
for name in label_name:  # 也可以采用for index,name in enumerate(label_name)
    label_list[name] = index # 字典的常规赋值操作
    index += 1

def default_loder(path):
    return Image.open(path).convert("RGB") 
    
    # 也可采用opencv读取,但是建议不要使用,因为后面的train因为不识别会报错
    # img = cv.imread(path)
    # return cv.cvtColor(img,cv.COLOR_BGR2RGB)


# 定义训练集数据的增强   下面的Compose表示拼接需要增强的操作
train_transform = transforms.Compose([
    transforms.RandomCrop(28,28), #进行随机裁剪为28*28大小
    transforms.RandomHorizontalFlip(), #垂直方向翻转
    transforms.RandomVerticalFlip(), #水平方向的翻转
    transforms.RandomRotation(90), #随机旋转90度
    transforms.RandomGrayscale(0.1), #灰度转化
    transforms.ColorJitter(0.3,0.3,0.3,0.3), #随机颜色增强
    transforms.ToTensor() #将数据转化为张量数据
])

# 定义pytorh的dataset类
class MyData(Dataset):
    def __init__(self,im_list,
                 transform = None,
                 loder = default_loder):     #初始化函数
        super(MyData,self).__init__() #初始化这个类

        # 获取图片的路径以及标签号
        images = []
        for item_data in im_list:
            # 注意下面这一步,split("\\")根据不同的操作系统会不相同,有的是"/"
            img_label_name = item_data.split("\\")[-2] #通过遍历每一个路径进行获取当前图片的文字标签
            images.append([item_data,label_list[img_label_name]])

        self.images = images
        self.tranform =transform
        self.loder = loder

    def __getitem__(self, index_num): # 此处的index_num是在训练的时候反复传进来的值
        img_path , img_label = self.images[index_num] #这里的
        img_data = self.loder(img_path)  # 这里用到了self.loder(path)==>default_loder(path)外置函数

        if self.tranform is not None: # 判断数据是否增强
            img_data = self.tranform(img_data)
        return img_data,img_label

    def __len__(self):
         return len(self.images)

train_list = glob.glob("./train/*/*.png") # glob.glob 获取改路径下的所有文件路径并返回为列表
test_list = glob.glob("./test/*/*.png")

train_dataset = MyData(train_list,transform = train_transform)
test_dataset = MyData(test_list,transform = transforms.ToTensor()) #测试集无需进行图像增强操作,直接转为张量

train_data_loder = DataLoader(dataset =train_dataset,
                              batch_size=6,
                              shuffle=True,
                              num_workers=4) #4线程
test_data_loder = DataLoader(dataset =test_dataset,
                              batch_size=6,
                              shuffle=False,
                              num_workers=4)
print(f"训练集的大小:{len(train_dataset)}")
print(f"测试集的大小:{len(test_dataset)}")

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92

注:以上代码非原创,仅供个人记录学习笔记,若有侵权,请联系我删除

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/582720
推荐阅读
相关标签
  

闽ICP备14008679号