当前位置:   article > 正文

python自定义数据集并展示_python 生成 datasets

python 生成 datasets

 首先,我自己创建的彩色图像数据集是这样的:

 标签是这样的:

 

  1. #本文引荐了文章:http://t.csdn.cn/gkVNC;并作了注释与修改
  2. #导入库
  3. import os
  4. import torch
  5. import pandas as pd
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. from torchvision.io import read_image
  9. from torch.utils.data import Dataset
  10. from torch.utils.data import DataLoader
  11. from torchvision import transforms
  12. #创建自定义数据集类
  13. class Custom_Dataset(Dataset):
  14. #函数,设置图像集路径索引、图像标签文件读取
  15. def __init__(self, img_dir, img_label_dir, transform=None):
  16. super().__init__()
  17. self.img_dir = img_dir
  18. self.img_labels = pd.read_csv(img_label_dir)
  19. self.transform = transform
  20. #函数,设置数据集长度
  21. def __len__(self):
  22. return len(self.img_labels)
  23. #函数,设置指定图像读取、指定图像标签索引
  24. def __getitem__(self, index):
  25. #'所在文件路径+指定图像名'
  26. img_path = os.path.join(self.img_dir + self.img_labels.iloc[index, 1])
  27. #读指定图像
  28. image=plt.imread(img_path)
  29. #'指定图像标签'
  30. label = self.img_labels.iloc[index, 0]
  31. return image, label
  32. # 把图片对应的tensor调整维度,并显示
  33. def tensorToimg(img_tensor):
  34. img=img_tensor
  35. plt.imshow(img)
  36. #python3.X必须加下行
  37. plt.show()
  38. #标签指示含义
  39. label_dic = {1: '膏岩', 2: '灰岩', 3: '灰质膏岩'}
  40. #图像集及标签路径
  41. label_path = 'C:/Users/yeahamen/AppData/Local/Programs/Python/Python310/label.csv'
  42. img_root_path = 'C:/Users/yeahamen/Desktop/自定义数据集/image/'
  43. #加载图像集与标签路径到函数
  44. #实例化类
  45. dataset = Custom_Dataset(img_root_path, label_path)
  46. #索引指定位置的图像及标签
  47. image, label = dataset.__getitem__(18)
  48. #展示图片及其形状(tensor)
  49. print(image.shape)
  50. print(label_dic[label])
  51. #tensorToimg(image)
  52. #批量输出
  53. dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
  54. #查看一批图像的形状
  55. for imgs, labels in dataloader:
  56. print(imgs.shape)#一批图像形状:torch.Size([5, 3456, 5184, 3])
  57. print(labels)#标签:tensor([3, 2, 3, 3, 1])
  58. break
  59. showimages=[]
  60. showlabels=[]
  61. for imgs, labels in dataloader:
  62. c = torch.squeeze(imgs, 0)#减去一维数据形成图片固定三参数
  63. d = torch.squeeze(labels,0)
  64. showimages.append(c)
  65. showlabels.append(d)
  66. def show_image(nrow, ncol, sharex, sharey):
  67. fig, axs = plt.subplots(nrow, ncol, sharex=sharex, sharey=sharey, figsize=(10, 10))
  68. for i in range(0,nrow):
  69. for j in range(0,ncol):
  70. axs[i,j].imshow(showimages[i*4+j])
  71. axs[i,j].set_title('Label={}'.format(showlabels[i*4+j]))
  72. plt.show()
  73. plt.tight_layout()
  74. #给定参数
  75. show_image(2, 4, False, False)

上面代码注释非常详细了。

最后通过读取、展示得到的结果:

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/455084
推荐阅读
相关标签
  

闽ICP备14008679号