赞
踩
pytorch 提供了一个可用于划分Dataset的简单接口。
如下:
def random_split(dataset, lengths, generator=default_generator): r""" Randomly split a dataset into non-overlapping new datasets of given lengths. Optionally fix the generator for reproducible results, e.g.: >>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42)) Arguments: dataset (Dataset): Dataset to be split lengths (sequence): lengths of splits to be produced generator (Generator): Generator used for the random permutation. """ if sum(lengths) != len(dataset): raise ValueError("Sum of input lengths does not equal the length of the input dataset!") indices = randperm(sum(lengths), generator=generator).tolist() return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]
实践:
class My_Dataset(Dataset): def __init__(self, x, y): self.x = torch.from_numpy(x).to(torch.long) self.y = torch.from_numpy(y).to(torch.long) def __len__(self): return self.x.shape[0] def __getitem__(self, index): return self.x[index], self.y[index] data = np.load(data_path) x = data["x"] y = data["y"] # split dataset full_dataset = My_Dataset(x, y) test_size = int(x.shape[0] * 0.2) train_size = x.shape[0]-test_size*2 train_dataset, test_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size, test_size])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。