当前位置:   article > 正文

Pytorch搭建U-Net网络_pytroch官网上有u-net网络训练过程吗

pytroch官网上有u-net网络训练过程吗

1、Pytorch

原来常用keras搭建网络模型,后来发现keras的训练模型速度和测试速度都较慢,因此转向使用pytorch,其实两者使用难度差不多,都是高层的深度学习框架,适合研究深度学习。

2、U_Net网络介绍

U_Net网络已经提出很早,常被用在图像语义分割领域。模型的主要结构如下图所示,包括下采样和上采样两个过程。为了保证上采样得到的特征图具有较强的语义信息、提高分割的精准度。会在上采样过程中进行通道拼接再卷积。

3、Pytorch代码

3.1、导入包

导入torch,至于为什么导入numpy,显然是因为喜欢。

  1. import torch
  2. from torch import nn
  3. import numpy as np

3.2、下采样模块

下采样模块本文采用了通用卷积进行搭建,当然在语义分割中用的较多的空洞卷积,以及残差结构对网络性能都是有提升效果的。BN层的作用当然是网络节点输出更加稳定,一定程度上能够缓解梯度爆炸和梯度消失问题。激活函数这里使用了Relu6,同样也是考虑到了数据分布,因为通常图片数据在进入网络模型前会进行标准化处理。

  1. class block_down(nn.Module):
  2. def __init__(self,inp_channel,out_channel):
  3. super(block_down,self).__init__()
  4. self.conv1=nn.Conv2d(inp_channel,out_channel,3,padding=1)
  5. self.conv2=nn.Conv2d(out_channel,out_channel,3,padding=1)
  6. self.bn=nn.BatchNorm2d(out_channel)
  7. self.relu=nn.ReLU6(inplace=True)
  8. def forward(self,x):
  9. x=self.conv1(x)
  10. x=self.bn(x)
  11. x=self.relu(x)
  12. x=self.conv2(x)
  13. x=self.bn(x)
  14. x=self.relu(x)
  15. return x

3.3、上采样模块

上采样模块先使用转置卷积进行~~额-->上采样。

常规卷积的输入和输出尺寸关系是:

out_size=(inp_size-f+2p)/stride +1

转置卷积为:

out_size=(inp_size-1)*stride +f

式中f是卷积核(kernel)的尺寸,stride是卷积核滑动步长。

转置卷积的作用显而易见是~~额-->回到过去。

  1. class block_up(nn.Module):
  2. def __init__(self,inp_channel,out_channel,y):
  3. super(block_up,self).__init__()
  4. self.up=nn.ConvTranspose2d(inp_channel,out_channel,2,stride=2)
  5. self.conv1=nn.Conv2d(inp_channel,out_channel,3,padding=1)
  6. self.conv2=nn.Conv2d(out_channel,out_channel,3,padding=1)
  7. self.bn=nn.BatchNorm2d(out_channel)
  8. self.relu=nn.ReLU6(inplace=True)
  9. self.y=y
  10. def forward(self,x):
  11. x=self.up(x)
  12. x=torch.cat([x,self.y],dim=1)
  13. x=self.conv1(x)
  14. x=self.bn(x)
  15. x=self.relu(x)
  16. x=self.conv2(x)
  17. x=self.bn(x)
  18. x=self.relu(x)
  19. return x

3.4、用模块搭建整体网络

至于为什么写block再搭建会显得有点麻烦,原因一是结构清楚,二是超参数更易修改,三是显而易见是我喜欢。

  1. class U_net(nn.Module):
  2. def __init__(self,out_channel):
  3. super(U_net,self).__init__()
  4. self.out=nn.Conv2d(64,out_channel,1)
  5. self.maxpool=nn.MaxPool2d(2)
  6. def forward(self,x):
  7. block1=block_down(3,64)
  8. x1_use=block1(x)
  9. x1=self.maxpool(x1_use)
  10. block2=block_down(64,128)
  11. x2_use=block2(x1)
  12. x2=self.maxpool(x2_use)
  13. block3=block_down(128,256)
  14. x3_use=block3(x2)
  15. x3=self.maxpool(x3_use)
  16. block4=block_down(256,512)
  17. x4_use=block4(x3)
  18. x4=self.maxpool(x4_use)
  19. block5=block_down(512,1024)
  20. x5=block5(x4)
  21. block6=block_up(1024,512,x4_use)
  22. x6=block6(x5)
  23. block7=block_up(512,256,x3_use)
  24. x7=block7(x6)
  25. block8=block_up(256,128,x2_use)
  26. x8=block8(x7)
  27. block9=block_up(128,64,x1_use)
  28. x9=block9(x8)
  29. x10=self.out(x9)
  30. out=nn.Softmax2d()(x10)
  31. return out

4、测试

输入形状和输出相同,完成搭建。

input_size: torch.Size([1, 3, 480, 640])
output_size: torch.Size([1, 3, 480, 640])
  1. if __name__=="__main__":
  2. test_input=torch.rand(1, 3, 480, 640)
  3. print("input_size:",test_input.size())
  4. model=U_net(out_channel=3)
  5. ouput=model(test_input)
  6. print("output_size:",ouput.size())

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/87690
推荐阅读
相关标签
  

闽ICP备14008679号