赞
踩
最近,在博客园发现一篇利用预训练网络训练CNN分类模型的文章,在代码中发现了一个问题。
如果sample_count的数目不等于batch_size 的整数倍,那么遍历到数据末尾时,必然会导致feature_batch的长度小于batch_size,也即下面图中划线的部分中等号前后的len不一样。
所以,我在这里滋生了一个疑惑,当python的内置数据类型list,或者numpy中的ndarray类型数据,在赋值时也出现这种情况,会怎么输出?
- list1 = [1,2,3,4,5]
- list2 = [6,7]
-
- list1[1:5]=list2
-
- print(list1)
- # [1, 6, 7] 正常输出;规律是从最小索引开始赋值,能赋多少赋多少,后面没办法赋值的元素自动去除
- import numpy as np
-
- arr1 = np.array([[1,2,3,4,5]])
-
- arr2 = np.array([6,7])
-
- arr1[1:5]=arr2
-
- print(arr1) # ValueError: could not broadcast input array from shape (2,) into shape (4,)
- import numpy as np
-
- arr1 = np.array([[1,2,3,4,5],[6,7,8,9,10]])
-
- arr2 = np.array([[1,2,3,6,7]])
-
- arr1[0:2]=arr2
-
- print(arr1)
-
- # [[1 2 3 6 7]
- # [1 2 3 6 7]]
- # 正常输出,说明触发broadcast机制(本文的第三部分有讲到)
而明显,本文的第一部分中引用的代码,明显属于ndarray数组类型的数据。
能不能实现广播机制呢?其实是不能的,因为从后向前数,第0维上的shape不相等,并且features_batch的shape一旦不是(1,4,4,512),就无法触发广播机制,所以我觉得作者的代码存在问题。
(sorry,对这个作者说句抱歉!大概是因为generator这个函数只能输出batch_size形状的数据,所以“大概”是不存在这个问题的,但是问题在于文件夹内剩余的没有batch_size的数据呢?完了又有新的疑问......可能会遗弃剩余的吧,不然会出错呢!)
- list1 = [[1,2,3,4,5],[6,7,8,9,10],[11,12,22,323,121]]
- list2 = [[[1,2,3]]]
-
- list1[0:2]=list2
-
- print(list1) # [[[1, 2, 3]], [11, 12, 22, 323, 121]]
可以看出规律:list1[0:2]切片相当于取出了[1,2,3,4,5],[6,7,8,9,10]这个部分, 而list2在赋值表达式中相当于[[1,2,3]],所以把后者替换到前者,就可以得到结果[[[1, 2, 3]], [11, 12, 22, 323, 121]]了。
很明显,list不存在广播机制,不然这种赋值操作是不允许存在的。
因为广播机制的前提是输入输出的容器中各个维度上的shape一定要相等。
但是上面的list1被赋值之后,第1个维度下的两个列表shape分别是1*3和5;说明list1不存在这个机制。
也正是因为numpy.ndarray的广播机制存在,所以在将list转成array的时候,一般是不建议将各维度下不等shape的list转成数组的。
如下所示:
- np.array([[[1, 2, 3]], [11, 12, 22, 323, 121]])
-
- # 警告:C:\Users\Administrator\AppData\Local\Temp\ipykernel_10728\1244013573.py:1: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
- # 输出:np.array([[[1, 2, 3]], [11, 12, 22, 323, 121]])
因为等号左右两边shape相同,所以没必要广播。
- import numpy as np
-
- arr1 = np.array([[1,2,3,4,5],[6,7,8,9,10],[213,214,324,235,44]])
-
- arr2 = np.array([[1,2,3,6,7],[1231,2142,2412,221,214]])
-
- arr1[0:2]=arr2
-
- print(arr1[0:2].shape)
- print(arr2.shape)
- print(arr1)
-
- # (2, 5)
- # (2, 5)
- # [[ 1 2 3 6 7]
- # [1231 2142 2412 221 214]
- # [ 213 214 324 235 44]]
- A = np.zeros((2,5,3,4))
- B = np.ones((1))
-
- print((A+B).shape) # (2, 5, 3, 4)
- A = np.zeros((2,5,3,4))
- B = np.ones((1,5,1,4))
-
- print((A+B).shape) # (2, 5, 3, 4)
- A = np.zeros((2,5,3,4))
- B = np.ones((8,1,5,1,4))
-
- print((A+B).shape) # (8, 2, 5, 3, 4)
- A = np.zeros((2,5,3,4))
- B = np.ones((2,5,3,2))
-
- print((A+B).shape) # ValueError: operands could not be broadcast together with shapes (2,5,3,4) (2,5,3,2)
原因在于,从后向前数,前者的第一个数是4,后者是2,两者不同,并且其中没有一个是1,所以无法广播,因此出错。
- A = np.zeros((2,5,3,4))
- B = np.ones((2,5,3,1))
-
- print((A+B).shape) # (2, 5, 3, 4)
从后向前数,前者的第一个数是4,后者是1,两者不同,但是后者是1,所以可广播,因此没出事。
从3.3这个实验结果,最容易得到一个广泛适用的结论。
- 从后向前数,逐一对比shape上的各个数值
- 如果数值相等,或者即使不相等但是有一方为1,就可以继续向前对比
- 一直对比到有一方的数值被用完,才结束
- 如果各个数值的对比结果符合条件,就符合“广播兼容性”,可以广播!
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。