赞
踩
小白看代码的时候发现这个函数,查了很多资料,在此做一下笔记
跳转到pytorch封装的random_split()函数里面,函数的具体定义是这样的:
def random_split(dataset, lengths):
r"""
Randomly split a dataset into non-overlapping new datasets of given lengths.
Arguments:
dataset (Dataset): Dataset to be split
lengths (sequence): lengths of splits to be produced
"""
if sum(lengths) != len(dataset):
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
indices = randperm(sum(lengths)).tolist()
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
dataset (Dataset): 划分的数据集
lengths (sequence): 被划分数据集的长度
注:可选择固定生成器以获得可复现的结果(效果同设置随机种子)
import torch
from torch.utils.data import random_split
dataset = range(13)
torch.manual_seed(0)
train_dataset,test_dataset = random_split(
dataset = dataset,
lengths = [6,7],
)
print(list(train_dataset))
print(list(test_dataset))
[3, 4, 0, 2, 11, 8]
[12, 5, 10, 6, 9, 1, 7]
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。