当前位置:   article > 正文

Pytorch学习笔记(9)——一文搞懂如何使用 torch 中的乘法_1d tensors expected, but got 2d and 2d tensors

1d tensors expected, but got 2d and 2d tensors

网上关于 torch 的乘法文章也很多,但是也很凌乱,所以这里我自己整理了一份。
本文的核心不是弄清楚 torch 是怎样实现的,源码如何,文档如何,本文只针对在什么情况下该调用怎样的方法。本文中只介绍了我使用过的方法,如果后续有新的方法就再进行添加。

本文所有计算都以以下两个矩阵举例:
a = [ 1 1 2 2 ] , b = [ 1 2 1 2 ] a = \left[

1122
\right], b = \left[
1212
\right] a=[1212],b=[1122]
我们现在 torch 中创建这两个矩阵:

# tensor([[1, 1],
#         [2, 2]])
a = torch.tensor([[1, 1], [2, 2]])
# tensor([[1, 2],
#         [1, 2]])
b = torch.tensor([[1, 2], [1, 2]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

1 乘法

矩阵的乘法,从维度上来说就是 [ x × n ] ⋅ [ n × y ] = [ x × y ] [x \times n] \cdot [n \times y] = [x \times y] [x×n][n×y]=[x×y]。具体的计算方式可以自行翻阅线代的书或者课程,这里就不多赘述。我们先手算出来 a × b a \times b a×b 的结果如下:

[ 1 1 2 2 ] ⋅ [ 1 2 1 2 ] = [ 2 4 4 8 ] \left[

1122
\right] \cdot \left[
1212
\right] = \left[
2448
\right] [1212][1122]=[2448]

注:

  1. 矩阵的乘法我们可以看做是多个向量的点积(dot product)。
  2. 向量的点积的公式可以用向量的模和夹角来计算,即 a ⋅ b = ∣ a ∣ ∣ b ∣ c o s ( θ ) a \cdot b = |a||b|{\rm cos}(\theta) ab=a∣∣bcos(θ),由此可以带出余弦相似度的公式: c o s ( θ ) = ( a ⋅ b ) / ( ∣ a ∣ ∣ b ∣ ) {\rm cos}(\theta) = (a \cdot b) / (|a||b|) cos(θ)=(ab)/(a∣∣b),所以点积在一定程度上可以体现出两个向量的相似程度,这点在注意力机制中很常见,比如 self-attention α = q ⋅ k T \alpha = \boldsymbol{q} \cdot \boldsymbol{k}^{\rm T} α=qkT

Pytorch 中实现矩阵乘法的方法有以下几个:

1.1 向量乘法

向量的乘法即点积,我们可以用 torch 中的 dot 实现。以 a 的第一行 [ 1 , 1 ] [1, 1] [1,1]b 的第二列 [ 2 , 2 ] [2, 2] [2,2] 为例,手算出来结果为 4,用 torch 计算:

c = torch.tensor([1, 1])
d = torch.tensor([2, 2])
torch.dot(c, d)
# tensor(4)
  • 1
  • 2
  • 3
  • 4

注: dot 仅能够计算向量,如果输入的维度大于 1 时会报错,以输入 torch.dot(a, b) 为例,报错如下:

RuntimeError: 1D tensors expected, but got 2D and 2D tensors
  • 1

1.2 矩阵乘法

矩阵乘法在 torch 中使用 mm 实现:

torch.mm(a, b)
# tensor([[2, 4],
#         [4, 8]])
  • 1
  • 2
  • 3

与我们计算出来的结果一样。
注: mm 仅能够计算矩阵,如果输入的维度不为2时会报错:

RuntimeError: self must be a matrix
  • 1

1.3 张量乘法

torch 中的张量乘法有两类:bmmmatmul,区别如下:

1.3.1 带 batch 的矩阵乘法

bmm 中的 b 实际上是 batch 的意思,即 带 batch 的矩阵乘法。说明数据得是三维,且第一维为 batch 维,简单来说就是 batch 中的每个数据参与一次矩阵运算,用简单的伪码来说即:

for i in batch:
	a[i] * b[i]
  • 1
  • 2

我们假设 batch 为 1,同时对 ab 升一维 batch 维,并使用 bmm 计算:

# shape: (1, 2, 2)
a = a.unsqueeze(0)
b = b.unsqueeze(0)

# shape: (1, 2, 2)
torch.bmm(a, b)
# tensor([[[2, 4],
#          [4, 8]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

注: bmm 仅能计算三维张量,如果数据维度不为3,会报错:

# 输入 2 维数据
RuntimeError: Expected 3-dimensional tensor, but got 2-dimensional tensor for argument #1 'batch1' (while checking arguments for bmm)
# 输入 4 维数据
RuntimeError: Expected 3-dimensional tensor, but got 4-dimensional tensor for argument #1 'batch1' (while checking arguments for bmm)
  • 1
  • 2
  • 3
  • 4

1.3.2 万能乘法

matmul 算是 torch 中最万能的乘法,这个必须要结合 torch 的文档来说明:

  • If both tensors are 1-dimensional, the dot product (scalar) is returned.

  • If both arguments are 2-dimensional, the matrix-matrix product is returned.

  • If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.

  • If the first argument is 2-dimensional and the second argument is 1-dimensional, the matrix-vector product is returned.

  • If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after. If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable). For example, if input is a ( j × 1 × n × n ) (j \times 1 \times n \times n) (j×1×n×n) tensor and other is a ( k × n × n ) (k \times n \times n) (k×n×n) tensor, out will be a ( j × k × n × n ) (j \times k \times n \times n) (j×k×n×n) tensor.

  • Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs are broadcastable, and not the matrix dimensions. For example, if input is a ( j × 1 × n × m ) (j \times 1 \times n \times m) (j×1×n×m) tensor and other is a ( k × m × p ) (k \times m \times p) (k×m×p) tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the matrix dimensions) are different. out will be a ( j × k × n × p ) (j \times k \times n \times p) (j×k×n×p) tensor.

