当前位置:   article > 正文

python常用函数(一)_nn.sequential的文件维度

nn.sequential的文件维度

repeat()

功能:指定维度上的元素重复n次。
例:

a = torch.rand(12,512,1,64)
b = a.repeat(1,1,32,1)  
  • 1
  • 2

表示第2维上的元素重复32次,其他维度为1表示重复1次,
也就是这维的元素不变动
这样b的维度就是(12,512,32,64)

如果现在这样:

b = a.repeat(1,2,32,2)  
  • 1

表示第0维不变,第1维所有元素重复2次,第2维元素重复32次,第3维元素重复2次
b的维度:(12,1024,32,128)
b对应维度元素个数=a上该维度元素个数×重复的次数。。。
即:b的维度(12×1,512×2,1×32,64×2) = (12,1024,32,128)

repeat()函数在深度学习网络之间的维度拼接中发挥很大的作用,当两个特征的维度不一致时,但需要将两个特征融合在一起时,就可以通过repeat()函数将两者的维度化为一致,再融合。
例:

a = torch.rand(12,512,32,64)
b = torch.rand(12,512,1,64)
  • 1
  • 2

a和b的第三维表示的是特征,现在需要将a和b第三维特征进行拼接,但由于第2维的维度不一致,则无法拼接。因此引入repeat()函数

c = torch.cat([a, b.repeat(1, 1, 32, 1)], dim=-1)
  • 1

b.repeat(1, 1, 32, 1) → \to b的维度(12,512,32,64)
dim = -1 表示对最后一维的元素进行拼接。
结果c的维度为:(12,512,32,128)

torch.randint()

torch.randint(low, high, size)Tensor
生成维度为size大小的张量,其中的数据的范围为[low, high]
例:

labels = torch.randint(0, 3, (5, 6))
>>> tensor([[0, 0, 2, 1, 0, 0],
        [0, 2, 1, 1, 2, 2],
        [1, 0, 2, 1, 2, 0],
        [1, 2, 0, 0, 0, 2],
        [0, 2, 1, 2, 1, 0]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

nn.ModuleList()

  • 这是一个储存不同module,并自动将每个moduleparameters添加到网络之中的容器(list)。
  • 它并没有定义一个网络,只是将不同的模块储存在一起,这些模块之间并没有什么先后顺序可言,
  • 没有实现 forward 功能,可以把添加到其中的模块和参数自动注册到网络上。
  • 如果模块被调用多次,那么它们是使用同一组的 parameters ,也就是它们的参数是共享的,无论之后怎么更新。
  • 使用两次可以实现参数共享,比如可用在要研究新的需要共享参数的网络结构。
  • 如果将它们按顺序摆放,那么利用for循环所加载的模块就是按顺序输出。灵活的是,可以通过索引加载指定位置的模块。
    例:
self.SA_modules = nn.ModuleList()
        self.SA_modules.append(PointWebSAModule(npoint=1024, nsample=32, mlp=[c, 32, 32, 64], mlp2=[32, 32], use_xyz=use_xyz))
        self.SA_modules.append(PointWebSAModule(npoint=256, nsample=32, mlp=[64, 64, 64, 128], mlp2=[32, 32], use_xyz=use_xyz))
        self.SA_modules.append(PointWebSAModule(npoint=64, nsample=32, mlp=[128, 128, 128, 256], mlp2=[32, 32], use_xyz=use_xyz))
        self.SA_modules.append(PointWebSAModule(npoint=16, nsample=32, mlp=[256, 256, 256, 512], mlp2=[32, 32], use_xyz=use_xyz))

def forward(self, pointcloud: torch.cuda.FloatTensor):
        xyz, features = self._break_up_pc(pointcloud)    
        l_xyz, l_features = [xyz], [features]       
		for i in range(len(self.SA_modules)):
            li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

其中,self.SA_modules[i] 中的索引i就是指定容器中对应位置的模块,利用for循环就可以依次加载这些模块。

nn.Sequential()

  • 里面的模块必须是按照顺序进行排列的,必须确保前一个模块的输出大小和下一个模块的输入大小一致。
  • 使用OrderedDict([ ])可以指定每个 module 的名字,而不是采用默认的命名方式 (按序号 0,1,2,3…) 。
  • nn.Sequential就是一个 nn.Module 的子类,也就是 nn.Module 所有的方法它都有。
  • 使用 nn.Sequential 就不用写 forward 函数,因为这函数内部已经帮你写好了。
  • 相同点:这两个类都是用于网络模块的构建,使构建模块过程的代码更加简洁灵活便捷。
  • 区别nn.ModuleList()具有灵活性,可以在forward中自由排列模块顺序;nn.Sequential()比较固定,其中模块顺序摆好后就不能改变。

参考:PyTorch 中的 ModuleList 和 Sequential: 区别和使用场景

add_module()

  • 该函数是Module类的成员函数,用于动态调整网络结构。
  • 输入参数为Module.add_module(name: str, module: Module),为Module添加一个子module,对应名字为name

参考:使用add_module替换部分模型

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/653125
推荐阅读
相关标签
  

闽ICP备14008679号