当前位置:   article > 正文

[高光谱]使用PyTorch的dataloader加载高光谱数据_高光谱数据卷积神经网络 torch

高光谱数据卷积神经网络 torch

本文实验的部分代码参考

Hyperspectral-Classificationicon-default.png?t=N4P3https://github.com/eecn/Hyperspectral-Classification如果对dataloader的工作原理不太清楚可以参见

[Pytorch]DataSet和DataLoader逐句详解icon-default.png?t=N4P3https://blog.csdn.net/weixin_37878740/article/details/129350390?spm=1001.2014.3001.5501

一、原理解析

        常见的高光谱数据维.mat格式,由数据文件gt(ground-truth)文件组成,图像数据和标签数据。这里以印度松数据为例,图像数据的尺寸为145*145*200,标签数据的尺寸为145*145*1。

         本文的实验代码主要思想如下:

                ①获取高光谱数据集gt标签集

                ②按一定比例将数据集切割为训练集、测试集、验证集

                ③将训练集和验证集装入dataloader

二、获取高光谱数据

  1. # 解析高光谱数据
  2. def get_dataset(target_folder,dataset_name):
  3. palette = None
  4. # 拼接文件路径
  5. folder = target_folder + '/' + dataset_name
  6. # 打开数据文件
  7. if dataset_name == 'IndianPines':
  8. img = open_file(folder + '/Indian_pines_corrected.mat')
  9. img = img['indian_pines_corrected'] #选择矩阵
  10. rgb_bands = (43, 21, 11) # AVIRIS sensor
  11. gt = open_file(folder + '/Indian_pines_gt.mat')['indian_pines_gt']
  12. # 设置标签
  13. label_values = ["Undefined", "Alfalfa", "Corn-notill", "Corn-mintill",
  14. "Corn", "Grass-pasture", "Grass-trees",
  15. "Grass-pasture-mowed", "Hay-windrowed", "Oats",
  16. "Soybean-notill", "Soybean-mintill", "Soybean-clean",
  17. "Wheat", "Woods", "Buildings-Grass-Trees-Drives",
  18. "Stone-Steel-Towers"]
  19. ignored_labels = [0]
  20. # 设置背景标签
  21. nan_mask = np.isnan(img.sum(axis=-1))
  22. img[nan_mask] = 0
  23. gt[nan_mask] = 0
  24. ignored_labels.append(0)
  25. # 数据格式转换
  26. ignored_labels = list(set(ignored_labels))
  27. img = np.asarray(img, dtype='float32')
  28. data = img.reshape(np.prod(img.shape[:2]), np.prod(img.shape[2:]))
  29. data = preprocessing.minmax_scale(data)
  30. img = data.reshape(img.shape)
  31. return img, gt, label_values, ignored_labels, rgb_bands, palette

        这里仅适配了印度松,有其他数据集需求的可以自行修改内部的参数。

        该函数会从.mat文件中获取图像文件和gt文件,并将相关信息打包返回,其中,读取文件的函数为:open_file(.)

  1. # 打开高光谱文件
  2. def open_file(dataset):
  3. _, ext = os.path.splitext(dataset)
  4. ext = ext.lower()
  5. # 根据格式不同打开文件
  6. if ext == '.mat':
  7. return io.loadmat(dataset)
  8. elif ext == '.tif' or ext == '.tiff':
  9. return imageio.imread(dataset)
  10. elif ext == '.hdr':
  11. img = spectral.open_image(dataset)
  12. return img.load()
  13. else:
  14. raise ValueError("Unknown file format: {}".format(ext))

        在主函数中调用如下:

  1. DataSetName = 'IndianPines'
  2. target_folder = 'Dataset'
  3. img, gt, LABEL_VALUES, IGNORED_LABELS, RGB_BANDS,
  4. palette = get_dataset(target_folder,DataSetName)

