赞
踩
有一篇文献说 正则化可以加快训练速度
可以看到参数中需要自己填的不错,主要num_Features就是输入的 channel
torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
文档例子:
>>> m = nn.BatchNorm2d(100)
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm2d(100, affine=False)
>>> input = torch.randn(20, 100, 35, 45)可以看到batchsize=20,channel=100
>>> output = m(input)
线性叠加中的
k 即是 权重
b即是偏置
他们来源于分布u
以一个典型的网络结构
VGG 16 model为例
**
**
1、老代码复习
import torchvision
from torch.utils.data import DataLoader
dataset=torchvision.datasets.CIFAR10("./P_21_data",train=False,transform=torchvision.transforms.ToTensor(),
download=True)
dataloader=DataLoader(dataset,batch_size=64)
for data in dataloader:
imgs,targets=data
print(imgs.shape)
输出:
Extracting ./P_21_data\cifar-10-python.tar.gz to ./P_21_data
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
2、使用reshape,变换维度,我们想让高是1,宽自定
等下务必注意下他的维度改变逻辑
import torchvision
from torch.utils.data import DataLoader
import torch
dataset=torchvision.datasets.CIFAR10("./P_21_data",train=False,transform=torchvision.transforms.ToTensor(),
download=True)
dataloader=DataLoader(dataset,batch_size=64)
for data in dataloader:
imgs,targets=data
print(imgs.shape)
output=torch.reshape(imgs,(1,1,1,-1)) //因为宽自定,所以是-1//
print(output.shape)
输出:
torch.Size([16, 3, 32, 32])
torch.Size([1, 1, 1, 49152])
torch.Size([64, 3, 32, 32])
torch.Size([1, 1, 1, 196608])
维度逻辑:
我们发现16×3×32×32=49152
64×3×32×32=19608
所以他变换维度只是让四个维度的积保持不变
3、使用class nn.moudule
创造linar实例:
注意linear需要的输入输出
实例化
import torchvision from torch.utils.data import DataLoader import torch from torch import nn from torch.nn import Linear dataset=torchvision.datasets.CIFAR10("./P_21_data",train=False,transform=torchvision.transforms.ToTensor(), download=True) dataloader=DataLoader(dataset,batch_size=64) class Tudui(nn.Module): def __init__(self): super(Tudui, self).__init__() self.linar1=Linear(196608,10) //定义Tudui实例,继承nn.moudule,在init定义函数linar// //这里Linea需要写input,output// def forward(self,input): output=self.linar1(input) return output //定义forward 这样继承Tudui后参数会直接传入forward// for data in dataloader: imgs,targets=data print(imgs.shape) output=torch.reshape(imgs,(1,1,1,-1)) print(output.shape)
4、应用自己的tudui
import torchvision from torch.utils.data import DataLoader import torch from torch import nn from torch.nn import Linear dataset=torchvision.datasets.CIFAR10("./P_21_data",train=False,transform=torchvision.transforms.ToTensor(), download=True) dataloader=DataLoader(dataset,batch_size=64) class Tudui(nn.Module): def __init__(self): super(Tudui, self).__init__() self.linar1=Linear(196608,10) def forward(self,input): output=self.linar1(input) return output tudui=Tudui() for data in dataloader: imgs,targets=data print(imgs.shape) output=torch.reshape(imgs,(1,1,1,-1)) print(output.shape) output=tudui(output) print(output.shape)
看输出结果:
torch.Size([64, 3, 32, 32])
torch.Size([1, 1, 1, 196608])
torch.Size([1, 1, 1, 10])
我们发现:成功把196608的宽变为了10
后面微微报错:
这是因为我们抓拍最后一组抓的不足64,只有16
即便上面用了reshape:
他也是从torch.Size([16, 3, 32, 32])
变成了torch.Size([1, 1, 1, 49152])
而我们是这么定义的
5、修改方法1
直接不要最后一落牌:droplast=true
dataloader=DataLoader(dataset,batch_size=64,drop_last=True)
完事~没报错
修改方法2:不用drop——last,在class TUDUI中加一行判断 input的shape(我真聪明)
import torchvision from torch.utils.data import DataLoader import torch from torch import nn from torch.nn import Linear dataset=torchvision.datasets.CIFAR10("./P_21_data",train=False,transform=torchvision.transforms.ToTensor(), download=True) dataloader=DataLoader(dataset,batch_size=64) class Tudui(nn.Module): def __init__(self): super(Tudui, self).__init__() self.linar1=Linear(196608,10) self.linar2=Linear(49152,10) //因为前面我们输出过input.reshape发现是tensor类型的列表// //前面抓牌是196608,最后一摞是49152,所以我们重新定义一个函数linar2即可// def forward(self,input): if input.shape[3]==196608: //此处加一行判断,因为input的shape形式是【batchsize,channel,h,w】// //其中宽显然在列表的第四个// //所以判断宽如果是196608,那就用linar1// //否则就linar2// output=self.linar1(input) else: output=self.linar2(input) return output tudui=Tudui() for data in dataloader: imgs,targets=data print(imgs.shape) output=torch.reshape(imgs,(1,1,1,-1)) print(output.shape) output=tudui(output) print(output.shape)
最后果然没报错
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。