当前位置:   article > 正文

[Pytorch]:PyTorch中张量乘法大全

[Pytorch]:PyTorch中张量乘法大全

PyTorch 中,有多种方法可以执行张量之间的乘法。这里列出了一些常见的乘法操作:

总结:

  • 逐元素乘法:*ortorch.mul()
  • 矩阵乘法@ortorch.mm()ortorch.matmul()
  • 点积torch.Tensor.dot()
  • 批量矩阵乘法torch.bmm()torch.matmul()
  • 矩阵与向量相乘torch.mv(X, w0)
  1. 逐元素乘法(Element-wise multiplication):*ortorch.mul()()`对应位置的元素相乘,输入张量形状必须相同或可广播

    import torch
    
    A = torch.tensor([[1, 2], [3, 4]])
    B = torch.tensor([[2, 3], [4, 5]])
    
    result = A * B
    print(result)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    输出:

    tensor([[ 2,  6],
            [12, 20]])
    
    • 1
    • 2
  2. 矩阵乘法@ortorch.mm()ortorch.matmul()两个矩阵相乘,第一个矩阵的列数必须等于第二个矩阵的行数。

    import torch
    
    A = torch.tensor([[1, 2], [3, 4]])
    B = torch.tensor([[2, 3], [4, 5]])
    
    result = torch.matmul(A, B)
    print(result)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    输出:

    tensor([[10, 13],
            [22, 29]])
    
    • 1
    • 2

    或者使用 @ 运算符执行矩阵乘法

    result = A @ B
    print(result)
    
    • 1
    • 2
  3. 点积(Dot product):torch.Tensor.dot()两个一维张量的点积。

    import torch
    
    A = torch.tensor([1, 2, 3])
    B = torch.tensor([4, 5, 6])
    
    result = torch.dot(A, B)
    print(result)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    输出:

    tensor(32)
    
    • 1
  4. 批量矩阵乘法:对于具有更高维度的张量(点积),可以使用 torch.bmm()torch.matmul() 进行批量矩阵乘法。

    import torch
    
    A = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
    B = torch.tensor([[[2, 3], [4, 5]], [[6, 7], [8, 9]]])
    
    result = torch.bmm(A, B)
    print(result)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    输出:

    tensor([[[ 10,  13],
             [ 22,  29]],
    
            [[ 76,  91],
             [112, 133]]])
    
    • 1
    • 2
    • 3
    • 4
    • 5

​ 两个输入张量的 batch_size 必须相同。此外,第一个输入张量的 num_columns 必须与第二个输入张量的 num_rows 相同。换句话说,输入张量的形状应为 (batch_size, num_rows_A, num_columns_A)(batch_size, num_columns_A, num_columns_B)

  1. 矩阵与向量相乘torch.mv(X, w0)第一个参数是矩阵,第二个参数只能是一维向量,等价于X乘以w0的转置
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/360076
推荐阅读
相关标签
  

闽ICP备14008679号