当前位置:   article > 正文

每天学点pytorch--torch.nn.ReLU(inplace=False)中inplace的作用

torch.nn.relu(inplace=false)

记录pytorch中遇到的一些问题,文章没有顺序关系

官方连接:

ReLU — PyTorch 1.10.0 documentation

inplace为True时,计算结果会对原来的结果进行覆盖。

还是看下pytorch中的具体操作:

  1. >>> import torch
  2. >>> import torch.nn as nn
  3. >>> conv1 = nn.Conv2d(3, 3, kernel_size=3)
  4. >>> rl1 = nn.ReLU(inplace=True)
  5. >>> rl2 = nn.ReLU()
  6. >>> input = torch.randn(1,3,5,5)
  7. >>> o1 = conv1(input)
  8. >>> id(o1)
  9. 139670453299872
  10. >>> o1
  11. tensor([[[[-0.1162, 0.5905, 1.0601],
  12. [-0.1423, 0.7013, 0.1079],
  13. [ 0.1096, -0.3253, -0.6799]],
  14. [[ 0.3407, 0.5013, -0.2121],
  15. [-0.6805, -0.8362, 0.3360],
  16. [ 1.1606, -0.2564, 0.2965]],
  17. [[ 0.4317, -0.2480, 0.2381],
  18. [-0.0314, -0.0850, 0.1920],
  19. [-0.2762, 0.0338, -0.2298]]]], grad_fn=<ThnnConv2DBackward>)
  20. >>> h1 = rl1(o1)
  21. >>> id(h1)
  22. 139670453299872 #和o1的id一样,说明h1和o1指向同一个地方
  23. >>> o1 # o1的值发生了变化,inplace操作起了作用
  24. tensor([[[[0.0000, 0.5905, 1.0601],
  25. [0.0000, 0.7013, 0.1079],
  26. [0.1096, 0.0000, 0.0000]],
  27. [[0.3407, 0.5013, 0.0000],
  28. [0.0000, 0.0000, 0.3360],
  29. [1.1606, 0.0000, 0.2965]],
  30. [[0.4317, 0.0000, 0.2381],
  31. [0.0000, 0.0000, 0.1920],
  32. [0.0000, 0.0338, 0.0000]]]], grad_fn=<ReluBackward1>)
  33. >>> h1
  34. tensor([[[[0.0000, 0.5905, 1.0601],
  35. [0.0000, 0.7013, 0.1079],
  36. [0.1096, 0.0000, 0.0000]],
  37. [[0.3407, 0.5013, 0.0000],
  38. [0.0000, 0.0000, 0.3360],
  39. [1.1606, 0.0000, 0.2965]],
  40. [[0.4317, 0.0000, 0.2381],
  41. [0.0000, 0.0000, 0.1920],
  42. [0.0000, 0.0338, 0.0000]]]], grad_fn=<ReluBackward1>)

从上面的操作可以看出,如果采用inplace操作,输入参数o1的结果被直接修改。

下面看下inplace为False时的结果:

  1. >>> o1 = conv1(input)
  2. >>> id(o1)
  3. 139670453299712
  4. >>> o1
  5. tensor([[[[-0.1162, 0.5905, 1.0601],
  6. [-0.1423, 0.7013, 0.1079],
  7. [ 0.1096, -0.3253, -0.6799]],
  8. [[ 0.3407, 0.5013, -0.2121],
  9. [-0.6805, -0.8362, 0.3360],
  10. [ 1.1606, -0.2564, 0.2965]],
  11. [[ 0.4317, -0.2480, 0.2381],
  12. [-0.0314, -0.0850, 0.1920],
  13. [-0.2762, 0.0338, -0.2298]]]], grad_fn=<ThnnConv2DBackward>)
  14. >>> h1 = rl2(o1) # relu的inplace为False
  15. >>> id(h1) #h1和o1的id不一样了
  16. 139670453330560
  17. >>> o1 #查看o1的值,发现没有改变
  18. tensor([[[[-0.1162, 0.5905, 1.0601],
  19. [-0.1423, 0.7013, 0.1079],
  20. [ 0.1096, -0.3253, -0.6799]],
  21. [[ 0.3407, 0.5013, -0.2121],
  22. [-0.6805, -0.8362, 0.3360],
  23. [ 1.1606, -0.2564, 0.2965]],
  24. [[ 0.4317, -0.2480, 0.2381],
  25. [-0.0314, -0.0850, 0.1920],
  26. [-0.2762, 0.0338, -0.2298]]]], grad_fn=<ThnnConv2DBackward>)
  27. >>> h1 #h1是经过了relu操作后的结果
  28. tensor([[[[0.0000, 0.5905, 1.0601],
  29. [0.0000, 0.7013, 0.1079],
  30. [0.1096, 0.0000, 0.0000]],
  31. [[0.3407, 0.5013, 0.0000],
  32. [0.0000, 0.0000, 0.3360],
  33. [1.1606, 0.0000, 0.2965]],
  34. [[0.4317, 0.0000, 0.2381],
  35. [0.0000, 0.0000, 0.1920],
  36. [0.0000, 0.0338, 0.0000]]]], grad_fn=<ReluBackward0>)

inplace为False时,不修改输入的值,而是生成一个新的对象,符合预期。

采用原地操作可以节省内存,但是在多分支(Multi-branch)的网络中,使用时需要注意,比如:

  1. conv1 = nn.Conv2d(3, 3, kernel_size=3)
  2. conv2 = nn.Conv2d(3, 3, kernel_size=3)
  3. rl1 = nn.ReLU(inplace=True)
  4. ...
  5. x = conv1(x)
  6. h1 = rl1(x)
  7. h2 = conv2(x) # 此时x的值可能已经变化了

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/在线问答5/article/detail/963768
推荐阅读
相关标签
  

闽ICP备14008679号