当前位置:   article > 正文

Pytorch实现CNN的通用框架

Pytorch实现CNN的通用框架

1.导入库函数

  1. import torch as t
  2. from torch.utils.data import Dataset,DataLoader
  3. from torch import optim
  4. from torchvision import transforms
  5. from torchvision.datasets import ImageFolder

2.数据处理

  1. device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
  2. transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(), # 转为Tensor
  3. transforms.Normalize((0.5,), (0.5,)), ])# 归一化
  4. train = ImageFolder('D:/py_files/dog_cat/data/dogcat',transform)
  5. valid = ImageFolder('D:/py_files/dog_cat/data/dogcat',transform)
  6. train_dataloader = t.utils.data.DataLoader(train,batch_size=32,num_workers=0,shuffle=True)
  7. test_dataloader = t.utils.data.DataLoader(valid,batch_size=32,num_workers=0,shuffle=False)

3.定义模型

自定义了VGG模型

  1. import torch.nn as nn
  2. import
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/691108
推荐阅读
相关标签
  

闽ICP备14008679号