赞
踩
模型构造
继承nn.Module类来构造模型
Torch封装了三个子类Sequential(),ModuleList(),ModuleDict()。可以更加方便的定义模型。
类似于keras的Sequential()。内部的模块要求按顺序排序,并且要求维度匹配。
使用如下:
- #1 Sequential()
- net1 = nn.Sequential(
- nn.Linear(784,128),
- nn.ReLU(),
- nn.Linear(128,10)
- )
- print(net1)
- x = torch.rand(size=(32,784))
- print(net1(x))
输出结果:
- MyModule(
- (linears): ModuleList(
- (0): Linear(in_features=10, out_features=10, bias=True)
- (1): Linear(in_features=10, out_features=10, bias=True)
- (2): Linear(in_features=10, out_features=10, bias=True)
- (3): Linear(in_features=10, out_features=10, bias=True)
- (4): Linear(in_features=10, out_features=10, bias=True)
- )
- )
- net = nn.ModuleList([nn.Linear(784,256),nn.ReLU(),nn.Linear(256,10)])
- net.append(nn.Linear(10,2))
- print(net)
(1)ModuleList()个Sequential()的区别
sequential()和ModuleList()都接受网络层列表作为参数。区别在于:
- class MyModule(nn.Module):
- def __init__(self):
- super(MyModule,self).__init__()
- self.linears = nn.ModuleList([nn.Linear(10,10) for i in range(5)])
-
- def forward(self,x):
- for i,L in enumerate(self.linears):
- y = self.linears[i](x) + L(x)
- return y
- net2 = MyModule()
- print(net2)
- x = torch.rand(size=(24,10))
- print(net2(x))
(2)ModuleList()和Python List的区别
先来看在定义模型时,所有网络层参数会自动加到net.parameters()中。
- class Module_test(nn.Module):
- def __init__(self):
- super(Module_test,self).__init__()
- self.linears = nn.Linear(10,10)
- net5 = Module_test()
- for i in net5.parameters():
- print(i.size())
运行结果:
- net5
- torch.Size([10, 10])
- torch.Size([10])
ModuleList不同于一般的python list。ModuleList里所有模型参数都会自动加到模型参数中。
- class Module_ModuleList(nn.Module):
- def __init__(self):
- super(Module_ModuleList,self).__init__()
- self.linears = nn.ModuleList([nn.Linear(10,10) for i in range(5)])
- class Module_List(nn.Module):
- def __init__(self):
- super(Module_List,self).__init__()
- self.linears = [nn.Linear(10,10) for i in range(5)]
-
- net3 = Module_ModuleList()
- net4 = Module_List()
- net5 = Module_test()
- print('net3')
- for i in net3.parameters():
- print(i.size())
- print('net4')
- for i in net4.parameters():
- print(i.size())
运行结果:
- net3
- torch.Size([10, 10])
- torch.Size([10])
- torch.Size([10, 10])
- torch.Size([10])
- torch.Size([10, 10])
- torch.Size([10])
- torch.Size([10, 10])
- torch.Size([10])
- torch.Size([10, 10])
- torch.Size([10])
- net4
ModuleDict()类似于ModuleList(),但是结果模块字典作为输入。
- net = nn.ModuleDict({
- 'linear': nn.Linear(10,10),
- 'act': nn.ReLU()}
- )
- print(net)
)输出结果:
- ModuleDict(
- (act): ReLU()
- (linear): Linear(in_features=10, out_features=10, bias=True)
- )
上述子类在构建模型更简单,但不能用于构建复杂模型。例如多输入模型。也就是其封装的更高级,但缺乏灵活性。
可以直接继承nn.Module()类构建更在复杂的模型。例如模块的复用,也可以使用python循环或者控制流。
- class FancyMLP(nn.Module):
- def __init__(self,**kwargs):
- super(FancyMLP,self).__init__()
-
- self.rand_weight = torch.rand((20,20),requires_grad=False)
- self.linear = nn.Linear(20,20)
- def forward(self,x):
- x = self.linear(x)
- x = nn.functional.relu(torch.mm(x,self.rand_weight.data)+1)
-
- #复用网络层
- x = self.linear(x)
-
- #可用python的控制流或循环控制
- while x.norm().item() > 1:
- x /=2
- if x.norm().item() < 0.8:
- x *= 10
- return x.sum()
-
- net6 = FancyMLP()
- print(net6)
- for i in net6.parameters():
- print(i.size())
结果输出:
- FancyMLP(
- (linear): Linear(in_features=20, out_features=20, bias=True)
- )
- torch.Size([20, 20])
- torch.Size([20])
总结:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。