当前位置:   article > 正文

pytorch学习二:模型定义方法_torch sequential 子类

torch sequential 子类

模型构造

继承nn.Module类来构造模型

一、nn.Module的三个子类

Torch封装了三个子类Sequential(),ModuleList(),ModuleDict()。可以更加方便的定义模型。

1.Sequential()

类似于keras的Sequential()。内部的模块要求按顺序排序,并且要求维度匹配。

使用如下:

  1. #1 Sequential()
  2. net1 = nn.Sequential(
  3. nn.Linear(784,128),
  4. nn.ReLU(),
  5. nn.Linear(128,10)
  6. )
  7. print(net1)
  8. x = torch.rand(size=(32,784))
  9. print(net1(x))

输出结果:

  1. MyModule(
  2. (linears): ModuleList(
  3. (0): Linear(in_features=10, out_features=10, bias=True)
  4. (1): Linear(in_features=10, out_features=10, bias=True)
  5. (2): Linear(in_features=10, out_features=10, bias=True)
  6. (3): Linear(in_features=10, out_features=10, bias=True)
  7. (4): Linear(in_features=10, out_features=10, bias=True)
  8. )
  9. )

 2.ModileList()类

  1. net = nn.ModuleList([nn.Linear(784,256),nn.ReLU(),nn.Linear(256,10)])
  2. net.append(nn.Linear(10,2))
  3. print(net)

(1)ModuleList()个Sequential()的区别

sequential()和ModuleList()都接受网络层列表作为参数。区别在于:

  • Sequential()内部的网络层需要按顺序排列,且维度要匹配,内部forward()函数已实现。(类似于keras 的Sequential())
  • ModeleList()各模块之间没有联系也没有顺序。不要求维度匹配。内部forward()函数没有实现。ModuleList只是让定义网络前向传播更加方便,
  1. class MyModule(nn.Module):
  2. def __init__(self):
  3. super(MyModule,self).__init__()
  4. self.linears = nn.ModuleList([nn.Linear(10,10) for i in range(5)])
  5. def forward(self,x):
  6. for i,L in enumerate(self.linears):
  7. y = self.linears[i](x) + L(x)
  8. return y
  9. net2 = MyModule()
  10. print(net2)
  11. x = torch.rand(size=(24,10))
  12. print(net2(x))

(2)ModuleList()和Python List的区别

先来看在定义模型时,所有网络层参数会自动加到net.parameters()中。

  1. class Module_test(nn.Module):
  2. def __init__(self):
  3. super(Module_test,self).__init__()
  4. self.linears = nn.Linear(10,10)
  5. net5 = Module_test()
  6. for i in net5.parameters():
  7. print(i.size())

运行结果:

  1. net5
  2. torch.Size([10, 10])
  3. torch.Size([10])

ModuleList不同于一般的python list。ModuleList里所有模型参数都会自动加到模型参数中。

  1. class Module_ModuleList(nn.Module):
  2. def __init__(self):
  3. super(Module_ModuleList,self).__init__()
  4. self.linears = nn.ModuleList([nn.Linear(10,10) for i in range(5)])
  5. class Module_List(nn.Module):
  6. def __init__(self):
  7. super(Module_List,self).__init__()
  8. self.linears = [nn.Linear(10,10) for i in range(5)]
  9. net3 = Module_ModuleList()
  10. net4 = Module_List()
  11. net5 = Module_test()
  12. print('net3')
  13. for i in net3.parameters():
  14. print(i.size())
  15. print('net4')
  16. for i in net4.parameters():
  17. print(i.size())

运行结果:

  1. net3
  2. torch.Size([10, 10])
  3. torch.Size([10])
  4. torch.Size([10, 10])
  5. torch.Size([10])
  6. torch.Size([10, 10])
  7. torch.Size([10])
  8. torch.Size([10, 10])
  9. torch.Size([10])
  10. torch.Size([10, 10])
  11. torch.Size([10])
  12. net4

3.ModuleDict()

ModuleDict()类似于ModuleList(),但是结果模块字典作为输入。

  1. net = nn.ModuleDict({
  2. 'linear': nn.Linear(10,10),
  3. 'act': nn.ReLU()}
  4. )
  5. print(net)
)输出结果:
  1. ModuleDict(
  2. (act): ReLU()
  3. (linear): Linear(in_features=10, out_features=10, bias=True)
  4. )

二 继承nn.Module类构建复杂模型

上述子类在构建模型更简单,但不能用于构建复杂模型。例如多输入模型。也就是其封装的更高级,但缺乏灵活性。

可以直接继承nn.Module()类构建更在复杂的模型。例如模块的复用,也可以使用python循环或者控制流。

  1. class FancyMLP(nn.Module):
  2. def __init__(self,**kwargs):
  3. super(FancyMLP,self).__init__()
  4. self.rand_weight = torch.rand((20,20),requires_grad=False)
  5. self.linear = nn.Linear(20,20)
  6. def forward(self,x):
  7. x = self.linear(x)
  8. x = nn.functional.relu(torch.mm(x,self.rand_weight.data)+1)
  9. #复用网络层
  10. x = self.linear(x)
  11. #可用python的控制流或循环控制
  12. while x.norm().item() > 1:
  13. x /=2
  14. if x.norm().item() < 0.8:
  15. x *= 10
  16. return x.sum()
  17. net6 = FancyMLP()
  18. print(net6)
  19. for i in net6.parameters():
  20. print(i.size())

结果输出:

  1. FancyMLP(
  2. (linear): Linear(in_features=20, out_features=20, bias=True)
  3. )
  4. torch.Size([20, 20])
  5. torch.Size([20])

总结:

  1. Sequential()可按照模块训练构建模型,forward函数已实现。
  2. ModuleList()和ModuleDict()只是将多个模型放到一起。并没有实现forward()函数。一般用在继承nn.Module定义模型的初始化函数__init__中。
  3. 可以像操作python列表或字典一样操作ModuleList和ModuleDict。例如ModuleList.append()。ModuleList[i]。
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/653118
推荐阅读
相关标签
  

闽ICP备14008679号