赞
踩
首先,我自己创建的彩色图像数据集是这样的:
标签是这样的:
- #本文引荐了文章:http://t.csdn.cn/gkVNC;并作了注释与修改
- #导入库
- import os
- import torch
- import pandas as pd
- import matplotlib.pyplot as plt
- import numpy as np
- from torchvision.io import read_image
- from torch.utils.data import Dataset
- from torch.utils.data import DataLoader
- from torchvision import transforms
-
- #创建自定义数据集类
- class Custom_Dataset(Dataset):
- #函数,设置图像集路径索引、图像标签文件读取
- def __init__(self, img_dir, img_label_dir, transform=None):
- super().__init__()
- self.img_dir = img_dir
- self.img_labels = pd.read_csv(img_label_dir)
- self.transform = transform
-
- #函数,设置数据集长度
- def __len__(self):
- return len(self.img_labels)
-
- #函数,设置指定图像读取、指定图像标签索引
- def __getitem__(self, index):
- #'所在文件路径+指定图像名'
- img_path = os.path.join(self.img_dir + self.img_labels.iloc[index, 1])
- #读指定图像
- image=plt.imread(img_path)
- #'指定图像标签'
- label = self.img_labels.iloc[index, 0]
- return image, label
-
- # 把图片对应的tensor调整维度,并显示
- def tensorToimg(img_tensor):
- img=img_tensor
- plt.imshow(img)
- #python3.X必须加下行
- plt.show()
-
- #标签指示含义
- label_dic = {1: '膏岩', 2: '灰岩', 3: '灰质膏岩'}
-
- #图像集及标签路径
- label_path = 'C:/Users/yeahamen/AppData/Local/Programs/Python/Python310/label.csv'
- img_root_path = 'C:/Users/yeahamen/Desktop/自定义数据集/image/'
-
- #加载图像集与标签路径到函数
- #实例化类
- dataset = Custom_Dataset(img_root_path, label_path)
- #索引指定位置的图像及标签
- image, label = dataset.__getitem__(18)
-
- #展示图片及其形状(tensor)
- print(image.shape)
- print(label_dic[label])
- #tensorToimg(image)
-
-
- #批量输出
- dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
- #查看一批图像的形状
- for imgs, labels in dataloader:
- print(imgs.shape)#一批图像形状:torch.Size([5, 3456, 5184, 3])
- print(labels)#标签:tensor([3, 2, 3, 3, 1])
- break
-
- showimages=[]
- showlabels=[]
- for imgs, labels in dataloader:
- c = torch.squeeze(imgs, 0)#减去一维数据形成图片固定三参数
- d = torch.squeeze(labels,0)
- showimages.append(c)
- showlabels.append(d)
- def show_image(nrow, ncol, sharex, sharey):
- fig, axs = plt.subplots(nrow, ncol, sharex=sharex, sharey=sharey, figsize=(10, 10))
- for i in range(0,nrow):
- for j in range(0,ncol):
- axs[i,j].imshow(showimages[i*4+j])
- axs[i,j].set_title('Label={}'.format(showlabels[i*4+j]))
- plt.show()
- plt.tight_layout()
- #给定参数
- show_image(2, 4, False, False)
上面代码注释非常详细了。
最后通过读取、展示得到的结果:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。