总体而言,matmul 执行的还是矩阵乘法,只是会自动填补维度信息。同时,其计算的是 最后两维 的数据。用代码执行一次:

# 向量
# 这里是 a = [1, 1], b = [2, 2]
# shape: (), 是的, 这里是标量, 所以没有维度
torch.matmul(a, b)
# tensor(4)

# 矩阵
# shape: (2, 2)
torch.matmul(a, b)
# tensor([[2, 4],
#         [4, 8]])

# 三维张量
# shape: (1, 2, 2)
torch.matmul(a, b)
# tensor([[[2, 4],
#          [4, 8]]])

# 四维张量
# shape: (1, 1, 2, 2)
torch.matmul(a, b)
# tensor([[[[2, 4],
#           [4, 8]]]])

# 五维张量
# shape: (1, 1, 1, 2, 2)
torch.matmul(a, b)
# tensor([[[[[2, 4],
#            [4, 8]]]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

2 对位相乘

对位相乘(element-wise product)指的是两个矩阵中第 i i i 行,第 j j j 列的元素直接相乘。以 ab 为例,手算得到(这里 ⊗ \otimes 指对位相乘):
[ 1 1 2 2 ] ⊗ [ 1 2 1 2 ] = [ 1 2 2 4 ] \left[

1122
\right] \otimes \left[
1212
\right] = \left[
1224
\right] [1212][1122]=[1224]

2.1 直接乘法

torch 中可以直接使用 * 实现对位相乘:

# 一维
# 这里是 a = [1, 1], b = [2, 2]
# shape: (2)
a * b
# tensor([2, 2])

# 二维
# shape: (2, 2)
a * b
# tensor([[1, 2],
#         [2, 4]])

# 三维
# shape: (1, 2, 2)
a * b
# tensor([[[1, 2],
#          [2, 4]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

2.2 调库实现

当然,torch 中也可以通过调库来实现对位相乘,即 mul

# 一维
# 这里是 a = [1, 1], b = [2, 2]
# shape: (2)
torch.mul(a, b)
# tensor([2, 2])

# 二维
# shape: (2, 2)
torch.mul(a, b)
# tensor([[1, 2],
#         [2, 4]])

# 三维
# shape: (1, 2, 2)
torch.mul(a, b)
# tensor([[[1, 2],
#          [2, 4]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/374936?site
推荐阅读
相关标签
  

闽ICP备14008679号