当前位置:   article > 正文

DataLoader 的 collate_fn 解释与示例教程

DataLoader 的 collate_fn 解释与示例教程

导包

import torch
from torch.utils.data import Dataset
from typing import Any
  • 1
  • 2
  • 3

数据

class CustomDataset(Dataset):
    
    def __init__(self, length) -> None:
        super().__init__()
        self.length = length
    
    def __getitem__(self, index=None):
        w1 = 3.14
        w2 = 4.27
        w = torch.tensor([w1, w2])
        feature = torch.rand(2) * 10
        noise = torch.randn_like(feature) * 0.01
        label = torch.matmul(w, feature.t())
        feature += noise
        # return feature, label.view(1)
        return feature, label
    
    def __len__(self):
        return self.length

dataset = CustomDataset(4)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

Dataloader

dataloader = torch.utils.data.DataLoader(
                    			dataset, 
                    			batch_size=2, 
								)

for feature, label in dataloader:
    print(feature.shape, label.shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

下述展示了,默认的 Dataload 的处理结果:
通过 torch.stack(feature),构建出 batch 数据;

torch.Size([2, 2]) torch.Size([2])
torch.Size([2, 2]) torch.Size([2])

常量直接拼接;
向量则会在前面添加一个 batch 纬度;

collate_fn

collate_fn:返回值为最终构建的batch数据;在这一步中处理dataset的数据,将其调整成我们期望的数据格式;

如上述默认的输出结果所示:label.shape 为 torch.Size([2]),笔者想通过 collate_fn 修改 label.shapetorch.Size([2, 1]),下述代码实现这个功能:

def collate_fn(item):
    feature, label = zip(*item)
    feature = torch.stack(feature)
    label = torch.stack(label)
    label = label.view(-1, 1)
    return feature, label
    
dataloader = torch.utils.data.DataLoader(
                    			dataset, 
                    			batch_size=2, 
                    			collate_fn=collate_fn
								)

for feature, label in dataloader:
    print(feature.shape, label.shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

输出如下:

torch.Size([2, 2]) torch.Size([2, 1])
torch.Size([2, 2]) torch.Size([2, 1])

collate_fn(item),传入的item的数据为:

[(tensor([[6.9436, 7.2040]]), tensor([[52.6007]])), (tensor([[7.1495, 2.8882]]), tensor([[34.7427]]))]
[(tensor([[1.5311, 9.9278]]), tensor([[47.1995]])), (tensor([[4.9614, 8.6232]]), tensor([[52.3849]]))]
  • 1
  • 2

feature, label = zip(*item) 故通过zip(*item)的方式,拆分出 feature 和 label 各自的数据,再借助torch.stack方法将其拼接出 batch 形状的数据。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/391992
推荐阅读
相关标签
  

闽ICP备14008679号