赞
踩
reshape和transpose都是非常高效的算子,究其原因,是因为二者均没有在内存中重新排列数据,只是对数据的shape或strides等信息进行了改变。下面分别简介。
为了更好地理解reshape和transpose算子,需要对ndarray的shape, base, strides
三个属性有所了解,其中shape很容易理解,就不多说了,下面简单介绍一下base和strides。
base参考:https://numpy.org/doc/stable/reference/generated/numpy.ndarray.base.html
如果一个ndarray是通过其他ndarray经过某种操作创建出来的,那么其base就会指向最初的源头。
比如下面例子中,a
是b, c
的源头,所以b.base 和 c.base
都等于a
,而a
本身没有base,所以是None。
import numpy as np a = np.array([0, 1, 2, 3, 4, 5]) print(a) # ==> [0 1 2 3 4 5] print(a.base) # ==> None b = a.reshape([2, 3]) print(b) # ==> # [[0 1 2] # [3 4 5]] print(b.base) # ==> [0 1 2 3 4 5] c = b.transpose([1, 0]) print(c) # ==> # [[0 3] # [1 4] # [2 5]] print(c.base) # ==> [0 1 2 3 4 5]
strides参考:https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
ndarray的每一个维度(axis)都有一个strides,表示从数组在某个维度进行遍历的内存偏移量
。
比如在下面的例子中,数组a
三个维度的strides分别是(48, 16, 4)
,意思是:
a[0, 0, 0]
到a[0, 0, 1] = 1
的内存偏移量是4字节,1个int型数字是4字节a[0, 0, 0]
到a[0, 1, 0] = 4
的内存偏移量是16字节,因为需要偏移4个int型数字a[0, 0, 0]
到a[1, 0, 0] = 12
的内存偏移量是48字节,因为需要偏移12个int型数字import numpy as np
a = np.arange(24).reshape([2, 3, 4])
print(a.strides) # ==> (48, 16, 4)
print(a) # ==>
# [[[ 0 1 2 3]
# [ 4 5 6 7]
# [ 8 9 10 11]]
#
# [[12 13 14 15]
# [16 17 18 19]
# [20 21 22 23]]]
reshape仅仅只是改变了数组的shape属性,比如把shape从 ( 4 , ) (4,) (4,)改成 ( 2 , 2 ) (2,2) (2,2)。通过下面的测试代码,可以明白reshape的下列性质:
b
的第一个值,发现所有相关的变量的第一个值都发生了变化,所以就可以知道,经reshape后,变量用于保存数据的那块内存没有被碰过。import numpy as np a = np.arange(4) # a = torch.arange(4) print(a) # ==> [0 1 2 3] print(a.shape) # ==> (4,) b = a.reshape([2, 2]) # b = a.reshape([2, 2]) print(b) # ==> [[0 1], [2 3]] print(b.shape) # ==> (2, 2) c = b.reshape([-1]) # c = torch.reshape(b, [-1]) print(c) # ==> [0 1 2 3] b[0, 0] = 100 print(a) # ==> [100 1 2 3] print(b) # ==> [[100 1], [2 3]] print(c) # ==> [100 1 2 3]
transpose改变了数组的维度(axis)排列顺序。比如对于二维数组,如果我们把两个维度的顺序互换,那就是我们很熟悉的矩阵转置。而transpose可以在更多维度的情况下生效。transpose的入参是输出数组的维度排列顺序,序号从0开始计数。
下面例子中我们改变了transpose后的b
的第一个元素的值,发现a
也随之改变,说明transpose也没有去碰数组的内存。那么问题来了,既然数组没有在内存中重新排列,那么打印顺序是受什么影响而发生了改变呢?是strides。
import numpy as np a = np.arange(24).reshape([2, 3, 4]) print(a.base) # ==> # [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23] print(a.shape) # ==> (2, 3, 4) print(a.strides) # ==> (48, 16, 4) print(a) # ==> # [[[ 0 1 2 3] # [ 4 5 6 7] # [ 8 9 10 11]] # # [[12 13 14 15] # [16 17 18 19] # [20 21 22 23]]] b = a.transpose([1, 2, 0]) print(b.shape) # ==>(3, 4, 2) print(b.strides) # ==> (16, 4, 48) print(b) # ==> # [[[ 0 12] # [ 1 13] # [ 2 14] # [ 3 15]] # # [[ 4 16] # [ 5 17] # [ 6 18] # [ 7 19]] # # [[ 8 20] # [ 9 21] # [10 22] # [11 23]]] b[0, 0, 0] = 100 print(a) # ==> # [[[100 1 2 3] # [ 4 5 6 7] # ...]]] print(b) # ==> # [[[100 12] # [ 1 13] # ...]]]
下面图示一下strides的含义。
首先明确一个很重要的概念,strides都是相对于base数组而言进行遍历的,所以无论是a
还是b
,遍历时需要参考的源头都是a.base / b.base
,也就是最上面的一维数组。
数组a
的strides情况我们前面已经讲过了,接下来主要看看b
。
b[0, 0, 1]
是b[0, 0, 0]
在b.base
中偏移48字节后的数字,也就是12
。b[0, 1, 0]
是b[0, 0, 0]
在b.base
中偏移4字节后的数字,也就是1
。b[1, 0, 0]
是b[0, 0, 0]
在b.base
中偏移16字节后的数字,也就是4
。所以ranspose操作只是改变了strides的顺序,没有重新排列内存中的数据。
前面我们在解释reshape和transpose的机制时,分别从ndarray的shape和strides属性进行了侧重解释。实际上reshape既改变shape也改变strides,而transpose也可能会改变shape。
但这两个算子均不会在内存中重新排列数据。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。