当前位置:   article > 正文

Torch中的矩阵相乘分类_torch矩阵乘法

torch矩阵乘法

矩阵相乘在torch中的几种情况

  • 1、矩阵逐元素(Element-wise)乘法 torch.mul(mat1, other)

mat和other可以是标量也可以是任意维度的矩阵,只要满足最终相乘是可以broadcast的即可,即该操作是支持broadcast操作的。只要mat1与other满足broadcast条件,就可可以进行逐元素相乘 。

# 生成指定张量
c = torch.Tensor([[1, 2, 3], [4, 5 ,6]])
print(c.shape)  # 2*3
print(c)

# 生成随机张量
d = torch.randn(2,2,3) 
print(d)
print(d.shape)  # 2*2*3

mul = torch.mul(c, d) # c会自动broadcast和d进行匹配
print(mul.shape)      # 2*2*3
print(mul)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 2、 二维矩阵相乘 torch.mm(a, b)

该函数一般只能用来计算两个二维矩阵的矩阵乘法,而且不支持broadcast操作。

  • 3、三维带Batch矩阵乘法 torch.bmm()

t o r c h . b m m ( b m a t 1 , b m a t 2 ) torch.bmm(bmat1, bmat2) torch.bmm(bmat1,bmat2), 其中 b m a t 1 ( B × n × m ) bmat1(B\times n\times m) bmat1(B×n×m), b m a t 2 ( B × m × d ) bmat2(B\times m \times d) bmat2(B×m×d)输出的维度是 o u t out out的尺度是 B × n × d B\times n \times d B×n×d,该函数两个输入必须三维矩阵中的第一维要要相同,不支持broadCast操作。

  • 4、多维数据矩阵相乘torch.matmul(a, b)

矩阵乘法使用使用两个参数的后两个维度来计算,其他的维度都可以认为是batch维度。这个可用范围更广

假设两个输入的维度分别是 i n p u t = ( 1000 × 500 × 99 ×   11 ) input=(1000\times 500 \times 99 \times\ 11) input=(1000×500×99× 11), o t h e r = ( 500 × 11 ×   99 ) other=(500 \times 11 \times\ 99) other=(500×11× 99),那么我们可以认为 乘法 t o r c h . m a t m u l ( i n p u t , o t h e r ) torch.matmul(input, other) torch.matmul(input,other)首先是进行后两位矩阵乘法得到 99 × 99 99\times 99 99×99 ,可以广播成为 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DHzguPmS-1642437999480)(https://www.zhihu.com/equation?tex=%281000+%5Ctimes+500%29)], 因此最终输出的维度是 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-U2kJGZUO-1642437999483)(https://www.zhihu.com/equation?tex=%EF%BC%881000+%5Ctimes+500+%5Ctimes+99+%5Ctimes+99%29)]。

a = torch.randn(4,5)
b = torch.randn(5, 4)
print(torch.matmul(a,b))

a = torch.randn(2,4,5)
b = torch.randn(5, 4)
print(torch.matmul(a,b))

a = torch.randn(2, 3,4,5)
b = torch.randn(5, 4)
print(torch.matmul(a,b))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 扩展稀疏张量矩阵

    Torch支持COO(rdinate)格式的稀疏张量,可以有效地存储和处理大多数元素为零的张量。

    sparse tensor 可以表示为一对 dense tensors:一个张量的value和一个二维的张量 indices一个稀疏张量可以通过提供这两个张量, 以及稀疏张量的大小来构造(从这些张量是无法推导出来的!)假设我们要定义一个稀疏张量, 其中 3在(0,2)处, 4在(1,0)处, 5在(1,2)处, 然后我们可以这样写:

i = torch.LongTensor([[0, 1, 1],
                          [2, 0, 2]])
v = torch.FloatTensor([3, 4, 5])
torch.sparse.FloatTensor(i, v, torch.Size([2,3])).to_dense()

结果:
 0  0  3
 4  0  5
[torch.FloatTensor of size 2x3]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
import torch
 
if __name__ == '__main__':
    # 对应的索引坐标(0,0)是2, (1,0)为3
    indices = torch.tensor([[0,1],
                            [0,0]])
    # 也可以写成数组,但是要转置
    ind = torch.tensor([[0,0],[1,0]])
    #稀疏矩阵对应的值
    values = torch.tensor([2,3])
    # 稀疏矩阵的大小
    shape = torch.Size((2,2))

    # 创建稀疏矩阵,传入三个参数
    s = torch.sparse.FloatTensor(indices,values,shape)
    s2 = torch.sparse.FloatTensor(ind.t(), values, shape)
 
    # 显示对应的稀疏矩阵
    print(s.to_dense())
    print(s2.to_dense())
    print(s)
    d = torch.tensor([[1,2],
                      [3,4]])
    e = torch.tensor([[2,3],[2,3]])
    # 矩阵相乘的时候只能稀疏矩阵在前
    print(torch.spmm(s,d
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

参考链接

torch离散矩阵

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/678970
推荐阅读
相关标签
  

闽ICP备14008679号