赞
踩
先看看调用
train, val = random_split(dataset, [n_train, n_val])
- def random_split(dataset, lengths):
- r"""
- Randomly split a dataset into non-overlapping new datasets of given lengths.
- Arguments:
- dataset (Dataset): Dataset to be split
- lengths (sequence): lengths of splits to be produced
- """
- if sum(lengths) != len(dataset):
- raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
-
- indices = randperm(sum(lengths)).tolist()
- return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
从英文字面意思是,将dataset随机的分割为无覆盖的datasets,按照给定的形状,例如这里是[n_train,n_val]分成两组,一组由n_train个数据,一组datasets有n_val个数据。并且取得的数据是随机的。随机从dataset里面无放回的取出n_train和n_val个数据,这样就不会重复了
第一个判断就是这个lengths里面的数据和跟dataset的len是不是一样的,因为不管分多少组,取的数据就len(dataset)这么大。
接下来看randperm,生成1~n之间整数的无重复随机排列。生成1~len(dataset)的数列并转换为list
最后subset,
先看一个for循环
Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)
for offset,length in zip(。。。。)
查看_accumulate
- def _accumulate(iterable, fn=lambda x, y: x + y):
- 'Return running totals'
- # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
- # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
- it = iter(iterable)
- try:
- total = next(it)
- except StopIteration:
- return
- yield total
- for element in it:
- total = fn(total, element)
- yield total
从示例来看,offset应该是n_train,n_train+n_val,length是n_train,n_val
所以indices就是0:n_train第一个,n_train:n_train+n_val,刚刚好将indices分成两段。
总的来说是调用了两次创建了两个subset放到了list并返回,这个时候得到
然后看subset
- class Subset(Dataset):
- r"""
- Subset of a dataset at specified indices.
- Arguments:
- dataset (Dataset): The whole Dataset
- indices (sequence): Indices in the whole set selected for subset
- """
- def __init__(self, dataset, indices):
- self.dataset = dataset
- self.indices = indices
这个很简单只是简单的复制操作而已,
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。