当前位置:   article > 正文

【pytorch基础】torch.utils.data.random_split()划分数据集

torch.utils.data.random_split

torch.utils.data.random_split()划分数据集

小白看代码的时候发现这个函数,查了很多资料,在此做一下笔记

random_split()函数说明:这个函数的作用是划分数据集,我们不用自己划分数据集,pytorch已经帮我们封装好了,划分数据集就用torch.utils.data.random_split()

跳转到pytorch封装的random_split()函数里面,函数的具体定义是这样的:

def random_split(dataset, lengths):
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.

    Arguments:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
    """
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = randperm(sum(lengths)).tolist()
    return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

参数说明:

dataset (Dataset): 划分的数据集
lengths (sequence): 被划分数据集的长度
注:可选择固定生成器以获得可复现的结果(效果同设置随机种子

示例

import torch
from torch.utils.data import random_split
dataset = range(13)
torch.manual_seed(0)
train_dataset,test_dataset = random_split(
    dataset = dataset,
    lengths = [6,7],
)
print(list(train_dataset))
print(list(test_dataset))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

结果

[3, 4, 0, 2, 11, 8]
[12, 5, 10, 6, 9, 1, 7]
  • 1
  • 2
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/149538
推荐阅读
相关标签
  

闽ICP备14008679号