当前位置:   article > 正文

Pytorch的DataLoader和Dataset详解_train_test_split 和dataloader

train_test_split 和dataloader

Pytorch作为一个深度学习的框架,在学术界的使用已经很广泛了,我们现在来介绍一下如何把numpy的数据转化为Pytorch的Tensor,不懂Tensor的小伙伴可以自行百度一下。

        我这里简单介绍一下Tensor其实和Numpy的nparray基本上是一个事情,只不过我们的Tensor是我们Pytorch需要的数据类型,而且Pytorch所有的计算都是基于Tensor的,我们可以把它叫做张量,同时有的小伙伴也会把它叫做高维矩阵。这里不多重复。

        我们说一下重点,在Meachine Learning中,我们处理数据,我们一般都会在DataFrame,nparray这两个数据结构里面进行,当我们训练深度学习模型的时候才会用到Pytorch,这个时候我们需要把数据转化为Tensor,那么如何转化,当然直接使用

torch.from_numpy()

        这段代码是没有毛病的,而且我们多数时候也是这样转换的,但是当我们需要训练数据,通过每个min_batch的方式进行训练的时候我们这样转化的方式就是不对的,这样把所有的数据都转换为张量了我们还需要自己去切分Batch,当然你自己手动切分Batch是可以的,我也没说不行,但是,切分的代码,你都有可能写错,这样的话对后面的训练的影响是很大的,所以Pytorch官方推荐使用Dataset包装数据,使用DataLoader加载数据。

  1. from sklearn.model_selection import train_test_split
  2. import torch.nn as nn
  3. import torch
  4. from torch.utils.data import Dataset
  5. from torch.utils.data import DataLoader
  6. #切分训练数据和测试数据
  7. X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3,train_size=0.7,random_state=42)
  8. #使用torch.from_numpy 把数据转化为numpy
  9. x_train_tensor = torch.from_numpy(X_train).float()
  10. y_train_tensor = torch.from_numpy(y_train).float()
  11. #自己定义DataSet继承自torch.utils.data的DataSet
  12. class AirDataSet(Dataset):
  13. def __init__(self,x_tensor,y_tensor):
  14. self.x = x_tensor
  15. self.y = y_tensor
  16. def __getitem__(self, index):
  17. return(self.x[index], self.y[index])
  18. def __len__(self):
  19. return len(self.x)
  20. #放入数据 这样就是实现了数据的加载
  21. train_data = AirDataSet(x_train_tensor, y_train_tensor)
  22. #定义DataLoader加载器
  23. #有两个超参数需要定义,一个是batch_size一个是shuffle
  24. #batch_size 代表我们每一批次取出多少个进行训练
  25. #shuffle 代表我们是否需要重新洗牌 随机选取
  26. epochs = 180
  27. batch_size = 25
  28. train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)

        接下来就是定义模型,训练数据,数据如何从DataLoader里面取出来,如下图所示,我们通过for直接就可以取出来每个Batch里面的数据。

  1. for epoch in range(epochs):
  2. train_loss = 0
  3. for batch_idx, (x_batch, y_batch) in enumerate(train_loader):

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/318259
推荐阅读
相关标签
  

闽ICP备14008679号