当前位置:   article > 正文

深度学习——01pytorch自定义网络模型_torch深度学习定义网络结构

torch深度学习定义网络结构

参考:pytorch教程之nn.Module类详解——使用Module类来自定义模型

一、简单的模型模板

1、定义网络结构

class MyNet(nn.Module):
	# 初始化函数  __init__(self):
	# 定义了具体网络有那些层,但并没有决定网络的结构。
    def __init__(self) -> None:
        super().__init__()
	# 前向传播  forward():
	# 函数定义了网络的的顺序
    def forward(self, input):
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

2、实验一下
前向网络中对传入的值加1

import torch.nn as nn
import torch

class MyNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, input):
    # input输入,output输出
        output = input + 1
        return output

x = torch.tensor(1.0)
# 初始化网络
MyNet = MyNet()
output = MyNet(x)
print(output)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

输出:
在这里插入图片描述
二、自定义网络模型

class MyNet(nn.Module):
    def __init__(self) -> None:
        super(MyNet, self).__init__()
        self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)
    	# 卷积层中stride默认为1,池化层中stride默认为kernel_size的大小

    def forward(self, x):
        x = self.conv1(x)
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/79200
推荐阅读
相关标签
  

闽ICP备14008679号