当前位置:   article > 正文

torch.norm()函数的用法

.norm()

目录

一、函数定义

二、代码示例

三、整体代码


一、函数定义

公式:

                                                                     ||x||p=x1p+x2p++xNpp

意思就是inputs的一共N维的话对这N个数据p范数,当然这个还是太抽象了,接下来还是看具体的代码~

p指的是求p范数的p值,函数默认p=2,那么就是求2范数

  1. def norm(self, input, p=2): # real signature unknown; restored from __doc__
  2. """
  3. .. function:: norm(input, p=2) -> Tensor
  4. Returns the p-norm of the :attr:`input` tensor.
  5. .. math::
  6. ||x||_{p} = \sqrt[p]{x_{1}^{p} + x_{2}^{p} + \ldots + x_{N}^{p}}
  7. Args:
  8. input (Tensor): the input tensor
  9. p (float, optional): the exponent value in the norm formulation
  10. Example::
  11. >>> a = torch.randn(1, 3)
  12. >>> a
  13. tensor([[-0.5192, -1.0782, -1.0448]])
  14. >>> torch.norm(a, 3)
  15. tensor(1.3633)
  16. .. function:: norm(input, p, dim, keepdim=False, out=None) -> Tensor
  17. Returns the p-norm of each row of the :attr:`input` tensor in the given
  18. dimension :attr:`dim`.
  19. If :attr:`keepdim` is ``True``, the output tensor is of the same size as
  20. :attr:`input` except in the dimension :attr:`dim` where it is of size 1.
  21. Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting
  22. in the output tensor having 1 fewer dimension than :attr:`input`.
  23. Args:
  24. input (Tensor): the input tensor
  25. p (float): the exponent value in the norm formulation
  26. dim (int): the dimension to reduce
  27. keepdim (bool): whether the output tensor has :attr:`dim` retained or not
  28. out (Tensor, optional): the output tensor
  29. Example::
  30. >>> a = torch.randn(4, 2)
  31. >>> a
  32. tensor([[ 2.1983, 0.4141],
  33. [ 0.8734, 1.9710],
  34. [-0.7778, 0.7938],
  35. [-0.1342, 0.7347]])
  36. >>> torch.norm(a, 2, 1)
  37. tensor([ 2.2369, 2.1558, 1.1113, 0.7469])
  38. >>> torch.norm(a, 0, 1, True)
  39. tensor([[ 2.],
  40. [ 2.],
  41. [ 2.],
  42. [ 2.]])
  43. """
  44. pass

二、代码示例

输入代码

  1. import torch
  2. rectangle_height = 3
  3. rectangle_width = 4
  4. inputs = torch.randn(rectangle_height, rectangle_width)
  5. for i in range(rectangle_height):
  6. for j in range(rectangle_width):
  7. inputs[i][j] = (i + 1) * (j + 1)
  8. print(inputs)

得到一个3×4矩阵,如下

  1. tensor([[ 1., 2., 3., 4.],
  2. [ 2., 4., 6., 8.],
  3. [ 3., 6., 9., 12.]])

接着我们分别对其分别求2范数

  1. inputs1 = torch.norm(inputs, p=2, dim=1, keepdim=True)
  2. print(inputs1)
  3. inputs2 = torch.norm(inputs, p=2, dim=0, keepdim=True)
  4. print(inputs2)

结果分别为

  1. tensor([[ 5.4772],
  2. [10.9545],
  3. [16.4317]])
  4. tensor([[ 3.7417, 7.4833, 11.2250, 14.9666]])

怎么来的?

inputs1:(p = 2,dim = 1)每行每一列数据进行2范数运算

5.4772=12+22+32+422

10.9545=22+42+62+822

15.4317=32+62+92+1222

inputs2:(p = 2,dim = 0)每列每一行数据进行2范数运算

3.7417=12+22+322

7.4833=22+42+622

11.2250=32+62+922

14.9666=42+82+1222


关注keepdim = False这个参数

  1. inputs3 = inputs.norm(p=2, dim=1, keepdim=False)
  2. print(inputs3)

inputs3

tensor([ 5.4772, 10.9545, 16.4317])

输出inputs1inputs3shape

  1. print(inputs1.shape)
  2. print(inputs3.shape)
  1. torch.Size([3, 1])
  2. torch.Size([3])

可以看到inputs3少了一维,其实就是dim=1(求范数)那一维(列)少了,因为从4列变成1列,就是3行中求每一行的2范数,就剩1列了,不保持这一维不会对数据产生影响

或者也可以这么理解,就是数据每个数据有没有用[]扩起来。

keepdim = True[]扩起来;

keepdim = False不用[]括起来~;


不写keepdim,则默认不保留dim的那个维度

  1. inputs4 = torch.norm(inputs, p=2, dim=1)
  2. print(inputs4)
tensor([ 5.4772, 10.9545, 16.4317])

不写dim,则计算Tensor中所有元素的2范数

  1. inputs5 = torch.norm(inputs, p=2)
  2. print(inputs5)
tensor(20.4939)

等价于这句话

  1. inputs6 = inputs.pow(2).sum().sqrt()
  2. print(inputs6)
tensor(20.4939)

20.4939=12+22+32+42+22+42+62+82+32+62+92+1222


总之,norm操作后dim这一维变为1或者消失


三、整体代码

  1. """
  2. @author:nickhuang1996
  3. """
  4. import torch
  5. rectangle_height = 3
  6. rectangle_width = 4
  7. inputs = torch.randn(rectangle_height, rectangle_width)
  8. for i in range(rectangle_height):
  9. for j in range(rectangle_width):
  10. inputs[i][j] = (i + 1) * (j + 1)
  11. print(inputs)
  12. inputs1 = torch.norm(inputs, p=2, dim=1, keepdim=True)
  13. print(inputs1)
  14. inputs2 = torch.norm(inputs, p=2, dim=0, keepdim=True)
  15. print(inputs2)
  16. inputs3 = inputs.norm(p=2, dim=1, keepdim=False)
  17. print(inputs3)
  18. print(inputs1.shape)
  19. print(inputs3.shape)
  20. inputs4 = torch.norm(inputs, p=2, dim=1)
  21. print(inputs4)
  22. inputs5 = torch.norm(inputs, p=2)
  23. print(inputs5)
  24. inputs6 = inputs.pow(2).sum().sqrt()
  25. print(inputs6)
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/659873
推荐阅读
相关标签
  

闽ICP备14008679号