赞
踩
import torch
from torch.utils.data import Dataset
from typing import Any
class CustomDataset(Dataset): def __init__(self, length) -> None: super().__init__() self.length = length def __getitem__(self, index=None): w1 = 3.14 w2 = 4.27 w = torch.tensor([w1, w2]) feature = torch.rand(2) * 10 noise = torch.randn_like(feature) * 0.01 label = torch.matmul(w, feature.t()) feature += noise # return feature, label.view(1) return feature, label def __len__(self): return self.length dataset = CustomDataset(4)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=2,
)
for feature, label in dataloader:
print(feature.shape, label.shape)
下述展示了,默认的 Dataload 的处理结果:
通过 torch.stack(feature)
,构建出 batch 数据;
torch.Size([2, 2]) torch.Size([2])
torch.Size([2, 2]) torch.Size([2])
常量直接拼接;
向量则会在前面添加一个 batch 纬度;
collate_fn
:返回值为最终构建的batch数据;在这一步中处理dataset的数据,将其调整成我们期望的数据格式;
如上述默认的输出结果所示:label.shape
为 torch.Size([2]),笔者想通过 collate_fn
修改 label.shape
为torch.Size([2, 1])
,下述代码实现这个功能:
def collate_fn(item):
feature, label = zip(*item)
feature = torch.stack(feature)
label = torch.stack(label)
label = label.view(-1, 1)
return feature, label
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=2,
collate_fn=collate_fn
)
for feature, label in dataloader:
print(feature.shape, label.shape)
输出如下:
torch.Size([2, 2]) torch.Size([2, 1])
torch.Size([2, 2]) torch.Size([2, 1])
在collate_fn(item)
,传入的item的数据为:
[(tensor([[6.9436, 7.2040]]), tensor([[52.6007]])), (tensor([[7.1495, 2.8882]]), tensor([[34.7427]]))]
[(tensor([[1.5311, 9.9278]]), tensor([[47.1995]])), (tensor([[4.9614, 8.6232]]), tensor([[52.3849]]))]
feature, label = zip(*item)
故通过zip(*item)
的方式,拆分出 feature 和 label 各自的数据,再借助torch.stack
方法将其拼接出 batch 形状的数据。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。