当前位置:   article > 正文

2D Transpose算子GPU实现和优化_batched transpose with cuda

batched transpose with cuda

2D Transpose

一种较好的做法是基于shared mem,合并内存访问读取到shared mem,读取整个warpSize x warpSize大小的矩阵块。然后基于shared mem索引变换读取实现转置效果,最后写回同样可以合并内存访问。可以考虑使用一个warp或者一个thread block读取warpSize x warpSize大小的矩阵,基于shared mem转置后写回。中小尺寸采用后者,也就是一个thread block转置一个warpSize x warpSize大小的矩阵,可以创建更多的线程提高硬件利用率,性能更优。

参考CUDA实现代码

  1. #include <stdio.h>
  2. #include <iostream>
  3. using namespace std;
  4. #include <cuda_runtime.h>
  5. #include "utils/cuda_mem_helper.h"
  6. #include "utils/cuda_stream_helper.h"
  7. #define THREAD_PER_WARP 32
  8. #define WARP_PER_BLOCK 8
  9. #define BLOCK_RD_NUM (THREAD_PER_WARP / WARP_PER_BLOCK)
  10. /*
  11. using a warp to transpose warpSize*warpSize block
  12. */
  13. template <typename T>
  14. __global__ void transpose_2d_warp(const T* __restrict__ in, T* __restrict__ out,
  15. const int row, const int col,
  16. int warp_row, int warp_col, int total_warp) {
  17. const int tid = blockDim.x * blockIdx.x + threadIdx.x; // global thread id
  18. const int warp_bid = threadIdx.x / THREAD_PER_WARP; // warp id in thread block
  19. const int warp_gid = tid / THREAD_PER_WARP; // warp id in grid
  20. const int lane = threadIdx.x % THREAD_PER_WARP; // thread id in warp
  21. const int warp_id_c = warp_gid % warp_col;
  22. const int warp_id_r = warp_gid / warp_col;
  23. // add array padding to handle bank-conflict
  24. __shared__ T block_data[WARP_PER_BLOCK][THREAD_PER_WARP][THREAD_PER_WARP + 1];
  25. const int row_bias = warp_id_r * THREAD_PER_WARP;
  26. const int col_bias = warp_id_c * THREAD_PER_WARP;
  27. // read block from input
  28. for (int i = 0; i < THREAD_PER_WARP; i++) {
  29. int addr = (row_bias + i) * col + col_bias + lane;
  30. block_data[warp_bid][i][lane] = in[addr];
  31. }
  32. __syncthreads();
  33. // write block to output
  34. for (int i = 0; i < THREAD_PER_WARP; i++) {
  35. int tgt_c = col_bias + i;
  36. int tgt_r = row_bias + lane;
  37. if ((tgt_r < row) && (tgt_c < col)) {
  38. int addr = tgt_c * row + tgt_r;
  39. out[addr] = block_data[warp_bid][lane][i];
  40. }
  41. }
  42. }
  43. /*
  44. using a thread block to transpose warpSize*warpSize block
  45. */
  46. template <typename T>
  47. __global__ void transpose_2d_block(const T* __restrict__ in, T* __restrict__ out,
  48. const int row, const int col) {
  49. int block_id = blockIdx.x;
  50. int block_num_col = col / warpSize;
  51. int block_id_row = block_id / block_num_col;
  52. int block_id_col = block_id % block_num_col;
  53. int row_offset = block_id_row * warpSize;
  54. int col_offset = block_id_col * warpSize;
  55. int row_id = threadIdx.x / warpSize;
  56. int col_id = threadIdx.x % warpSize;
  57. // add array padding to handle bank-conflict
  58. __shared__ T block_data[THREAD_PER_WARP][THREAD_PER_WARP + 1];
  59. #pragma unroll
  60. for (int i = 0; i < BLOCK_RD_NUM; i++) {
  61. int row_pos = i * WARP_PER_BLOCK + row_id;
  62. int in_addr = (row_offset + row_pos) * col + col_offset + col_id;
  63. block_data[row_pos][col_id] = in[in_addr];
  64. }
  65. __syncthreads();
  66. #pragma unroll
  67. for (int i = 0; i < BLOCK_RD_NUM; i++) {
  68. int row_pos = i * WARP_PER_BLOCK + row_id;
  69. int out_addr = (col_offset + row_pos) * row + row_offset + col_id;
  70. out[out_addr] = block_data[col_id][row_pos];
  71. }
  72. }
  73. template <typename T>
  74. void Transpose2DWarp(const T* in, T* out, const int row, const int col, cudaStream_t & stream) {
  75. const int warp_row = (row + THREAD_PER_WARP - 1) / THREAD_PER_WARP;
  76. const int warp_col = (col + THREAD_PER_WARP - 1) / THREAD_PER_WARP;
  77. const int total_warp = warp_row * warp_col;
  78. const int block_size = THREAD_PER_WARP * WARP_PER_BLOCK;
  79. const int grid_size = (total_warp + WARP_PER_BLOCK - 1) / WARP_PER_BLOCK;
  80. transpose_2d_warp <<< grid_size, block_size, 0, stream>>>(in, out, row, col, warp_row, warp_col, total_warp);
  81. }
  82. template <typename T>
  83. void Transpose2DBlock(const T* in, T* out, const int row, const int col, cudaStream_t & stream) {
  84. const int block_row = (row + THREAD_PER_WARP - 1) / THREAD_PER_WARP;
  85. const int block_col = (col + THREAD_PER_WARP - 1) / THREAD_PER_WARP;
  86. const int total_block = block_row * block_col;
  87. const int block_size = THREAD_PER_WARP * WARP_PER_BLOCK;
  88. const int grid_size = total_block;
  89. transpose_2d_block <<< grid_size, block_size, 0, stream>>>(in, out, row, col);
  90. }
  91. int main(void) {
  92. cudaError_t err = cudaSuccess;
  93. int row = 256;
  94. int col = 256;
  95. CudaMemoryHelper<float> data_in({row, col});
  96. CudaMemoryHelper<float> data_out({row, col});
  97. data_in.StepInitHostMem(1.0f);
  98. data_in.CopyMemH2D();
  99. data_in.PrintElems(4, 512, col);
  100. CudaStreamHelper stream_helper;
  101. auto & stream = stream_helper.stream;
  102. int eval_num = 20;
  103. int thread_num = row * col;
  104. int threadsPerBlock = std::min(128, col);
  105. int blocksPerGrid = (thread_num + threadsPerBlock - 1) / threadsPerBlock;
  106. printf("CUDA kernel launch with %d blocks of %d threads\n", blocksPerGrid, threadsPerBlock);
  107. for (int i = 0; i < eval_num; i++) {
  108. Transpose2DWarp(data_in.d_mem, data_out.d_mem, row, col, stream);
  109. }
  110. stream_helper.Sync();
  111. for (int i = 0; i < eval_num; i++) {
  112. Transpose2DBlock(data_in.d_mem, data_out.d_mem, row, col, stream);
  113. }
  114. stream_helper.Sync();
  115. data_out.CopyMemD2H();
  116. printf("results0:\n");
  117. // verify results
  118. data_out.PrintElems(1, 1024, row);
  119. return 0;
  120. }

