当前位置:   article > 正文

如何实现比PyTorch快6倍的Permute/Transpose算子?_permute torch 性能 oneflow

permute torch 性能 oneflow

2068e3e8e57dbf53656f89993523cbfb.png

撰文 | 郑泽康、柳俊丞、姚迟、郭冉

无论是在统治NLP届的Transformer,还是最近视觉领域的新秀Vision Transformer,我们都能在模型中看到Transpose/Permute算子的身影,特别是在多头注意力机制(Multi-Head Attention)中,需要该算子来改变数据维度排布。

显然,作为一个被高频使用的算子,其CUDA实现会影响到实际网络的训练速度。本文会介绍OneFlow中优化Permute Kernel的技巧,并跟PyTorch的Permute,原生的Copy操作进行实验对比。结果表明,经过深度优化后的Permute操作在OneFlow上的速度和带宽利用率远超PyTorch,带宽利用率能够接近原生Copy操作。

1

朴素的Permute实现

Permute算子的作用是变换张量数据维度的顺序,举个例子:

  1. x = flow.randn(23)
  2. y = x.permute(10)
  3. y.shape 
  4. (32)

其实现原理也可以很容易理解,即输出Tensor的第i维,对应输入Tensor的dims[i]维,上述例子中 permute 实现对应的伪代码如下:

  1. for row in x.shape[0]: 
  2.   for col in x.shape[1]: 
  3.     y[row][col] = x[col][row]

但是实际情况与上面的伪代码有出入,张量的Shape是数学上的概念,在物理设备上并不真实存在。

在OneFlow中,张量的数据都是保存在一块连续的内存中,下图分别从上层视角和底层视角描述了形状为(2, 3)的张量的存储方式:

c979957e9e2a311b8074e874ee128ed6.png

OneFlow的Permute实现原理为:

  • 通过当前输出的一维偏移量(offset)计算对应的高维索引

  • 然后根据参数dims重新排列输出索引,进而得到输入索引。

  • 将输入索引转换成输入偏移量

  • 最后进行数据移动,整个过程的示意图如下:

07c444f5995f988987949063e645417c.png

完成Permute后,输出如下图所示:

89344b6e8532068323314ae475b24098.png

整个 permute 计算过程需要经过多次一维偏移量offset和高维索引之间的转换,为了避免一次次手工计算,OneFlow提供了一个工具类NdIndexOffsetHelper来方便做上述转换。

2

NdIndexOffsetHelper

NdIndexOffsetHelper的主体方法如下:

  • NdIndexToOffset方法把高维索引转为一维偏移量

  • OffsetToNdIndex方法把一维偏移量转为高维索引

有了这么一个工具类,那我们就可以很轻松的写出一版Naive Permute Kernel了,核函数如下:

  1. template<size_t num_dims, size_t movement_size, typename IndexType>
  2. __global__ void PermuteKernel(PermuteKernelParams<num_dims, IndexType> params) {
  3.   using T = typename std::aligned_storage<movement_size, movement_size>::type;
  4.   const T* src = reinterpret_cast<const T*>(params.src);
  5.   T* dst = reinterpret_cast<T*>(params.dst);
  6.   IndexType src_index[num_dims];
  7.   IndexType dst_index[num_dims];
  8.   CUDA_1D_KERNEL_LOOP_T(IndexType, i, params.count) {
  9.     params.dst_index_helper.OffsetToNdIndex(i, dst_index);
  10. #pragma unroll
  11.     for (size_t dim = 0; dim < num_dims; ++dim) {
  12.       src_index[params.permutation[dim]] = dst_index[dim];
  13.     }
  14.     IndexType src_offset = params.src_index_helper.NdIndexToOffset(src_index);
  15.     dst[i] = src[src_offset];
  16.   }
  17. }
  • PermuteKernelParams是一个结构体,里面有初始化好的NdIndexOffsetHelper(src和dst各一个),元素总数count还有变换后的维度顺序permutation

  • 首先我们取得当前处理输出元素的高维索引dst_index,然后赋给经过Permute后的输入索引src_index

  • 再将输入索引转换成一维偏移量src_offset,取到输入元素并赋给对应的输出

3

常规情况的优化

这种朴素Permute Kernel的计算代价来源于坐标换算,访存开销则来源于数据移动,针对这两个角度我们引入以下优化方案。

1. IndexTy

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号