当前位置:   article > 正文

torch怎么生成小批量数据?_torch data 批量构建

torch data 批量构建

为了将x_datay_data打乱数据并转换成小批量数据,可以使用PyTorch中的DataLoaderrandom库。下面是一个示例代码:

import torch
from torch.utils.data import DataLoader, TensorDataset

x_data = torch.randn(10, 3)
y_data = torch.randn(10, 1)

for i, j in zip(x_data ,y_data):
    print(i, j)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

输出

tensor([-1.3064, -1.1474, -0.4826]) tensor([-2.0181])
tensor([-0.7043,  0.4129, -0.7812]) tensor([0.2593])
tensor([ 0.8225,  0.4909, -0.9564]) tensor([0.1052])
tensor([ 0.8489,  0.7734, -0.5316]) tensor([-0.1681])
tensor([ 2.6069,  0.3360, -1.2510]) tensor([-1.5229])
tensor([-0.2588,  0.1903, -1.1847]) tensor([-0.1975])
tensor([-2.6685,  2.1388, -0.7719]) tensor([0.8189])
tensor([-0.4615, -1.3020,  0.9347]) tensor([0.1780])
tensor([-0.6927, -0.1758,  0.0818]) tensor([-0.4284])
tensor([-0.7713,  0.0360,  0.3797]) tensor([-0.4796])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

小批次,打乱数据

# import random

# 创建TensorDataset对象
dataset = TensorDataset(x_data, y_data)

# 创建DataLoader对象,并指定batch_size和是否要进行打乱
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 遍历每个小批量数据
for batch_x, batch_y in dataloader:
    # 在这里执行训练或评估操作
    print(batch_x, batch_y)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

输出

tensor([[ 0.8225,  0.4909, -0.9564],
        [-0.7043,  0.4129, -0.7812]]) tensor([[0.1052],
        [0.2593]])
        
tensor([[-0.7713,  0.0360,  0.3797],
        [-2.6685,  2.1388, -0.7719]]) tensor([[-0.4796],
        [ 0.8189]])
        
tensor([[-0.2588,  0.1903, -1.1847],
        [ 0.8489,  0.7734, -0.5316]]) tensor([[-0.1975],
        [-0.1681]])
        
tensor([[-0.4615, -1.3020,  0.9347],
        [-0.6927, -0.1758,  0.0818]]) tensor([[ 0.1780],
        [-0.4284]])
        
tensor([[ 2.6069,  0.3360, -1.2510],
        [-1.3064, -1.1474, -0.4826]]) tensor([[-1.5229],
        [-2.0181]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

其中,使用TensorDatasetx_datay_data合并到一个数据集中。

然后,使用DataLoader创建一个迭代器,以便逐个处理每个小批量数据。

在这里,batch_size设置为2,这意味着每个小批量将包含2个样本。

shuffle参数设置为True,表示要对数据进行随机打乱。

在遍历每个小批量数据时,可以在循环体内执行训练或评估操作。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/455070
推荐阅读
相关标签
  

闽ICP备14008679号