当前位置:   article > 正文

transpose算子优化的几种常见场景

transpose算子

很难写一个kernel就能同时在transpose的所有场景都最优,归纳transpose的几种常见场景可以针对性优化。这里只列出了transpose对轴变换的几种情况,没有考虑shape大小。因此在这几种场景上还应该考虑转置的轴 shape大小针对性优化。

场景1: batch 2D,perm:021

二维,三维,或者更高维的tensor,交换最内层的两个维度。这些都可以统一为batch 2D。对于2D矩阵的转置相当于batch=1,大于三维的tensor可以把两个最内层维度以外的所有维度合并看成一个维度。

在GPU的实现中,一个重要的优化点是实现输入和输出的合并内存访问。可以每个线程读取4x4的数据块,基于寄存器进行转置,或者一个warp的线程基于shared memory,协作读取更大的一个数据块,基于shared memory和寄存器转置,这样能同时实现线程间数据读取和写回的合并内存访问。

场景2:0213

其特点是内部的两个相邻的维度进行交互,不包含最内层的一个或多个维度。跟上面一样,相邻不交换的维度可以合并看成一个整体,最外层的维度不足可以补1。

下面这个场景perm=[2, 0, 3, 1, 4],看上去同时转置了多个axes,但是由于shape元素1的特殊性,可以squeeze掉, 因此可以转换为[784, 3, 4, 12]到[3, 4, 784, 12]的transpose,可以使用0213的方法来解决。

删除transpose shape为1的算法

  1. perm = [2, 0, 3, 1, 4]
  2. in_shape = [1, 784, 1, 4, 12]
  3. rm_axes = []
  4. for idx, elem in enumerate(in_shape):
  5. if elem == 1:
  6. rm_axes.append(idx)
  7. print("rm_axes:", rm_axes)
  8. def remove_axis(in_shape, perm, rm_axis):
  9. del in_shape[rm_axis]
  10. perm_rm_idx = -1
  11. for idx, elem in enumerate(perm):
  12. if elem == rm_axis:
  13. perm_rm_idx = idx
  14. if elem > rm_axis:
  15. perm[idx] = perm[idx]-1
  16. del perm[perm_rm_idx]
  17. for rm_axis in reversed(rm_axes):
  18. remove_axis(in_shape, perm, rm_axis)
  19. print("perm:", perm)
  20. print("in_shape:", in_shape)

场景3:交换两个相邻的axes,但是其中一个axis对应的shape是1

这个场景并不需要transpose,只需要reshape即可。

使用上面的删除transpose shape为1的算法后,这种transpose的perm会变成[0,1,2,3,...] 可以非常简单的判断这个transpose实际上不需要进行任何操作,直接删除即可。

场景4:交换多个axes,但是部分perm是相邻的

这里perm=[1, 2, 0], 看上去交换了3个axes,实际上1x64这两个是一起交换的,可以合并成一个维度,这个问题就变成了上面的场景1。因此解决方案可以是合并一起变换的相邻轴,从而把问题简化。

场景5:transpose用作depth2space或space2depth

进行针对性替换往往有更好的性能。

场景6:其他

当然还有少量场景无法使用上面的方法来解决,例如这里输入shape第一个维度不是1的情况。

合并transpose相邻的perm从而降低transpose的维度

如前面所述,transpose可以转换为reshape+transpose+reshape,转换后的transpose通过合并相邻的perm从而降低transpose的维度。例如原始perm=[2, 0, 1] shape=[2,2,2]的transpose可以转换为[1,0], shape=[4, 2]的transpose,参考代码如下:

  1. def get_shape_elem(shape):
  2. elem_num = 1
  3. for elem in shape:
  4. elem_num *= elem
  5. return elem_num
  6. perm = [0, 1, 3, 2, 4, 5]
  7. in_shape = [1, 22, 12, 16, 12, 192]
  8. perm_groups = []
  9. concec_perm = []
  10. last_perm = -10
  11. for elem in perm:
  12. if elem == last_perm + 1:
  13. concec_perm.append(elem)
  14. else:
  15. if concec_perm:
  16. perm_groups.append(concec_perm)
  17. concec_perm = [elem]
  18. last_perm = elem
  19. if concec_perm:
  20. perm_groups.append(concec_perm)
  21. min_perms = [concec_perm[0] for concec_perm in perm_groups]
  22. min_perms_sorted = sorted(min_perms)
  23. new_perms_map = dict(zip(min_perms_sorted, list(range(len(min_perms_sorted)))))
  24. new_perms = [new_perms_map[axis] for axis in min_perms]
  25. min_perm_2_perm_groups=dict(zip(min_perms, perm_groups))
  26. new_shape = []
  27. for min_perm in min_perms_sorted:
  28. concec_perm = min_perm_2_perm_groups[min_perm]
  29. perm_shape = [in_shape[axis] for axis in concec_perm]
  30. new_shape.append(get_shape_elem(perm_shape))
  31. print("old perm:", perm)
  32. print("old shape:", in_shape)
  33. print("new perm:", new_perms)
  34. print("new shape:", new_shape)

更加复杂的情况,也可以考虑把一个transpose拆分为多个reshape+transpose来降低每个transpose的维度。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/154205
推荐阅读
  

闽ICP备14008679号