可以参考

CUDA学习(二)矩阵转置及优化(合并访问、共享内存、bank conflict) - 知乎

端侧GPU没有shared mem怎么处理?

可以考虑一个线程转置4x4大小矩阵块,连续读取4行,每次读取都能pack4读取,然后基于寄存器转置,写回仍然能pack4写回。

opencl参考代码

  1. #include <iostream>
  2. #include <memory>
  3. #include <string>
  4. #include <vector>
  5. #include <chrono>
  6. #include "mem_helper.h"
  7. #define CL_HPP_TARGET_OPENCL_VERSION 300
  8. #include <CL/opencl.hpp>
  9. using TEST_DTYPE = half;
  10. using namespace std;
  11. // must be the same with real implement
  12. #define TH_ROW_SIZE 4
  13. #define TH_COL_SIZE 4
  14. // should be used only for small channel_size (<=768)
  15. std::string kernel_source{R"(
  16. #pragma OPENCL EXTENSION cl_khr_fp16 : enable
  17. #define TH_ROW_SIZE 4
  18. #define TH_COL_SIZE 4
  19. #define DTYPE half
  20. #define DTYPE_PACK4 half4
  21. // use a warp to calculate
  22. kernel void transpose_kernel(__global const DTYPE* __restrict__ d_in, __global DTYPE* __restrict__ d_out, int hight,
  23. int width) {
  24. int gid = get_global_id(0);
  25. int batch_id = get_global_id(1);
  26. int batch_addr = batch_id * hight * width;
  27. int col_th_num = (width + TH_COL_SIZE - 1) / TH_COL_SIZE;
  28. int row_id = gid / col_th_num;
  29. int col_id = gid % col_th_num;
  30. int row_pos = row_id * TH_ROW_SIZE;
  31. int col_pos = col_id * TH_COL_SIZE;
  32. // read 4 x 4 tile
  33. DTYPE_PACK4 in_datas[TH_ROW_SIZE];
  34. #pragma unroll
  35. for (int rcnt = 0; rcnt < TH_ROW_SIZE; rcnt++) {
  36. int in_addr = batch_addr + (row_pos + rcnt) * width + col_pos;
  37. const DTYPE_PACK4* d_in_p4 = (const DTYPE_PACK4*)&d_in[in_addr];
  38. in_datas[rcnt] = d_in_p4[0];
  39. }
  40. DTYPE_PACK4 out_datas[TH_COL_SIZE];
  41. out_datas[0] = (DTYPE_PACK4)(in_datas[0].s0, in_datas[1].s0, in_datas[2].s0, in_datas[3].s0);
  42. out_datas[1] = (DTYPE_PACK4)(in_datas[0].s1, in_datas[1].s1, in_datas[2].s1, in_datas[3].s1);
  43. out_datas[2] = (DTYPE_PACK4)(in_datas[0].s2, in_datas[1].s2, in_datas[2].s2, in_datas[3].s2);
  44. out_datas[3] = (DTYPE_PACK4)(in_datas[0].s3, in_datas[1].s3, in_datas[2].s3, in_datas[3].s3);
  45. #pragma unroll
  46. for (int ccnt = 0; ccnt < TH_COL_SIZE; ccnt++) {
  47. int out_addr = batch_addr + (col_pos + ccnt) * hight + row_pos;
  48. DTYPE_PACK4* d_out_p4 = (const DTYPE_PACK4*)&d_out[out_addr];
  49. d_out_p4[0] = out_datas[ccnt];
  50. }
  51. }
  52. )"};
  53. int main() {
  54. std::vector<cl::Platform> platforms;
  55. cl::Platform::get(&platforms);
  56. std::cout << "get platform num:" << platforms.size() << std::endl;
  57. cl::Platform plat;
  58. for (auto& p : platforms) {
  59. std::string platver = p.getInfo<CL_PLATFORM_VERSION>();
  60. if (platver.find("OpenCL 2.") != std::string::npos || platver.find("OpenCL 3.") != std::string::npos) {
  61. // Note: an OpenCL 3.x platform may not support all required features!
  62. plat = p;
  63. }
  64. }
  65. if (plat() == 0) {
  66. std::cout << "No OpenCL 2.0 or newer platform found.\n";
  67. return -1;
  68. }
  69. std::cout << "platform name:" << plat.getInfo<CL_PLATFORM_NAME>() << std::endl;
  70. cl::Platform newP = cl::Platform::setDefault(plat);
  71. if (newP != plat) {
  72. std::cout << "Error setting default platform.\n";
  73. return -1;
  74. }
  75. // get default device (CPUs, GPUs) of the default platform
  76. std::vector<cl::Device> all_devices;
  77. newP.getDevices(CL_DEVICE_TYPE_GPU, &all_devices); // CL_DEVICE_TYPE_ALL
  78. std::cout << "get all_devices num:" << all_devices.size() << std::endl;
  79. if (all_devices.size() == 0) {
  80. std::cout << " No devices found. Check OpenCL installation!\n";
  81. exit(1);
  82. }
  83. // cl::Device default_device = cl::Device::getDefault();
  84. cl::Device default_device = all_devices[0];
  85. std::cout << "device name: " << default_device.getInfo<CL_DEVICE_NAME>() << std::endl;
  86. // a context is like a "runtime link" to the device and platform;
  87. // i.e. communication is possible
  88. cl::Context context({default_device});
  89. cl::CommandQueue queue(context, default_device);
  90. int batch = 4;
  91. int hight = 1024;
  92. int width = 1024;
  93. vector<int> shape1 = {batch, hight, width};
  94. MemoryHelper<TEST_DTYPE> mem_in(shape1);
  95. MemoryHelper<TEST_DTYPE> mem_out(shape1);
  96. mem_in.StepInit(1.0f);
  97. // CL_MEM_WRITE_ONLY CL_MEM_READ_ONLY CL_MEM_READ_WRITE
  98. cl::Buffer d_in = cl::Buffer(context, CL_MEM_READ_WRITE, mem_in.bytes);
  99. cl::Buffer d_out = cl::Buffer(context, CL_MEM_READ_WRITE, mem_out.bytes);
  100. memset(mem_out.Mem(), 0, mem_out.bytes);
  101. // push write commands to queue
  102. queue.enqueueWriteBuffer(d_in, CL_TRUE, 0, mem_in.bytes, mem_in.Mem());
  103. std::vector<std::string> programStrings;
  104. programStrings.push_back(kernel_source);
  105. cl::Program program(context, programStrings);
  106. if (program.build({default_device}, "-cl-std=CL3.0") != CL_SUCCESS) {
  107. std::cout << "Error building: " << program.getBuildInfo<CL_PROGRAM_BUILD_LOG>(default_device) << std::endl;
  108. exit(1);
  109. }
  110. auto cl_kernel = cl::KernelFunctor<cl::Buffer, cl::Buffer, int, int>(program, "transpose_kernel");
  111. int local_thread_num = 128 * 4;
  112. int row_th_num = (hight + TH_ROW_SIZE - 1) / TH_ROW_SIZE;
  113. int col_th_num = (width + TH_COL_SIZE - 1) / TH_COL_SIZE;
  114. int total_thread_num = row_th_num * col_th_num;
  115. local_thread_num = std::min(local_thread_num, total_thread_num);
  116. cout << "total_thread_num: " << total_thread_num << endl;
  117. cout << "local_thread_num: " << local_thread_num << endl;
  118. // global, or global, local, or offset, global, local
  119. cl::EnqueueArgs kernel_args(queue, cl::NDRange(total_thread_num, batch), cl::NDRange(local_thread_num));
  120. int warmup_num = 50;
  121. int eval_num = 50;
  122. for (int i = 0; i < warmup_num; i++) {
  123. cl_kernel(kernel_args, d_in, d_out, hight, width);
  124. }
  125. queue.finish();
  126. auto t1 = std::chrono::high_resolution_clock::now();
  127. for (int i = 0; i < eval_num; i++) {
  128. cl_kernel(kernel_args, d_in, d_out, hight, width);
  129. }
  130. queue.finish();
  131. auto t2 = std::chrono::high_resolution_clock::now();
  132. auto duration = std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count();
  133. float mean_time_ms = duration / 1000.0f / eval_num;
  134. printf("exec time us: %lld %d\n", duration, eval_num);
  135. printf("exec time: %f ms\n", mean_time_ms);
  136. printf("batch, hight, width: %d %d %d\n", batch, hight, width);
  137. queue.enqueueReadBuffer(d_out, CL_TRUE, 0, mem_out.bytes, mem_out.Mem());
  138. mem_in.PrintElems(1, 512, width);
  139. mem_out.PrintElems(1, 512, hight);
  140. return 0;
  141. }

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/154181
推荐阅读
相关标签
  

闽ICP备14008679号