赞
踩
功能:指定维度上的元素重复n次。
例:
a = torch.rand(12,512,1,64)
b = a.repeat(1,1,32,1)
表示第2维上的元素重复32次,其他维度为1表示重复1次,
也就是这维的元素不变动
这样b的维度就是(12,512,32,64)
如果现在这样:
b = a.repeat(1,2,32,2)
表示第0维不变,第1维所有元素重复2次,第2维元素重复32次,第3维元素重复2次
b
的维度:(12,1024,32,128)
b
对应维度元素个数=a上该维度元素个数×重复的次数。。。
即:b
的维度(12×1,512×2,1×32,64×2) = (12,1024,32,128)
repeat()函数在深度学习网络之间的维度拼接中发挥很大的作用,当两个特征的维度不一致时,但需要将两个特征融合在一起时,就可以通过repeat()函数将两者的维度化为一致,再融合。
例:
a = torch.rand(12,512,32,64)
b = torch.rand(12,512,1,64)
a和b的第三维表示的是特征,现在需要将a和b第三维特征进行拼接,但由于第2维的维度不一致,则无法拼接。因此引入repeat()
函数
c = torch.cat([a, b.repeat(1, 1, 32, 1)], dim=-1)
b.repeat(1, 1, 32, 1)
→
\to
→ b
的维度(12,512,32,64)
dim = -1
表示对最后一维的元素进行拼接。
结果c的维度为:(12,512,32,128)
torch.randint(low, high, size)
→ Tensor
生成维度为size
大小的张量,其中的数据的范围为[low, high]
。
例:
labels = torch.randint(0, 3, (5, 6))
>>> tensor([[0, 0, 2, 1, 0, 0],
[0, 2, 1, 1, 2, 2],
[1, 0, 2, 1, 2, 0],
[1, 2, 0, 0, 0, 2],
[0, 2, 1, 2, 1, 0]])
module
,并自动将每个module
的parameters
添加到网络之中的容器(list)。forward
功能,可以把添加到其中的模块和参数自动注册到网络上。parameters
,也就是它们的参数是共享的,无论之后怎么更新。for
循环所加载的模块就是按顺序输出。灵活的是,可以通过索引加载指定位置的模块。self.SA_modules = nn.ModuleList()
self.SA_modules.append(PointWebSAModule(npoint=1024, nsample=32, mlp=[c, 32, 32, 64], mlp2=[32, 32], use_xyz=use_xyz))
self.SA_modules.append(PointWebSAModule(npoint=256, nsample=32, mlp=[64, 64, 64, 128], mlp2=[32, 32], use_xyz=use_xyz))
self.SA_modules.append(PointWebSAModule(npoint=64, nsample=32, mlp=[128, 128, 128, 256], mlp2=[32, 32], use_xyz=use_xyz))
self.SA_modules.append(PointWebSAModule(npoint=16, nsample=32, mlp=[256, 256, 256, 512], mlp2=[32, 32], use_xyz=use_xyz))
def forward(self, pointcloud: torch.cuda.FloatTensor):
xyz, features = self._break_up_pc(pointcloud)
l_xyz, l_features = [xyz], [features]
for i in range(len(self.SA_modules)):
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
其中,self.SA_modules[i]
中的索引i
就是指定容器中对应位置的模块,利用for
循环就可以依次加载这些模块。
OrderedDict([ ])
可以指定每个 module
的名字,而不是采用默认的命名方式 (按序号 0,1,2,3…) 。nn.Sequential
就是一个 nn.Module
的子类,也就是 nn.Module
所有的方法它都有。nn.Sequential
就不用写 forward
函数,因为这函数内部已经帮你写好了。nn.ModuleList()
具有灵活性,可以在forward
中自由排列模块顺序;nn.Sequential()
比较固定,其中模块顺序摆好后就不能改变。参考:PyTorch 中的 ModuleList 和 Sequential: 区别和使用场景
Module.add_module(name: str, module: Module)
,为Module
添加一个子module
,对应名字为name
。Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。