当前位置:   article > 正文

一文整理5个Pytorch张量乘法函数_pytorch 张量乘法

pytorch 张量乘法

~~欢迎关注#公众号:AI算法小喵,会有更多不错的文章分享~~

本文首发于:一文整理5个Pytorch张量乘法函数

 

最近整理了Pytorch中5个常用的张量乘法函数和用法,建议收藏学习。

1. 张量的维度

在开始今天的学习之前,我们需要先学习一个知识点,即张量的维度

张量的维度包括两方面内容,其一是维度个数,其二是维度大小。维度个数可以通过张量.ndim属性查看,维度大小可以通过.shape.size()查看。

  1. >>> a=torch.arange(6).reshape(2,3)
  2. >>> a
  3. tensor([[012],
  4.         [345]])
  5. >>> a.ndim
  6. 2
  7. >>> a.shape
  8. torch.Size([23])
  9. >>> a.size()
  10. torch.Size([23])

比如上面的张量a:维度个数为2,代表a是一个二维张量;维度大小为[2,3],代表第0维的维度大小为2,第1维为3。

2. torch.matmul

我们先学习最复杂也最灵活的torch.matmul函数[1]。

2.1 概览

功能:matmul函数实现的是矩阵乘法,更确切地说,是“混合”矩阵乘法

参数

  • input(张量):第一个张量。

  • other(张量):第二个张量。

  • out(张量):结果张量,等同于torch.matmul函数的返回值。

返回值:张量。

2.2 示例代码

matmul函数的行为根据输入张量的不同大体可以分为5种情形(case),所以这里我们也通过5个case下的示例代码来学习这个函数。

(1) case1

若两个张量均为一维张量,则执行向量点积操作,等价于调用torch.dot函数。

比如,下面我们创建了两个一维张量a、b,维度大小均为2。

  1. >>> a=torch.randn(1)
  2. >>> b=torch.randn(1)
  3. >>> a.ndim
  4. 1
  5. >>> b.ndim
  6. 1
  7. >>> a.size()
  8. torch.Size([2])
  9. >>> b.size()
  10. torch.Size([2])
  11. >>> a
  12. tensor([0.8411])
  13. >>> b
  14. tensor([-1.1787])

然后分别对他们进行matmul操作,和dot操作。从结果比对来看,两个操作是等价的,最终生成的都是scalar标量

  1. >>> c1=torch.matmul(a,b)
  2. >>> c2=torch.dot(a,b)
  3. >>> c1.equal(c2)
  4. True
  5. >>> c1
  6. tensor(-0.9914)
  7. >>> c1.ndim
  8. 0
  9. >>> c1.size()
  10. torch.Size([])

(2) case2

若两个张量均为二维张量,则执行矩阵乘法,等价于调用torch.mm函数。

比如下面的例子,a、b均为2维张量,维度大小分别为[2,2]、[2,3]。即a.size()[1]=b.size()[0]满足矩阵乘法约束,通过matmul函数或mm函数,我们将获得2维张量,维度大小为[2,3]。

  1. >>> a=torch.randn(2,2)
  2. >>> b=torch.randn(2,3)
  3. >>> a.ndim
  4. 2
  5. >>> b.ndim
  6. 2
  7. >>> a.size()
  8. torch.Size([22])
  9. >>> b.size()
  10. torch.Size([23])
  11. >>> c1=torch.matmul(a,b)
  12. >>> c2=torch.mm(a,b)
  13. >>> c1.equal(c2)
  14. True
  15. >>> c1.size()
  16. torch.Size([23])
  17. >>> c1.ndim
  18. 2

(3) case3

若第一个张量为一维张量,假设维度为[k],第二个张量为二维张量,假设维度为[k,p]。第一个张量会在左边进行维度扩展,维度变为[1,k],然后再进行矩阵乘法,获得维度为[1,p]的张量,然后再去掉扩展的维度,最后结果张量维度为[p]。

比如,a是维度大小为[3]的一维矩阵,b是维度大小为[3,4]的二维矩阵,结果张量c1是一维张量,维度大小为[4]。

  1. >>> a=torch.arange(1,4)
  2. >>> b=torch.arange(2,14).reshape((3,4))
  3. >>> a.ndim
  4. 1
  5. >>> a.size()
  6. torch.Size([3])
  7. >>> b.ndim
  8. 2
  9. >>> b.size()
  10. torch.Size([34])
  11. >>>
  12. >>> c1=torch.matmul(a,b)
  13. >>> c1.ndim
  14. 1
  15. >>> c1.size()
  16. torch.Size([4])
  17. >>>
  18. >>> a
  19. tensor([123])
  20. >>> b
  21. tensor([[ 2,  3,  4,  5],
  22.         [ 6,  7,  8,  9],
  23.         [10111213]])
  24. >>> c1
  25. tensor([44505662])

更简单地记法,可以视为线性代数中的行向量乘矩阵,结果为第二个张量矩阵的行向量的线性组合,组合系数为第一个张量中相应的值。

  1. >>> c2=1*b[0]+2*b[1]+3*b[2]
  2. >>> c1.equal(c2)
  3. True
  4. >>> c2
  5. tensor([44505662])

还有一点需要注意,虽然matmul在进行维度扩展后,执行的是矩阵乘法,但在这种情形下,它与mm是不等价的,mm函数要求输入均为二维张量。

  1. >>> torch.mm(a,b)
  2. Traceback (most recent call last):
  3.   File "<stdin>"line 1in <module>
  4. RuntimeError: matrices expected, got 1D, 2D tensors at ../aten/src/TH/generic/THTensorMath.cpp:131
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/176509
推荐阅读
相关标签
  

闽ICP备14008679号