赞
踩
""" @file: codes.py @Time : 2023/1/12 @Author : Peinuan qin """ import numpy as np import torch from torchvision import transforms from torchvision.datasets import MNIST import random from torch.utils.data import Subset, Dataset DATA_ROOT = "./data" MEAN = (0.1307,) STD = (0.3081,) class MyDataset(Dataset): def __init__(self, dataset, ratio=0.2, add_noise=True): self.dataset = dataset self.add_noise = add_noise if ratio: random_indexs = random.sample(range(len(dataset)), int(ratio * len(dataset))) self.dataset = Subset(dataset, random_indexs) print(f"using a small dataset with ratio: {ratio}") def __len__(self): return len(self.dataset) def __getitem__(self, item): # noise image as the encoder-decoder input, and the clean image as the groundtruth label if self.add_noise: return self.make_noise(self.dataset[item][0]), self.dataset[item][0] else: return self.dataset[item][0], self.dataset[item][0] def make_noise(self, x): """ generate gaussian noise to make noised data for encoder :param x: :return: """ noise = np.random.normal(0, 1, size=x.size()) noise = torch.from_numpy(noise) x += noise return x DATASET_RATIO = 0.2 trainset = MNIST(DATA_ROOT , train=True , transform=transforms.Compose([transforms.ToTensor() , transforms.Normalize(MEAN, STD)]) , download=True) valset = MNIST(DATA_ROOT , train=False , transform=transforms.Compose([transforms.ToTensor() , transforms.Normalize(MEAN, STD)]) , download=False) # only use 0.2 of the raw data for training and validation train_set = MyDataset(trainset, ratio=DATASET_RATIO) val_set = MyDataset(valset, ratio=DATASET_RATIO)
""" @file: codes.py @Time : 2023/1/12 @Author : Peinuan qin """ from torchvision import transforms DATA_ROOT = "./data" MEAN = (0.1307,) STD = (0.3081,) from torchvision.datasets import MNIST from torch.utils.data import ConcatDataset, Subset, random_split trainset = MNIST(DATA_ROOT , train=True , transform=transforms.Compose([transforms.ToTensor() , transforms.Normalize(MEAN, STD)]) , download=True) valset = MNIST(DATA_ROOT , train=False , transform=transforms.Compose([transforms.ToTensor() , transforms.Normalize(MEAN, STD)]) , download=False) complete_set = ConcatDataset([trainset, valset])
Subset
类来完成, 具体实现,可以参考第一段代码中有关 Subset
的部分dataset = Subset(dataset, random_indexs)
""" @file: codes.py @Time : 2023/1/12 @Author : Peinuan qin """ import random from collections import Counter from copy import deepcopy import numpy as np from matplotlib import pyplot as plt from torch.utils.data import Dataset from torchvision import transforms from tqdm import tqdm from torchvision.datasets import MNIST from torch.utils.data import ConcatDataset DATA_ROOT = "./data" MEAN = (0.1307,) STD = (0.3081,) CLS_NUM = 10 BATCHSIZE=64 SPLIT_NUM = 2 class MyDataset(Dataset): def __init__(self, dataset, transform=None): super(MyDataset, self).__init__() self.dataset = dataset self.x, self.y = self.get_x_y() self.transform = transform def __len__(self): return len(self.dataset) def get_x_y(self): x = [] y = [] for i in range(len(self.dataset)): x.append(self.dataset[i][0]) y.append(self.dataset[i][1]) return x, y def get_dict(self): dict = {} for i in tqdm(range(len(self.x))): if self.y[i] not in dict: dict[self.y[i]] = [] dict[self.y[i]].append(self.x[i]) else: dict[self.y[i]].append(self.x[i]) return dict def get_y_lst(self): return self.y def plot_distribution(self): plt.hist(self.y) plt.show() def __getitem__(self, item): img = self.dataset[item][0] label = self.dataset[item][1] if self.transform: img = self.transform(img) return img, label class ClassDict: def __init__(self, label, x_lst): self.label = label self.x_lst = x_lst self.dict = {i: x_lst[i] for i in range(len(x_lst))} self.copy_dict = deepcopy(self.dict) def sample(self, num): num = min(num, len(self.dict)) sample_indexs = random.sample(list(self.dict.keys()), num) x_lst = [self.dict.pop(idx) for idx in sample_indexs] print(f"label: {self.label}, remaining samples: {len(self.dict)}") print(f"label: {self.label}, sampling lst length: {len(x_lst)}") return x_lst def remain(self): x_lst = [v for k, v in self.dict.items()] return x_lst class NormalSampler: def __init__(self, class_dicts): self.class_dicts = class_dicts def sample(self, mean, std, num): label_float_lst = np.random.normal(mean, std, (num,)) label_int_lst = list(map(lambda x: int(x), label_float_lst)) label_count_dict = Counter(label_int_lst) print(label_count_dict) for k in dict(label_count_dict).keys(): if k not in range(len(self.class_dicts)): label_count_dict.pop(k) all_x_lst = [] all_y_lst = [] for label, count in label_count_dict.items(): class_dic = self.class_dicts[label] class_x_lst = class_dic.sample(count) class_y_lst = [label for _ in range(len(class_x_lst))] all_x_lst.extend(class_x_lst) all_y_lst.extend(class_y_lst) return all_x_lst, all_y_lst def remain(self): all_x_lst = [] all_y_lst = [] for i in range(len(self.class_dicts)): class_dic = self.class_dicts[i] label = class_dic.label class_x_lst = class_dic.remain() class_y_lst = [label for _ in range(len(class_x_lst))] all_x_lst.extend(class_x_lst) all_y_lst.extend(class_y_lst) return all_x_lst, all_y_lst class SubDataset(Dataset): def __init__(self, x, y): super(SubDataset, self).__init__() self.x = x self.y = y def __len__(self): return len(self.x) def plot_distribution(self): plt.hist(self.y) plt.show() def __getitem__(self, item): return self.x[item], self.y[item] trainset = MNIST(DATA_ROOT , train=True , transform=transforms.Compose([transforms.ToTensor() , transforms.Normalize(MEAN, STD)]) , download=True) valset = MNIST(DATA_ROOT , train=False , transform=transforms.Compose([transforms.ToTensor() , transforms.Normalize(MEAN, STD)]) , download=False) complete_set = ConcatDataset([trainset, valset]) complete_set = MyDataset(complete_set, transform=None) # class_dicts 包含了 N 个 dict,买个 dict 存放了当前类中所有的 x 样本 classes_dict = complete_set.get_dict() class_dicts = [ClassDict(i, classes_dict[i]) for i in range(CLS_NUM)] # 构造正态分布取样器 normal_sampler = NormalSampler(class_dicts) # 每个 split 中分的基础样本数量(最后一个split) 可能会多余 base_sample_size basic_sample_size = len(complete_set) // SPLIT_NUM subsets = [] for i in range(SPLIT_NUM): # 这时候可以用 sampler.sample() 方法来取样 if i != SPLIT_NUM-1: x, y = normal_sampler.sample(CLS_NUM // SPLIT_NUM, 3, basic_sample_size) subset = SubDataset(x, y) # subset.plot_distribution() # 最后一次采样必须包含其他部分所有的样本 else: x, y = normal_sampler.remain() subset = SubDataset(x, y) # subset.plot_distribution() subsets.append(subset) for i in range(len(subsets)): print("=" * 35) print(f"subset {i}") subset = MyDataset(subsets[i]) # subset = subsets[i] print(f"subset length: {len(subset)}") x, y = subset.get_x_y() subset = MyDataset(subset, transforms.Compose( [ transforms.RandomHorizontalFlip(), transforms.ToTensor(), ] )) # 分层抽样,对每个 subset 保证数据的 train, val 是同分布的 for train_idxs, val_idxs in StratifiedShuffleSplit(n_splits=1 , train_size=0.75 , test_size=0.25 , random_state=1024).split(x, y): train_sampler = SubsetRandomSampler(train_idxs) val_sampler = SubsetRandomSampler(val_idxs) fold_train_loader = DataLoader(subset , batch_size=BATCHSIZE # , shuffle=True , sampler=train_sampler , num_workers=4 , pin_memory=True) fold_val_loader = DataLoader(subset , batch_size=BATCHSIZE # , shuffle=False , sampler=val_sampler , num_workers=4 , pin_memory=True)
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。