当前位置:   article > 正文

pytorch常用函数/接口,nn.ReLU和F.ReLU,torch.transpose,permute...___constants__ = ['inplace']

__constants__ = ['inplace']

1、nn.ReLU(inplace=True)中inplace的作用

这里面inplace的意思是将计算得到的值直接覆盖前面的值。
官网nn.ReLU介绍

类似于下面

# inplace False
y = x+1
x = x
#inplace True
x = x + 1
  • 1
  • 2
  • 3
  • 4
  • 5

2.nn.ReLU, F.ReLU区别与联系

其中nn.ReLU作为一个层结构,必须添加到nn.Module容器中才能使用,而F.ReLU则作为一个函数调用,看上去作为一个函数调用更方便更简洁。具体使用哪种方式,取决于编程风格。在PyTorch中,nn.X都有对应的函数版本F.X,但是并不是所有的F.X均可以用于forward或其它代码段中,因为当网络模型训练完毕时,在存储model时,在forward中的F.X函数中的参数是无法保存的。也就是说,在forward中,使用的F.X函数一般均没有状态参数,比如F.ReLU,F.avg_pool2d等,均没有参数,它们可以用在任何代码片段中。

nn.ReLu里面其实就是调用了F.ReLU

class ReLU(Module):
    r"""Applies the rectified linear unit function element-wise:

    :math:`\text{ReLU}(x)= \max(0, x)`

    Args:
        inplace: can optionally do the operation in-place. Default: ``False``

    Shape:
        - Input: :math:`(N, *)` where `*` means, any number of additional
          dimensions
        - Output: :math:`(N, *)`, same shape as the input

    .. image:: scripts/activation_images/ReLU.png

    Examples::

        >>> m = nn.ReLU()
        >>> input = torch.randn(2)
        >>> output = m(input)


      An implementation of CReLU - https://arxiv.org/abs/1603.05201

        >>> m = nn.ReLU()
        >>> input = torch.randn(2).unsqueeze(0)
        >>> output = torch.cat((m(input),m(-input)))
    """
    __constants__ = ['inplace']

    def __init__(self, inplace=False):
        super(ReLU, self).__init__()
        self.inplace = inplace

    @weak_script_method
    def forward(self, input):
      # F 来自于 import nn.functional as F
        return F.relu(input, inplace=self.inplace)

    def extra_repr(self):
        inplace_str = 'inplace' if self.inplace else ''
        return inplace_str
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42

3.torch.transpose(tensor,dim1,dim2) tensor.permute(dim1,dim2,dim3)

这两个都可以对tensor做维度转换
区别:transpose只能在两个维度之间转换, permute可以一下转换好几个维度
多个transpose转换就相当于permute

4.torch.contiguous

该函数功能是把原来的tensor放到一个连续的空间,也就是一个在内存中连续存储的张量
放在transpose,permute之后,view之前。
测试过transpose之后直接view,报错
可能是view必须作用于内存连续的张量。
PS: reshape ≈ \approx tensor.contiguous().view

5.torch.triu

官网介绍

torch.triu(input, diagonal=0, out=None) → Tensor

返回一个矩阵的上三角部分。

a = torch.randn(3, 3)
tensor([[ 0.2309, 0.5207, 2.0049],
[ 0.2072, -1.0680, 0.6602],
[ 0.3480, -0.5211, -0.4573]])
torch.triu(a)
tensor([[ 0.2309, 0.5207, 2.0049],
[ 0.0000, -1.0680, 0.6602],
[ 0.0000, 0.0000, -0.4573]])

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/306213
推荐阅读
相关标签
  

闽ICP备14008679号