当前位置:   article > 正文

window pytorch unet代码学习之random_split_sum of input lengths does not equal the length of

sum of input lengths does not equal the length of the input dataset!

先看看调用

train, val = random_split(dataset, [n_train, n_val])

 

  1. def random_split(dataset, lengths):
  2. r"""
  3. Randomly split a dataset into non-overlapping new datasets of given lengths.
  4. Arguments:
  5. dataset (Dataset): Dataset to be split
  6. lengths (sequence): lengths of splits to be produced
  7. """
  8. if sum(lengths) != len(dataset):
  9. raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
  10. indices = randperm(sum(lengths)).tolist()
  11. return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]

从英文字面意思是,将dataset随机的分割为无覆盖的datasets,按照给定的形状,例如这里是[n_train,n_val]分成两组,一组由n_train个数据,一组datasets有n_val个数据。并且取得的数据是随机的。随机从dataset里面无放回的取出n_train和n_val个数据,这样就不会重复了

第一个判断就是这个lengths里面的数据和跟dataset的len是不是一样的,因为不管分多少组,取的数据就len(dataset)这么大。

接下来看randperm,生成1~n之间整数的无重复随机排列。生成1~len(dataset)的数列并转换为list

最后subset,

先看一个for循环

Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)

for offset,length in zip(。。。。)

查看_accumulate

  1. def _accumulate(iterable, fn=lambda x, y: x + y):
  2. 'Return running totals'
  3. # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
  4. # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
  5. it = iter(iterable)
  6. try:
  7. total = next(it)
  8. except StopIteration:
  9. return
  10. yield total
  11. for element in it:
  12. total = fn(total, element)
  13. yield total

 从示例来看,offset应该是n_train,n_train+n_val,length是n_train,n_val

所以indices就是0:n_train第一个,n_train:n_train+n_val,刚刚好将indices分成两段。

总的来说是调用了两次创建了两个subset放到了list并返回,这个时候得到

 

然后看subset

  1. class Subset(Dataset):
  2. r"""
  3. Subset of a dataset at specified indices.
  4. Arguments:
  5. dataset (Dataset): The whole Dataset
  6. indices (sequence): Indices in the whole set selected for subset
  7. """
  8. def __init__(self, dataset, indices):
  9. self.dataset = dataset
  10. self.indices = indices

这个很简单只是简单的复制操作而已,

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/149522
推荐阅读
相关标签
  

闽ICP备14008679号