赞
踩
参考:pytorch教程之nn.Module类详解——使用Module类来自定义模型
一、简单的模型模板
1、定义网络结构
class MyNet(nn.Module):
# 初始化函数 __init__(self):
# 定义了具体网络有那些层,但并没有决定网络的结构。
def __init__(self) -> None:
super().__init__()
# 前向传播 forward():
# 函数定义了网络的的顺序
def forward(self, input):
2、实验一下
前向网络中对传入的值加1
import torch.nn as nn import torch class MyNet(nn.Module): def __init__(self) -> None: super().__init__() def forward(self, input): # input输入,output输出 output = input + 1 return output x = torch.tensor(1.0) # 初始化网络 MyNet = MyNet() output = MyNet(x) print(output)
输出:
二、自定义网络模型
class MyNet(nn.Module):
def __init__(self) -> None:
super(MyNet, self).__init__()
self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)
# 卷积层中stride默认为1,池化层中stride默认为kernel_size的大小
def forward(self, x):
x = self.conv1(x)
return x
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。