赞
踩
1.导入库函数
- import torch as t
- from torch.utils.data import Dataset,DataLoader
- from torch import optim
- from torchvision import transforms
- from torchvision.datasets import ImageFolder
2.数据处理
- device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
-
- transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(), # 转为Tensor
- transforms.Normalize((0.5,), (0.5,)), ])# 归一化
-
- train = ImageFolder('D:/py_files/dog_cat/data/dogcat',transform)
- valid = ImageFolder('D:/py_files/dog_cat/data/dogcat',transform)
-
- train_dataloader = t.utils.data.DataLoader(train,batch_size=32,num_workers=0,shuffle=True)
- test_dataloader = t.utils.data.DataLoader(valid,batch_size=32,num_workers=0,shuffle=False)
3.定义模型
自定义了VGG模型
- import torch.nn as nn
- import
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。