赞
踩
网上关于 torch 的乘法文章也很多,但是也很凌乱,所以这里我自己整理了一份。
本文的核心不是弄清楚 torch 是怎样实现的,源码如何,文档如何,本文只针对在什么情况下该调用怎样的方法。本文中只介绍了我使用过的方法,如果后续有新的方法就再进行添加。
本文所有计算都以以下两个矩阵举例:
a
=
[
1
1
2
2
]
,
b
=
[
1
2
1
2
]
a = \left[
我们现在 torch 中创建这两个矩阵:
# tensor([[1, 1],
# [2, 2]])
a = torch.tensor([[1, 1], [2, 2]])
# tensor([[1, 2],
# [1, 2]])
b = torch.tensor([[1, 2], [1, 2]])
矩阵的乘法,从维度上来说就是 [ x × n ] ⋅ [ n × y ] = [ x × y ] [x \times n] \cdot [n \times y] = [x \times y] [x×n]⋅[n×y]=[x×y]。具体的计算方式可以自行翻阅线代的书或者课程,这里就不多赘述。我们先手算出来 a × b a \times b a×b 的结果如下:
[
1
1
2
2
]
⋅
[
1
2
1
2
]
=
[
2
4
4
8
]
\left[
注:
Pytorch 中实现矩阵乘法的方法有以下几个:
向量的乘法即点积,我们可以用 torch 中的 dot
实现。以 a
的第一行
[
1
,
1
]
[1, 1]
[1,1] 和 b
的第二列
[
2
,
2
]
[2, 2]
[2,2] 为例,手算出来结果为 4,用 torch 计算:
c = torch.tensor([1, 1])
d = torch.tensor([2, 2])
torch.dot(c, d)
# tensor(4)
注:
dot
仅能够计算向量,如果输入的维度大于 1 时会报错,以输入 torch.dot(a, b)
为例,报错如下:
RuntimeError: 1D tensors expected, but got 2D and 2D tensors
矩阵乘法在 torch 中使用 mm
实现:
torch.mm(a, b)
# tensor([[2, 4],
# [4, 8]])
与我们计算出来的结果一样。
注:
mm
仅能够计算矩阵,如果输入的维度不为2时会报错:
RuntimeError: self must be a matrix
torch 中的张量乘法有两类:bmm
和 matmul
,区别如下:
bmm
中的 b
实际上是 batch
的意思,即 带 batch 的矩阵乘法。说明数据得是三维,且第一维为 batch
维,简单来说就是 batch 中的每个数据参与一次矩阵运算,用简单的伪码来说即:
for i in batch:
a[i] * b[i]
我们假设 batch 为 1,同时对 a
和 b
升一维 batch 维,并使用 bmm
计算:
# shape: (1, 2, 2)
a = a.unsqueeze(0)
b = b.unsqueeze(0)
# shape: (1, 2, 2)
torch.bmm(a, b)
# tensor([[[2, 4],
# [4, 8]]])
注:
bmm
仅能计算三维张量,如果数据维度不为3,会报错:
# 输入 2 维数据
RuntimeError: Expected 3-dimensional tensor, but got 2-dimensional tensor for argument #1 'batch1' (while checking arguments for bmm)
# 输入 4 维数据
RuntimeError: Expected 3-dimensional tensor, but got 4-dimensional tensor for argument #1 'batch1' (while checking arguments for bmm)
matmul
算是 torch 中最万能的乘法,这个必须要结合 torch 的文档来说明:
If both tensors are 1-dimensional, the dot product (scalar) is returned.
If both arguments are 2-dimensional, the matrix-matrix product is returned.
If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.
If the first argument is 2-dimensional and the second argument is 1-dimensional, the matrix-vector product is returned.
If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after. If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable). For example, if input is a ( j × 1 × n × n ) (j \times 1 \times n \times n) (j×1×n×n) tensor and other is a ( k × n × n ) (k \times n \times n) (k×n×n) tensor, out will be a ( j × k × n × n ) (j \times k \times n \times n) (j×k×n×n) tensor.
Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs are broadcastable, and not the matrix dimensions. For example, if input is a ( j × 1 × n × m ) (j \times 1 \times n \times m) (j×1×n×m) tensor and other is a ( k × m × p ) (k \times m \times p) (k×m×p) tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the matrix dimensions) are different. out will be a ( j × k × n × p ) (j \times k \times n \times p) (j×k×n×p) tensor.
总体而言,matmul
执行的还是矩阵乘法,只是会自动填补维度信息。同时,其计算的是 最后两维 的数据。用代码执行一次:
# 向量 # 这里是 a = [1, 1], b = [2, 2] # shape: (), 是的, 这里是标量, 所以没有维度 torch.matmul(a, b) # tensor(4) # 矩阵 # shape: (2, 2) torch.matmul(a, b) # tensor([[2, 4], # [4, 8]]) # 三维张量 # shape: (1, 2, 2) torch.matmul(a, b) # tensor([[[2, 4], # [4, 8]]]) # 四维张量 # shape: (1, 1, 2, 2) torch.matmul(a, b) # tensor([[[[2, 4], # [4, 8]]]]) # 五维张量 # shape: (1, 1, 1, 2, 2) torch.matmul(a, b) # tensor([[[[[2, 4], # [4, 8]]]]])
对位相乘(element-wise product)指的是两个矩阵中第
i
i
i 行,第
j
j
j 列的元素直接相乘。以 a
和 b
为例,手算得到(这里
⊗
\otimes
⊗ 指对位相乘):
[
1
1
2
2
]
⊗
[
1
2
1
2
]
=
[
1
2
2
4
]
\left[
torch 中可以直接使用 *
实现对位相乘:
# 一维 # 这里是 a = [1, 1], b = [2, 2] # shape: (2) a * b # tensor([2, 2]) # 二维 # shape: (2, 2) a * b # tensor([[1, 2], # [2, 4]]) # 三维 # shape: (1, 2, 2) a * b # tensor([[[1, 2], # [2, 4]]])
当然,torch 中也可以通过调库来实现对位相乘,即 mul
:
# 一维 # 这里是 a = [1, 1], b = [2, 2] # shape: (2) torch.mul(a, b) # tensor([2, 2]) # 二维 # shape: (2, 2) torch.mul(a, b) # tensor([[1, 2], # [2, 4]]) # 三维 # shape: (1, 2, 2) torch.mul(a, b) # tensor([[[1, 2], # [2, 4]]])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。