赞
踩
最近在学习pyTorch, 在阅读pytorch教程的时候,发现有一个简单的卷积神经网络,之前搞明白过这个过程,时间太久,都忘的差不多了, 正好写个笔记记录总结一下
代码如下:
#! usr/bin/env python3 # -*- coding:utf-8 -*- """ @Author:MaCan @Time:2019/10/29 19:59 @File:torch_net.py @Mail:ma_cancan@163.com """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import torch import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 10) def forward(self, x): x = self.conv1(x) print('x1 {}'.format(x.size())) x = F.max_pool2d(F.relu(x), (2, 2)) print('x2 {}'.format(x.size())) x = self.conv2(x) print('x3 {}'.format(x.size())) x = F.max_pool2d(F.relu(x), 2) print('x4: {}'.format(x.size())) x = x.view(-1, self.num_flat_features(x)) print('x5: {}'.format(x.size())) x = F.relu(self.fc1(x)) print('x6: {}'.format(x.size())) x = F.relu(self.fc2(x)) print('x7: {}'.format(x.size())) x = self.fc3(x) print('x8: {}'.format(x.size())) return x def num_flat_features(self, x): size = x.size()[1:] # 除了batch 外的其他纬度值 print('size: {}'.format(size)) num_features = 1 for s in size: num_features *= s return num_features if __name__ == '__main__': net = Net() print(net) input = torch.randn(1, 1, 32, 32) out = net(input) print(out)
运行这个代码,输出如下:
Net( (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1)) (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) (fc1): Linear(in_features=400, out_features=128, bias=True) (fc2): Linear(in_features=128, out_features=64, bias=True) (fc3): Linear(in_features=64, out_features=10, bias=True) ) x1 torch.Size([1, 6, 28, 28]) x2 torch.Size([1, 6, 14, 14]) x3 torch.Size([1, 16, 10, 10]) x4: torch.Size([1, 16, 5, 5]) size: torch.Size([16, 5, 5]) x5: torch.Size([1, 400]) x6: torch.Size([1, 128]) x7: torch.Size([1, 64]) x8: torch.Size([1, 10]) tensor([[ 0.0684, 0.0224, -0.0527, -0.1091, 0.0603, -0.0389, -0.0848, -0.0689, 0.0107, -0.0398]], grad_fn=<AddmmBackward>)
为了观察每个过程的维度变化,我写了一些print操作,其实维度变化已经很明显了,下面来具体计算一下每个维度是怎么计算的到的
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。