二、DataSet类

        在使用DataSet类加载数据集前,我们需要将数据集进行随机划分,这里直接调用了原项目的sample_gt(.)函数对gt进行分割。

  1. def sample_gt(gt, train_size, mode='random'):
  2. indices = np.nonzero(gt)
  3. X = list(zip(*indices)) # x,y features
  4. y = gt[indices].ravel() # classes
  5. train_gt = np.zeros_like(gt)
  6. test_gt = np.zeros_like(gt)
  7. if train_size > 1:
  8. train_size = int(train_size)
  9. if mode == 'random':
  10. train_indices, test_indices = sklearn.model_selection.train_test_split(X, train_size=train_size, stratify=y)
  11. train_indices = [list(t) for t in zip(*train_indices)]
  12. test_indices = [list(t) for t in zip(*test_indices)]
  13. train_gt[tuple(train_indices)] = gt[tuple(train_indices)]
  14. test_gt[tuple(test_indices)] = gt[tuple(test_indices)]
  15. elif mode == 'fixed':
  16. print("Sampling {} with train size = {}".format(mode, train_size))
  17. train_indices, test_indices = [], []
  18. for c in np.unique(gt):
  19. if c == 0:
  20. continue
  21. indices = np.nonzero(gt == c)
  22. X = list(zip(*indices)) # x,y features
  23. train, test = sklearn.model_selection.train_test_split(X, train_size=train_size)
  24. train_indices += train
  25. test_indices += test
  26. train_indices = [list(t) for t in zip(*train_indices)]
  27. test_indices = [list(t) for t in zip(*test_indices)]
  28. train_gt[train_indices] = gt[train_indices]
  29. test_gt[test_indices] = gt[test_indices]
  30. elif mode == 'disjoint':
  31. train_gt = np.copy(gt)
  32. test_gt = np.copy(gt)
  33. for c in np.unique(gt):
  34. mask = gt == c
  35. for x in range(gt.shape[0]):
  36. first_half_count = np.count_nonzero(mask[:x, :])
  37. second_half_count = np.count_nonzero(mask[x:, :])
  38. try:
  39. ratio = first_half_count / second_half_count
  40. if ratio > 0.9 * train_size and ratio < 1.1 * train_size:
  41. break
  42. except ZeroDivisionError:
  43. continue
  44. mask[:x, :] = 0
  45. train_gt[mask] = 0
  46. test_gt[train_gt > 0] = 0
  47. else:
  48. raise ValueError("{} sampling is not implemented yet.".format(mode))
  49. return train_gt, test_gt

        主函数调用如下:

  1. #--训练集占比
  2. SAMPLE_PERCENTAGE = 0.1
  3. #--数据集划分
  4. train_gt, test_gt = sample_gt(gt,SAMPLE_PERCENTAGE,mode='random')
  5. train_gt, val_gt = sample_gt(train_gt, 0.95, mode='random')

        随后将划分好的数据集放入DataSet类中,DataSet类共计9个参数,分别代表:

  1. data-高光谱数据集;
  2. gt-标签集;
  3. patch_size-邻居个数(即感受野,影响提取的每个块大小);
  4. ignored_labels - 需要忽略的类别;
  5. flip_augmentation - 是否使用随机折叠;
  6. radiation_augmentation - 是否使用随机噪声;
  7. mixture_augmentation - 是否对光谱进行随机混合
  8. center_pixel - 设置为True以仅考虑中心像素的标签
  9. supervision - 训练模式,可选'full'-全监督 或 'semi'-半监督

        DataSet如下:

  1. # 高光谱dataset类
  2. class HyperX(torch.utils.data.Dataset):
  3. def __init__(self,data,gt,patch_size,ignored_labels,flip_augmentation,radiation_augmentation,mixture_augmentation,center_pixel,supervision):
  4. super().__init__()
  5. self.data = data
  6. self.label = gt
  7. self.patch_size = patch_size
  8. self.ignored_labels = ignored_labels
  9. self.flip_augmentation = flip_augmentation
  10. self.radiation_augmentation = radiation_augmentation
  11. self.mixture_augmentation = mixture_augmentation
  12. self.center_pixel = center_pixel
  13. supervision = supervision
  14. # 监督模式
  15. if supervision == 'full':
  16. mask = np.ones_like(gt)
  17. for l in self.ignored_labels:
  18. mask[gt == l] = 0
  19. # 半监督模式
  20. elif supervision == 'semi':
  21. mask = np.ones_like(gt)
  22. x_pos, y_pos = np.nonzero(mask)
  23. p = self.patch_size // 2
  24. self.indices = np.array([(x,y) for x,y in zip(x_pos, y_pos) if x > p-1 and x < data.shape[0] - p and y > p-1 and y < data.shape[1] - p])
  25. self.labels = [self.label[x,y] for x,y in self.indices]
  26. np.random.shuffle(self.indices)
  27. @staticmethod #静态方法
  28. def flip(*arrays):
  29. horizontal = np.random.random() > 0.5
  30. vertical = np.random.random() > 0.5
  31. if horizontal:
  32. arrays = [np.fliplr(arr) for arr in arrays]
  33. if vertical:
  34. arrays = [np.flipud(arr) for arr in arrays]
  35. return arrays
  36. @staticmethod
  37. def radiation_noise(data, alpha_range=(0.9, 1.1), beta=1/25):
  38. alpha = np.random.uniform(*alpha_range)
  39. noise = np.random.normal(loc=0., scale=1.0, size=data.shape)
  40. return alpha * data + beta * noise
  41. def mixture_noise(self, data, label, beta=1/25):
  42. alpha1, alpha2 = np.random.uniform(0.01, 1., size=2)
  43. noise = np.random.normal(loc=0., scale=1.0, size=data.shape)
  44. data2 = np.zeros_like(data)
  45. for idx, value in np.ndenumerate(label):
  46. if value not in self.ignored_labels:
  47. l_indices = np.nonzero(self.labels == value)[0]
  48. l_indice = np.random.choice(l_indices)
  49. assert(self.labels[l_indice] == value)
  50. x, y = self.indices[l_indice]
  51. data2[idx] = self.data[x,y]
  52. return (alpha1 * data + alpha2 * data2) / (alpha1 + alpha2) + beta * noise
  53. # 获得长度数据
  54. def __len__(self):
  55. return len(self.indices)
  56. # 获得元素
  57. def __getitem__(self, i):
  58. x,y = self.indices[i]
  59. x1,y1 = x-self.patch_size // 2, y-self.patch_size // 2
  60. x2,y2 = x1+self.patch_size, y1+self.patch_size
  61. data = self.data[x1:x2,y1:y2]
  62. label = self.label[x1:x2,y1:y2]
  63. # 选择数据增强模式
  64. if self.flip_augmentation and self.patch_size > 1: #
  65. data, label = self.flip(data, label)
  66. if self.radiation_augmentation and np.random.random() < 0.1:
  67. data = self.radiation_noise(data)
  68. if self.mixture_augmentation and np.random.random() < 0.2:
  69. data = self.mixture_noise(data, label)
  70. # mat->np->tensor
  71. data = np.asarray(np.copy(data).transpose((2, 0, 1)), dtype='float32')
  72. label = np.asarray(np.copy(label), dtype='int64')
  73. data = torch.from_numpy(data)
  74. label = torch.from_numpy(label)
  75. # 提取中心标签
  76. if self.center_pixel and self.patch_size > 1:
  77. label = label[self.patch_size // 2, self.patch_size // 2]
  78. # 使用不可见光谱时删除未使用部分
  79. elif self.patch_size == 1:
  80. data = data[:, 0, 0]
  81. label = label[0, 0]
  82. # 进行3D卷积时增加一维
  83. if self.patch_size > 1:
  84. data = data.unsqueeze(0)
  85. return data,label

        dataset_collate:

  1. def HyperX_collate(batch):
  2. datas = []
  3. labels = []
  4. for data, label in batch:
  5. datas.append(data)
  6. labels.append(label)
  7. datas = np.array(datas)
  8. labels = np.array(labels)
  9. return datas, labels

        在主函数中调用如下:

  1. # 调用dataset
  2. train_dataset = HyperX(img, train_gt,patch_size,IGNORED_LABELS,True,True,True,True,'full')
  3. val_dataset = HyperX(img, val_gt,patch_size,IGNORED_LABELS,True,True,True,True,'full')
  4. # 调用dataloader
  5. train_loader = DataLoader(train_dataset,batch_size=batch_size,pin_memory=True,shuffle=True)
  6. val_loader = DataLoader(val_dataset,batch_size=batch_size,pin_memory=True,shuffle=True)

三、数据展示

  1. # 可视化展示
  2. for item in train_dataset:
  3. img,label = item
  4. img = torch.squeeze(img,0) #除去第0维度
  5. img = img.permute(1,2,0) #调整通道位置
  6. print('tensor尺寸:{}'.format(img.shape))
  7. img = img.numpy() #转换为numpy
  8. view1 = spy.imshow(data=img, bands=RGB_BANDS, title="train") # 图像显示
  9. print('标签编号:{}'.format(label.numpy()))

        邻居个数patch_size设置为9,运行后得到如下结果:

                 

四、模拟训练

  1. print("模拟训练")
  2. for epoch in range(3):
  3. step = 0
  4. for data in train_loader:
  5. imgs, labels = data
  6. print(imgs.shape)
  7. print(labels.shape)
  8. img = imgs[0]
  9. img = torch.squeeze(img,0).permute(1,2,0).numpy() #通道调整和numpy转换
  10. view1 = spy.imshow(data=img, bands=RGB_BANDS, title="train") # 图像显示
  11. step=step+1
  12. input("按任意键继续")

         测试结果如下:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/338645
推荐阅读
相关标签
  

闽ICP备14008679号