当前位置:   article > 正文

搞定秋招!3W字LLM高频面试题汇总:大模型手撕CUDA

cuda 手写kernel 面试

作者 | DefTruth  编辑 | 自动驾驶之心

原文链接:https://zhuanlan.zhihu.com/p/678903537

点击下方卡片,关注“自动驾驶之心”公众号

戳我-> 领取自动驾驶近15个方向学习路线

>>点击进入→自动驾驶之心大模型技术交流群

本文只做学术分享,如有侵权,联系删文

0x00 前言

dd8f04d9048b2c02bb69499ce19278fb.png

前段时间参加了一些大模型的面试,大部分都要手撕CUDA,因此也整体复习了一遍CUDA优化相关的内容,整理了一些高频题的基本写法,保存在这里也便于日后自己复习。当然,有些代码不一定是最优化解,比如GEMM,想要在面试短短的30分钟内写一个好的GEMM Kernel,是有些难度的。印象比较深刻的是,其中有一场面试2个多小时,一个小时问项目,剩下一个小时在写GEMM,说实话,如果不是事先有准备过一些,直接上手写优化版还是会有点慌。就自己的经验而言,命中率还挺高(目前还没有遇到这些题目之外的),考虑深度优化的,一般也不会让你在面试短短几十分钟手撸出来。要是遇到有面试官让你手撸一个FlashAttention,那说明,你们实在是没有缘分,还是尽早好聚好散的好,或者提前结束面试,把时间省下来,出去吃顿烧烤也不错...,附GitHub链接:

https://github.com/DefTruth/cuda-learn-note
github.com/DefTruth/cuda-learn-note

TIPS: 文章整理为方便自己复习,不喜欢的请自动跳过哈。

0x01 高频面试题汇总简介

相关kernel如下。也就是不到1000行代码,建议背下来,我个人是喜欢背记,背的过程中基本就慢慢理解所有细节。当然,每个人的学习方法都不一样哈,自己觉得舒服就行。

  • sgemm naive, sgemm + block-tile + k-tile + vec4

  • sgemv k32/k128/k16 kernel

  • warp/block reduce sum/max, block all reduce + vec4

  • dot product, dot product + vec4

  • elementwise, elementwise + vec4

  • histogram, histogram + vec4

  • softmax, softmax + vec4 (grid level memory fence)

  • sigmoid, sigmoid + vec4

  • relu, relu + vec4

  • layer_norm, layer_norm + vec4

  • rms_norm, rms_norm + vec4

  • ....

题内话,大模型相关的岗位,手撕CUDA的概率非常大,leetcode反而写的少,就前段时间个人的经验,基本是4:1的比例,还是建议好好复习下CUDA。当然,这些只是最简单的kernel实现,比如flash_attn,FMHA这些优化手段,就不在这篇文章里写了,面试中基本都会问到。FlashAttention系列原理详解,可以看我写的另一篇文章:

0x02 sgemm naive, sgemm + block-tile + k-tile + vec4

  1. #include <stdio.h>
  2. #include <stdlib.h>
  3. #include <float.h>
  4. #include <vector>
  5. #include <algorithm>
  6. #include <cuda_runtime.h>
  7. #define WARP_SIZE 32
  8. #define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
  9. #define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
  10. // SGEMM: Block Tile + K Tile, with smem
  11. // Block Tile (BM, BN) + K Tile (BK=32)
  12. // grid((N + BN - 1) / BN, (M + BM - 1) / BM), block(BN, BM)
  13. // a: MxK, b: KxN, c: MxN, compute: c = a * b, all row major  
  14. __global__ void sgemm(float* a, float* b, float* c, int M, int N, int K) {
  15.   // [1] Block Tile: 32x32的block处理c上一块32x32的元素计算
  16.   // [2]     K Tile: 使用共享内存,并将K分块为BK大小的块
  17.   constexpr int BM = 32;
  18.   constexpr int BN = 32;
  19.   constexpr int BK = 32;
  20.   __shared__ float s_a[BM][BK], s_b[BK][BN]; 
  21.   int bx = blockIdx.x;
  22.   int by = blockIdx.y;
  23.   int tx = threadIdx.x;
  24.   int ty = threadIdx.y;
  25.   int tid = threadIdx.y * blockDim.x + tx; // tid within the block
  26.   // load values to shared memory, 32x32 threads working together 
  27.   // to fetch data along the row direction of a and b both for s_a 
  28.   // and s_b 32x32x4x2=8KB, we use 32x32 threads within block to 
  29.   // load 32x32 elements from global memory to shared memory, namely, 
  30.   // each thread will load 1 element.
  31.   int load_smem_a_m = tid / 32// 0~31, tid / 32, tid / BM, threadIdx.y
  32.   int load_smem_a_k = tid % 32// 0~31, tid % 32, tid % BK, threadIdx.x
  33.   int load_smem_b_k = tid / 32// 0~31, tid / 32, tid / BK, threadIdx.y
  34.   int load_smem_b_n = tid % 32// 0~31, tid % 32, tid % BN, threadIdx.x
  35.   int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
  36.   int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
  37.   // if (load_gmem_a_m >= M || load_gmem_b_n >= N) return;
  38.   
  39.   float sum = 0.f;
  40.   for (int bk = 0; bk < (K + BK - 1) / BK; ++bk) {
  41.     int load_gmem_a_k = bk * BK + load_smem_a_k;
  42.     int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
  43.     s_a[load_smem_a_m][load_smem_a_k] = a[load_gmem_a_addr];
  44.     int load_gmem_b_k = bk * BK + load_smem_b_k;
  45.     int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
  46.     s_b[load_smem_b_k][load_smem_b_n] = b[load_gmem_b_addr];
  47.     __syncthreads();
  48.     #pragma unroll
  49.     for (int k = 0; k < BK; ++k) {
  50.       int comp_smem_a_m = load_smem_a_m;
  51.       int comp_smem_b_n = load_smem_b_n;
  52.       sum += s_a[comp_smem_a_m][k] * s_b[k][comp_smem_b_n];
  53.     }
  54.     __syncthreads();
  55.   }
  56.   int store_gmem_c_m = load_gmem_a_m;
  57.   int store_gmem_c_n = load_gmem_b_n;
  58.   int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n;
  59.   c[store_gmem_c_addr] = sum;
  60. }
  61. // SGEMM: Block Tile + Thread Tile + K Tile + Vec4, with smem
  62. // BK:TILE_K=8 BM=BN=128
  63. // TM=TN=8 增加计算密度 BM/TM=16 BN/TN=16
  64. // dim3 blockDim(BN/TN, BM/TM);
  65. // dim3 gridDim((N + BN - 1) / BN, (M + BM - 1) / BM)
  66. __global__ void sgemm_thread_tile_vec4(
  67.   float* a, float* b, float* c, int M, int N, int K) {
  68.   // [1]  Block Tile: 一个16x16的block处理C上大小为128X128的一个目标块
  69.   // [2] Thread Tile: 每个thread负责计算TM*TN(8*8)个元素,增加计算密度
  70.   // [3]      K Tile: 将K分块,每块BK大小,迭代(K+BK-1/BK)次,
  71.   //                  每次计算TM*TN个元素各自的部分乘累加
  72.   // [4]   Vectorize: 减少load和store指令,使用float4
  73.   constexpr int BM = 128;
  74.   constexpr int BN = 128;
  75.   constexpr int BK = 8
  76.   constexpr int TM = 8;
  77.   constexpr int TN = 8;
  78.   int bx = blockIdx.x;
  79.   int by = blockIdx.y;
  80.   int tx = threadIdx.x;
  81.   int ty = threadIdx.y;
  82.   int tid = threadIdx.y * blockDim.x + tx; // tid within the block
  83.   __shared__ float s_a[BM][BK], s_b[BK][BN]; // 2*128*8*4=8KB
  84.   
  85.   // 0. 先计算shared memory中的索引
  86.   // tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=8 按行读取 A行主序
  87.   // 对于s_a每行8个数据,每个线程读取4个,需要2个线程;总共128行,需要128x2刚好256线程
  88.   int load_smem_a_m = tid / 2// tid/2 (128/8)*(128/8)=256 threads per block, tid/2->[0,128), BM=128 0~127
  89.   int load_smem_a_k = (tid % 2 == 0) ? 0 : 4;  // (tid%2 == 0) ? 0 : 4, col of s_a 0,4
  90.   // tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=8 BN=128 按行读取 B行主序
  91.   // 对于s_b每行128个数据,每个线程读4个数据,需要32个线程;总共8行,需要32x8=256个线程
  92.   int load_smem_b_k = tid / 32// tid/32, row of s_b 256/32=8 行 0~7
  93.   int load_smem_b_n = (tid % 32) * 4;  // (tid % 32) * 4, col of s_b 0,4,...,124
  94.   // 1. 再计算全局内存中的索引
  95.   // 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块
  96.   int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
  97.   int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
  98.   
  99.   float r_c[TM][TN] = {0.0}; // 8x8
  100.   // 2. 先对K进行分块,每块BK大小
  101.   for (int bk = 0; bk < (K + BK - 1) / BK; ++bk) {
  102.     // 加载数据到共享内存smem s_a BM*BK 128*8 vectorize float4
  103.     int load_gmem_a_k = bk * BK + load_smem_a_k; // global col of a
  104.     int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
  105.     FLOAT4(s_a[load_smem_a_m][load_smem_a_k]) = FLOAT4(a[load_gmem_a_addr]);
  106.     // 加载数据到共享内存smem s_b BK*BN 8*128 vectorize float4
  107.     int load_gmem_b_k = bk * BK + load_smem_b_k; // global row of b
  108.     int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; 
  109.     FLOAT4(s_b[load_smem_b_k][load_smem_b_n]) = FLOAT4(b[load_gmem_b_addr]); 
  110.     __syncthreads();
  111.     #pragma unroll
  112.     for (int k = 0; k < BK; k++) {
  113.       // 3. 每个线程负责计算BM*BN(12x128)中的TM*TN(8x8)个元素
  114.       #pragma unroll
  115.       for (int m = 0; m < TM; m++) {
  116.         #pragma unroll
  117.         for (int n = 0; n < TN; n++) {
  118.           // k from 0~7,0 ~ BK, ty and tx range from 0 to 15, 16x8=128
  119.           int comp_smem_a_m = ty * TM + m;  // 128*8 128/TM(8)=16 M方向 16线程
  120.           int comp_smem_b_n = tx * TN + n;  // 8*128 128/TN(8)=16 N方向 16线程
  121.           r_c[m][n] += s_a[comp_smem_a_m][k] * s_b[k][comp_smem_b_n];
  122.         }
  123.       }
  124.     }
  125.     __syncthreads();
  126.   }
  127.   #pragma unroll
  128.   for (int m = 0; m < TM; ++m) {
  129.     int store_gmem_c_m = by * BM + ty * TM + m;
  130.     #pragma unroll
  131.     for (int n = 0; n < TN; n += 4) {
  132.       int store_gmem_c_n = bx * BN + tx * TN + n;
  133.       int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n;
  134.       FLOAT4(c[store_gmem_c_addr]) = FLOAT4(r_c[m][n]);
  135.     }
  136.   }
  137. }

这里gemm的实现比较简单,只使用了CUDA Cores,并且只实现Block Tile + K Tile以及Block Tile + K Tile+Thread Tile+向量化的版本。主要在于如何加载gmem中的数据到smem,也就是把全局内存中的数据索引mapping到共享内存中的。核心思维:把一个block中的线程id按照线性来理解,然后把这个线性的id和全局内存索引以及共享内存索引进行匹配。比如Block Tile + K Tile的实现,block内一共32x32个Threads,需要加载到smem的数据也是32x32,那么,最简单的做法,只需要每个线程加载一个互不重复数据即可。NOTE,本文的gemm kernel修改自:紫气东来:CUDA(三):通用矩阵乘法:从入门到熟练

0x03 warp/block reduce sum/max

  1. // Warp Reduce Sum
  2. template<const int kWarpSize = WARP_SIZE>
  3. __device__ __forceinline__ float warp_reduce_sum(float val) {
  4.   #pragma unroll
  5.   for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) {
  6.     val += __shfl_xor_sync(0xffffffff, val, mask);
  7.   }
  8.   return val;
  9. }
  10. // Warp Reduce Max
  11. template<const int kWarpSize = WARP_SIZE>
  12. __device__ __forceinline__ float warp_reduce_max(float val) {
  13.   #pragma unroll
  14.   for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) {
  15.     val = fmaxf(val, __shfl_xor_sync(0xffffffff, val, mask));
  16.   }
  17.   return val;
  18. }
  19. // Block reduce sum/max/min device helper for Layer/RMS Norm/Softmax etc.
  20. // grid 1D block 1D, grid(N/128), block(128)
  21. template<const int NUM_THREADS=128>
  22. __device__ __forceinline__ float block_reduce_sum(float val) {
  23.   // always <= 32 warps per block (limited by 1024 threads per block)
  24.   constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;
  25.   int warp = threadIdx.x / WARP_SIZE;
  26.   int lane = threadIdx.x % WARP_SIZE;
  27.   static __shared__ float shared[NUM_WARPS];
  28.   
  29.   val = warp_reduce_sum<WARP_SIZE>(val);
  30.   if (lane == 0) shared[warp] = val;
  31.   __syncthreads();
  32.   val = (lane < NUM_WARPS) ? shared[lane] : 0.0f;
  33.   val = warp_reduce_sum<NUM_WARPS>(val);
  34.   return val;
  35. }
  36. template<const int NUM_THREADS=128>
  37. __device__ __forceinline__ float block_reduce_max(float val) {
  38.   // always <= 32 warps per block (limited by 1024 threads per block)
  39.   constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;
  40.   int warp = threadIdx.x / WARP_SIZE;
  41.   int lane = threadIdx.x % WARP_SIZE;
  42.   static __shared__ float shared[NUM_WARPS];
  43.   
  44.   val = warp_reduce_max<WARP_SIZE>(val);
  45.   if (lane == 0) shared[warp] = val;
  46.   __syncthreads();
  47.   val = (lane < NUM_WARPS) ? shared[lane] : -FLT_MAX;
  48.   val = warp_reduce_max<NUM_WARPS>(val);
  49.   return val;
  50. }

warp reduce几乎已经成为大部分reduce kernel的标准写法了,比如vLLM中,就是这种经典的写法。所以,先搞懂warp reduce(也就是搞懂各种warp functions的用法),再去写其他kernel,思路就会容易很多。需要注意的是,warp函数处理的是寄存器上的数据,也就是说,此时,没必要先加载数据到smem,再进行reduce,直接加载到寄存器即可(以前犯过这个小错误...)。Warp Functions建议参考:jhang:CUDA编程入门之Warp-Level Primitives

0x04 block all reduce + vec4

  1. // Block All Reduce Sum
  2. // grid(N/128), block(128)
  3. // a: Nx1, y=sum(a)
  4. template<const int NUM_THREADS = 128>
  5. __global__ void block_all_reduce_sum(float* a, float* y, int N) {
  6.   int tid = threadIdx.x;
  7.   int idx = blockIdx.x * NUM_THREADS + tid;
  8.   constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;
  9.   __shared__ float reduce_smem[NUM_WARPS];
  10.   // keep the data in register is enougth for warp operaion.
  11.   float sum = (idx < N) ? a[idx] : 0.0f;
  12.   int warp = tid / WARP_SIZE;
  13.   int lane = tid % WARP_SIZE;
  14.   // perform warp sync reduce.
  15.   sum = warp_reduce_sum<WARP_SIZE>(sum);
  16.   // warp leaders store the data to shared memory.
  17.   if (lane == 0) reduce_smem[warp] = sum;
  18.   __syncthreads(); // make sure the data is in shared memory.
  19.   // the first warp compute the final sum.
  20.   sum = (lane < NUM_WARPS) ? reduce_smem[lane] : 0.0f;
  21.   if (warp == 0) sum = warp_reduce_sum<NUM_WARPS>(sum);
  22.   if (tid == 0) atomicAdd(y, sum);
  23. }
  24. // Block All Reduce Sum + float4
  25. // grid(N/128), block(128/4)
  26. // a: Nx1, y=sum(a)
  27. template<const int NUM_THREADS = 128/4>
  28. __global__ void block_all_reduce_sum_vec4(float* a, float* y, int N) {
  29.   int tid = threadIdx.x;
  30.   int idx = (blockIdx.x * NUM_THREADS + tid) * 4;
  31.   constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;
  32.   __shared__ float reduce_smem[NUM_WARPS];
  33.   float4 reg_a = FLOAT4(a[idx]);
  34.   // keep the data in register is enougth for warp operaion.
  35.   float sum = (idx < N) ? (reg_a.x + reg_a.y + reg_a.z + reg_a.w) : 0.0f;
  36.   int warp = tid / WARP_SIZE;
  37.   int lane = tid % WARP_SIZE;
  38.   // perform warp sync reduce.
  39.   sum = warp_reduce_sum<WARP_SIZE>(sum);
  40.   // warp leaders store the data to shared memory.
  41.   if (lane == 0) reduce_smem[warp] = sum;
  42.   __syncthreads(); // make sure the data is in shared memory.
  43.   // the first warp compute the final sum.
  44.   sum = (lane < NUM_WARPS) ? reduce_smem[lane] : 0.0f;
  45.   if (warp == 0) sum = warp_reduce_sum<NUM_WARPS>(sum);
  46.   if (tid == 0) atomicAdd(y, sum);
  47. }

block all reduce是在warp reduce的基础上进行的,reduce_smem这部分的共享内存申请无法避免,这是用来同步每个warp之间得到局部结果。注意,最后,还需要atomicAdd做一个block级别的原子操作,以得到全局的和。float4向量化优化访存,可以减缓WarpScheduler发送指令的压力。

0x05 sgemv k32/k128/k16 kernel

  1. // SGEMV: Warp SGEMV K32
  2. // 假设K为32的倍数,每个warp负责一行
  3. // grid(M/4), block(32,4) blockDim.x=32=K, blockDim.y=4
  4. // a: MxK, x: Kx1, y: Mx1, compute: y = a * x
  5. __global__ void sgemv_k32(float* a, float* x, float* y, int M, int K) {
  6.   int tx = threadIdx.x;         // 0~31
  7.   int ty = threadIdx.y;         // 0~4
  8.   int bx = blockIdx.x;          // 0~M/4
  9.   int lane = tx % WARP_SIZE;    // 0~31
  10.   int m = bx * blockDim.y + ty; // (0~M/4) * 4 + (0~3)
  11.   if (m < M) {
  12.     float sum = 0.0f;
  13.     int NUM_WARPS = (K + WARP_SIZE - 1) / WARP_SIZE;
  14.     #pragma unroll
  15.     for (int w = 0; w < NUM_WARPS; ++w) {
  16.       // 若NUM_WARPS>=2,先将当前行的数据累加到第一个warp中
  17.       int k = w * WARP_SIZE + lane;
  18.       sum += a[m * K + k] * x[k];
  19.     }
  20.     sum = warp_reduce_sum<WARP_SIZE>(sum);
  21.     if (lane == 0) y[m] = sum;
  22.   }
  23. }
  24. // SGEMV: Warp SGEMV K128 + Vec4
  25. // 假设K为128的倍数 float4
  26. // grid(M/4), block(32,4) blockDim.x=32=K, blockDim.y=4
  27. // a: MxK, x: Kx1, y: Mx1, compute: y = a * x
  28. __global__ void sgemv_k128(float* a, float* x, float* y, int M, int K) {
  29.   // 每个线程负责4个元素,一个warp覆盖128个元素
  30.   int tx = threadIdx.x;         // 0~31
  31.   int ty = threadIdx.y;         // 0~3
  32.   int bx = blockIdx.x;          // 0~M/4
  33.   int lane = tx % WARP_SIZE;    // 0~31
  34.   int m = blockDim.y * bx + ty; // (0~M/4) * 4 + (0~3)
  35.   
  36.   if (m < M) {
  37.     float sum = 0.0f;
  38.     // process 4*WARP_SIZE elements per warp.
  39.     int NUM_WARPS = (((K + WARP_SIZE - 1) / WARP_SIZE) + 4 - 1) / 4;
  40.     #pragma unroll
  41.     for (int w = 0; w < NUM_WARPS; ++w) {
  42.       int k = (w * WARP_SIZE + lane) * 4;
  43.       float4 reg_x = FLOAT4(x[k]);
  44.       float4 reg_a = FLOAT4(a[m * K + k]);
  45.       sum += (reg_a.x * reg_x.x + reg_a.y * reg_x.y 
  46.             + reg_a.z * reg_x.z + reg_a.w * reg_x.w);
  47.     }
  48.     sum = warp_reduce_sum<WARP_SIZE>(sum);
  49.     if(lane == 0) y[m] = sum;
  50.   }
  51. }
  52. // SGEMV: Warp SGEMV K16
  53. // 假设K为16 < 32,每个warp负责2行,每行有16个元素
  54. // NUM_THREADS=128, NUM_WARPS=NUM_THREADS/WARP_SIZE;
  55. // NUM_ROWS=NUM_WARPS * ROW_PER_WARP, grid(M/NUM_ROWS), block(32,NUM_WARPS)
  56. // a: MxK, x: Kx1, y: Mx1, compute: y = a * x
  57. template<const int ROW_PER_WARP = 2
  58. __global__ void sgemv_k16(float* A, float* x, float* y, int M, int K) {
  59.   constexpr int K_WARP_SIZE = (WARP_SIZE + ROW_PER_WARP - 1) / ROW_PER_WARP;
  60.   int tx = threadIdx.x;       // 0~31
  61.   int ty = threadIdx.y;       // 0~NUM_WARPS
  62.   int bx = blockIdx.x;        // 0~M/NUM_ROWS (NUM_ROWS=NUM_WARPS * ROW_PER_WARP)
  63.   int lane = tx % WARP_SIZE;  // 0~31
  64.   int k = lane % K_WARP_SIZE; // 0~15
  65.   // gloabl row of a: MxK and y:Mx1, blockDim.y=NUM_WARPS
  66.   int m = (blockDim.y * bx + ty) * ROW_PER_WARP + lane / K_WARP_SIZE;
  67.   if (m < M) {
  68.     float sum = A[m * K + k] * x[k];
  69.     sum = warp_reduce_sum<K_WARP_SIZE>(sum);
  70.     // 注意是k == 0,而不是lane == 0
  71.     if(k == 0) y[m] = sum; 
  72.   }
  73. }

估计有些大佬倒立都能写sgemv的各种优化版了,核心思路其实也是基于warp reduce,考虑K的不同情况进行优化。本文的sgemv kernel修改自:有了琦琦的棍子:深入浅出GPU优化系列:gemv优化

0x06 dot product, dot product + vec4

  1. // Dot Product
  2. // grid(N/128), block(128)
  3. // a: Nx1, b: Nx1, y=sum(elementwise_mul(a,b))
  4. template<const int NUM_THREADS = 128>
  5. __global__ void dot(float* a, float* b, float* y, int N) {
  6.   int tid = threadIdx.x;
  7.   int idx = blockIdx.x * NUM_THREADS + tid;
  8.   constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;
  9.   __shared__ float reduce_smem[NUM_WARPS];
  10.   // keep the data in register is enougth for warp operaion.
  11.   float prod = (idx < N) ? a[idx] * b[idx] : 0.0f;
  12.   int warp = tid / WARP_SIZE;
  13.   int lane = tid % WARP_SIZE;
  14.   // perform warp sync reduce.
  15.   prod = warp_reduce_sum<WARP_SIZE>(prod);
  16.   // warp leaders store the data to shared memory.
  17.   if (lane == 0) reduce_smem[warp] = prod;
  18.   __syncthreads(); // make sure the data is in shared memory.
  19.   // the first warp compute the final sum.
  20.   prod = (lane < NUM_WARPS) ? reduce_smem[lane] : 0.0f;
  21.   if (warp == 0) prod = warp_reduce_sum<NUM_WARPS>(prod);
  22.   if (tid == 0) atomicAdd(y, prod);
  23. }
  24. // Dot Product + Vec4
  25. // grid(N/128), block(128/4)
  26. // a: Nx1, b: Nx1, y=sum(elementwise_mul(a,b))
  27. template<const int NUM_THREADS = 128/4>
  28. __global__ void dot_vec4(float* a, float* b, float* y, int N) {
  29.   int tid = threadIdx.x;
  30.   int idx = (blockIdx.x * NUM_THREADS + tid) * 4;
  31.   constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;
  32.   __shared__ float reduce_smem[NUM_WARPS];
  33.   float4 reg_a = FLOAT4(a[idx]);
  34.   float4 reg_b = FLOAT4(b[idx]);
  35.   float prod = (idx < N) ? (reg_a.x * reg_b.x + reg_a.y * reg_b.y 
  36.                           + reg_a.z * reg_b.z + reg_a.w * reg_b.w) : 0.0f;
  37.   int warp = tid / WARP_SIZE;
  38.   int lane = tid % WARP_SIZE;
  39.   // perform warp sync reduce.
  40.   prod = warp_reduce_sum<WARP_SIZE>(prod);
  41.   // warp leaders store the data to shared memory.
  42.   if (lane == 0) reduce_smem[warp] = prod;
  43.   __syncthreads(); // make sure the data is in shared memory.
  44.   // the first warp compute the final sum.
  45.   prod = (lane < NUM_WARPS) ? reduce_smem[lane] : 0.0f;
  46.   if (warp == 0) prod = warp_reduce_sum<NUM_WARPS>(prod);
  47.   if (tid == 0) atomicAdd(y, prod);
  48. }

dot product kernel的核心就是block reduce,不多说了。

0x07 elementwise, elementwise + vec4

  1. // ElementWise Add  
  2. // grid(N/128), block(128)
  3. // a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b)
  4. __global__ void elementwise_add(float* a, float* b, float* c, int N) {
  5.   int idx = blockIdx.x * blockDim.x + threadIdx.x;
  6.   if (idx < N) c[idx] = a[idx] + b[idx];
  7. }
  8. // ElementWise Add + Vec4
  9. // grid(N/128), block(128/4)
  10. // a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b)
  11. __global__ void elementwise_add_vec4(float* a, float* b, float* c, int N) {
  12.   int idx = 4 * (blockIdx.x * blockDim.x + threadIdx.x);
  13.   if (idx < N) {
  14.     float4 reg_a = FLOAT4(a[idx]);
  15.     float4 reg_b = FLOAT4(b[idx]);
  16.     float4 reg_c;
  17.     reg_c.x = reg_a.x + reg_b.x;
  18.     reg_c.y = reg_a.y + reg_b.y;
  19.     reg_c.z = reg_a.z + reg_b.z;
  20.     reg_c.w = reg_a.w + reg_b.w;
  21.     FLOAT4(c[idx]) = reg_c;
  22.   }
  23. }

elementwise可以考虑加点向量化进行访存优化。

0x08 histogram, histogram + vec4

  1. // Histogram
  2. // grid(N/128), block(128)
  3. // a: Nx1, y: count histogram
  4. __global__ void histogram(int* a, int* y, int N) {
  5.   int idx = blockIdx.x * blockDim.x + threadIdx.x;
  6.   if (idx < N) atomicAdd(&(y[a[idx]]), 1);
  7. }
  8. // Histogram + Vec4
  9. // grid(N/128), block(128/4)
  10. // a: Nx1, y: count histogram
  11. __global__ void histogram_vec4(int* a, int* y, int N) {
  12.   int idx = 4 * (blockIdx.x * blockDim.x + threadIdx.x);
  13.   if (idx < N) {
  14.     int4 reg_a = INT4(a[idx]);
  15.     atomicAdd(&(y[reg_a.x]), 1);
  16.     atomicAdd(&(y[reg_a.y]), 1);
  17.     atomicAdd(&(y[reg_a.z]), 1);
  18.     atomicAdd(&(y[reg_a.w]), 1);
  19.   }
  20. }

统计频数直方图,很简单,两行代码搞定。

0x09 softmax, softmax + vec4 (grid level memory fence)

  1. // Softmax x: N, y: N
  2. // grid(N/128), block(K=128)
  3. template<const int NUM_THREADS = 128>
  4. __global__ void softmax(float* x, float* y, float* total, int N) {
  5.   const int tid = threadIdx.x;
  6.   const int idx = blockIdx.x * blockDim.x + tid; 
  7.   constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;
  8.   __shared__ float reduce_smem[NUM_WARPS];
  9.   
  10.   float sum = (idx < N) ? expf(x[idx]) : 0.0f;
  11.   int warp = tid / WARP_SIZE;
  12.   int lane = tid % WARP_SIZE;
  13.   sum = warp_reduce_sum<WARP_SIZE>(sum);
  14.   if (lane == 0) reduce_smem[warp] = sum;
  15.   __syncthreads();
  16.   // compute the final sum in each warp
  17.   sum = (lane < NUM_WARPS) ? reduce_smem[lane] : 0.0f;
  18.   sum = warp_reduce_sum<NUM_WARPS>(sum); // sum(e^x_0,...,e^x_n-1)
  19.   // get the total sum of all blocks.
  20.   if (tid == 0) atomicAdd(total, sum);
  21.   __threadfence(); // grid level memory fence 注意这里需要网格级别的内存同步
  22.   // e^x_i/sum(e^x_0,...,e^x_n-1) 
  23.   if (idx < N) y[idx] = block_smem[tid] / (*total); 
  24. }
  25. // Softmax x: N, y: N
  26. // grid(N/128), block(K=128)
  27. template<const int NUM_THREADS = 128>
  28. __global__ void softmax_v2(float* x, float* y, float* total, int N) {
  29.   const int tid = threadIdx.x;
  30.   const int idx = blockIdx.x * blockDim.x + tid; 
  31.   
  32.   float exp_val = (idx < N) ? expf(x[idx]) : 0.0f;
  33.   float sum = block_reduce_sum<NUM_THREADS>(exp_val);
  34.   // get the total sum of all blocks.
  35.   if (tid == 0) atomicAdd(total, sum);
  36.   __threadfence(); // grid level memory fence  注意这里需要网格级别的内存同步
  37.   // e^x_i/sum(e^x_0,...,e^x_n-1) 
  38.   if (idx < N) y[idx] = exp_val / (*total); 
  39. }
  40. // Softmax Vec4 x: N, y: N
  41. // grid(N/128), block(128/4)
  42. template<const int NUM_THREADS = 128/4>
  43. __global__ void softmax_v2_vec4(float* x, float* y, float* total, int N) {
  44.   const int tid = threadIdx.x;
  45.   const int idx = (blockIdx.x * blockDim.x + tid) * 4
  46.   
  47.   float4 reg_x = FLOAT4(x[idx]);
  48.   float4 reg_exp;
  49.   reg_exp.x = (idx < N) ? expf(reg_x.x) : 0.0f;
  50.   reg_exp.y = (idx < N) ? expf(reg_x.y) : 0.0f;
  51.   reg_exp.z = (idx < N) ? expf(reg_x.z) : 0.0f;
  52.   reg_exp.w = (idx < N) ? expf(reg_x.w) : 0.0f;
  53.   float exp_val = (reg_exp.x + reg_exp.y + reg_exp.z + reg_exp.w);
  54.   float sum = block_reduce_sum<NUM_THREADS>(exp_val);
  55.   // get the total sum of all blocks.
  56.   if (tid == 0) atomicAdd(total, sum);
  57.   __threadfence(); // grid level memory fence  注意这里需要网格级别的内存同步
  58.   // e^x_i/sum(e^x_0,...,e^x_n-1) 
  59.   if (idx < N) {
  60.     float4 reg_y;
  61.     reg_y.x = reg_exp.x / (*total);
  62.     reg_y.y = reg_exp.y / (*total);
  63.     reg_y.z = reg_exp.z / (*total);
  64.     reg_y.w = reg_exp.w / (*total);
  65.     FLOAT4(y[idx]) = reg_y; 
  66.   }
  67. }

softmax稍微要注意的就是内存同步的问题,这里,你需要做一个网格级别的同步,而不能仅仅是block级别,否则拿不到全局的exp sum作为分母项。因此使用 __threadfence 这个网格及内存同步操作。不过效率我还没测过,实在要高效的话,可能得整成FA2那样的 1-pass + online softmax的实现。不过,如果是面试的话,就不要太为难自己了... ,但是FA1/FA2的论文很经典,强烈建议多读几遍。

0x0a sigmoid, sigmoid + vec4

  1. // Sigmoid x: N, y: N y=1/(1+exp(-x))
  2. // grid(N/128), block(K=128) 
  3. __global__ void sigmoid(float* x, float* y, int N) {
  4.   int idx = blockIdx.x * blockDim.x + threadIdx.x;
  5.   if (idx < N) y[idx] = 1.0f / (1.0f + expf(-x[idx]));
  6. }
  7. // Sigmoid x: N, y: N y=1/(1+exp(-x)) Vec4
  8. // grid(N/128), block(128/4)
  9. __global__ void sigmoid_vec4(float* x, float* y, int N) {
  10.   int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
  11.   if (idx < N) {
  12.     float4 reg_x = FLOAT4(x[idx]);
  13.     float4 reg_y;
  14.     reg_y.x = 1.0f / (1.0f + expf(-reg_x.x));
  15.     reg_y.y = 1.0f / (1.0f + expf(-reg_x.y));
  16.     reg_y.z = 1.0f / (1.0f + expf(-reg_x.z));
  17.     reg_y.w = 1.0f / (1.0f + expf(-reg_x.w));
  18.     FLOAT4(y[idx]) = reg_y;
  19.   }
  20. }

0x0b relu, relu + vec4

  1. // Relu x: N, y: N y=max(0,x)
  2. // grid(N/128), block(K=128) 
  3. __global__ void relu(float* x, float* y, int N) {
  4.   int idx = blockIdx.x * blockDim.x + threadIdx.x;
  5.   if (idx < N) y[idx] = fmaxf(0.0f, x[idx]);
  6. }
  7. // Relu x: N, y: N y=max(0,x) Vec4
  8. // grid(N/128/4), block(128/4) 
  9. __global__ void relu_vec4(float* x, float* y, int N) {
  10.   int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
  11.   if (idx < N) {
  12.     float4 reg_x = FLOAT4(x[idx]);
  13.     float4 reg_y;
  14.     reg_y.x = fmaxf(0.0f, reg_x.x);
  15.     reg_y.y = fmaxf(0.0f, reg_x.y);
  16.     reg_y.z = fmaxf(0.0f, reg_x.z);
  17.     reg_y.w = fmaxf(0.0f, reg_x.w);
  18.     FLOAT4(y[idx]) = reg_y;
  19.   }
  20. }

0x0c layer_norm, layer_norm + vec4

  1. // Layer Norm: x: NxK(K=128<1024), y': NxK, y'=x-mean(x)/std(x) each row
  2. // mean(x) = sum(x)/K, 1/std(x) = rsqrtf( sum( (x-mean(x))^2 )/K ) each row
  3. // grid(N*K/K), block(K<1024) N=batch_size*seq_len, K=hidden_size
  4. // y=y'*g + b (g: scale, b: bias)
  5. template<const int NUM_THREADS=128>
  6. __global__ void layer_norm(float* x, float* y, float g, float b, int N, int K) {
  7.   int tid = threadIdx.x; // 0..K-1
  8.   int bid = blockIdx.x; // 0..N-1
  9.   int idx = bid * blockDim.x + threadIdx.x;
  10.   const float epsilon = 1e-5f;
  11.   __shared__ float s_mean; // shared within block
  12.   __shared__ float s_variance; // shared within block
  13.   float value = (idx < N * K) ? x[idx] : 0.0f; // load once only
  14.   float sum = block_reduce_sum<NUM_THREADS>(value);
  15.   if (tid == 0) s_mean = sum / (float) K;
  16.   // wait for s_mean in shared memory to be ready for all threads
  17.   __syncthreads();
  18.   float variance = (value - s_mean) * (value - s_mean);
  19.   variance = block_reduce_sum<NUM_THREADS>(variance);
  20.   if (tid == 0) s_variance = rsqrtf(variance / (float) K + epsilon);
  21.   // wait for s_variance in shared memory to be ready for all threads
  22.   __syncthreads();
  23.   if (idx < N * K) y[idx] = ((value - s_mean) * s_variance) * g + b;
  24. }
  25. // Layer Norm Vec4: x: NxK(K=128<1024), y': NxK, y'=x-mean(x)/std(x) each row
  26. // mean(x) = sum(x)/K, 1/std(x) = rsqrtf( sum( (x-mean(x))^2 )/K ) each row
  27. // grid(N*K/K), block(K/4<1024) N=batch_size*seq_len, K=hidden_size
  28. // y=y'*g + b (g: scale, b: bias)
  29. template<const int NUM_THREADS=128/4>
  30. __global__ void layer_norm_vec4(float* x, float* y, float g, float b, int N, int K) {
  31.   int tid = threadIdx.x; // 0..K-1
  32.   int bid = blockIdx.x; // 0..N-1
  33.   int idx = (bid * blockDim.x + threadIdx.x) * 4;
  34.   const float epsilon = 1e-5f;
  35.   __shared__ float s_mean; // shared within block
  36.   __shared__ float s_variance; // shared within block
  37.   float4 reg_x = FLOAT4(x[idx])
  38.   float value = (idx < N * K) ? (reg_x.x + reg_x.y 
  39.                                + reg_x.z + reg_x.w) : 0.0f;
  40.   float sum = block_reduce_sum<NUM_THREADS>(value);
  41.   if (tid == 0) s_mean = sum / (float) K;
  42.   // wait for s_mean in shared memory to be ready for all threads
  43.   __syncthreads();
  44.   float4 reg_x_hat;
  45.   reg_x_hat.x = reg_x.x - s_mean;
  46.   reg_x_hat.y = reg_x.y - s_mean;
  47.   reg_x_hat.z = reg_x.z - s_mean;
  48.   reg_x_hat.w = reg_x.w - s_mean;
  49.   float variance = reg_x_hat.x * reg_x_hat.x + reg_x_hat.y * reg_x_hat.y 
  50.                  + reg_x_hat.z * reg_x_hat.z + reg_x_hat.w * reg_x_hat.w;
  51.   variance = block_reduce_sum<NUM_THREADS>(variance);
  52.   if (tid == 0) s_variance = rsqrtf(variance / (float) K + epsilon);
  53.   // wait for s_variance in shared memory to be ready for all threads
  54.   __syncthreads();
  55.   float4 reg_y;
  56.   reg_y.x = reg_x_hat.x * s_variance * g + b;
  57.   reg_y.y = reg_x_hat.y * s_variance * g + b;
  58.   reg_y.z = reg_x_hat.z * s_variance * g + b;
  59.   reg_y.w = reg_x_hat.w * s_variance * g + b;
  60.   if (idx < N * K) FLOAT4(y[idx]) = reg_y;
  61. }

layer norm实现的核心同样也是block reduce和warp reduce,然后再整点向量化...

0x0d rms_norm, rms_norm + vec4

  1. // RMS Norm: x: NxK(K=128<1024), y': NxK, y'=x/rms(x) each row
  2. // 1/rms(x) = rsqrtf( sum(x^2)/K ) each row
  3. // grid(N*K/K), block(K<1024) N=batch_size*seq_len, K=hidden_size
  4. // y=y'*g (g: scale)
  5. template<const int NUM_THREADS=128>
  6. __global__ void rms_norm(float* x, float* y, float g, int N, int K) {
  7.   int tid = threadIdx.x; // 0..K-1
  8.   int bid = blockIdx.x; // 0..N-1
  9.   int idx = bid * blockDim.x + threadIdx.x;
  10.   const float epsilon = 1e-5f;
  11.   __shared__ float s_variance; // shared within block
  12.   float value = (idx < N * K) ? x[idx] : 0.0f; // load once only
  13.   float variance = value * value;
  14.   variance = block_reduce_sum<NUM_THREADS>(variance);
  15.   if (tid == 0) s_variance = rsqrtf(variance / (float) K + epsilon);
  16.   // wait for s_variance in shared memory to be ready for all threads
  17.   __syncthreads(); 
  18.   if (idx < N * K) y[idx] = (value * s_variance) * g;
  19. }
  20. // RMS Norm Vec4: x: NxK(K=128<1024), y': NxK, y'=x/rms(x) each row
  21. // 1/rms(x) = rsqrtf( sum(x^2)/K ) each row
  22. // grid(N*K/K), block(K/4<1024) N=batch_size*seq_len, K=hidden_size
  23. // y=y'*g (g: scale)
  24. template<const int NUM_THREADS=128/4>
  25. __global__ void rms_norm_vec4(float* x, float* y, float g, int N, int K) {
  26.   int tid = threadIdx.x; // 0..K-1
  27.   int bid = blockIdx.x; // 0..N-1
  28.   int idx = (bid * blockDim.x + threadIdx.x) * 4;
  29.   const float epsilon = 1e-5f;
  30.   __shared__ float s_variance; // shared within block
  31.   float4 reg_x = FLOAT4(x[idx]);
  32.   float variance = (idx < N * K) ? (reg_x.x * reg_x.x + reg_x.y * reg_x.y 
  33.                                   + reg_x.z * reg_x.z + reg_x.w * reg_x.w) : 0.0f;
  34.   variance = block_reduce_sum<NUM_THREADS>(variance);
  35.   if (tid == 0) s_variance = rsqrtf(variance / (float) K + epsilon);
  36.   // wait for s_variance in shared memory to be ready for all threads
  37.   __syncthreads(); 
  38.   float4 reg_y;
  39.   reg_y.x = reg_x.x * s_variance * g;
  40.   reg_y.y = reg_x.y * s_variance * g;
  41.   reg_y.z = reg_x.z * s_variance * g;
  42.   reg_y.w = reg_x.w * s_variance * g;
  43.   if (idx < N * K) FLOAT4(y[idx]) = reg_y;
  44. }

rms norm实现的核心同样也是block reduce和warp reduce...,然后再加点float4向量化什么的。

0x0e NMS

  1. struct Box {
  2.   float x1, y1, x2, y2, score;
  3.   float area() const {return (std::abs(x2 - x1 + 1)) * std::abs(y2 - y1 + 1); }
  4.   float iou_of(const Box& other) const{
  5.     float inner_x1 = x1 > other.x1 ? x1 : other.x1;
  6.     float inner_y1 = y1 > other.y1 ? y1 : other.y1;
  7.     float inner_x2 = x2 < other.x2 ? x2 : other.x2;
  8.     float inner_y2 = y2 < other.y2 ? y2 : other.y2;
  9.     float inner_h = inner_y2 - inner_y1 + 1.0f;
  10.     float inner_w = inner_x2 - inner_x1 + 1.0f;
  11.     float inner_area = inner_h * inner_w;
  12.     return (inner_area / (area() + tbox.area() - inner_area));
  13.   }
  14. }
  15. void hard_nms(std::vector<Box> &input, std::vector<Box> &output, float iou_threshold){
  16.   if (input.empty()) return;
  17.   std::sort(input.begin(), input.end(),[](Box& a, Box& b) { return a.score > b.score; });
  18.   int box_num = input.size();
  19.   std::vector<int> merged(box_num, 0);
  20.   for (int i = 0; i < box_num; ++i) {
  21.     if (merged[i]) continue;
  22.     merged[i] = 1;
  23.     for (int j = i + 1; j < box_num; ++j) {
  24.       if (merged[j]) continue;
  25.       float iou = input[i].iou_of(input[j]);
  26.       if (iou > iou_threshold) merged[j] = 1;
  27.     }
  28.     output.push_back(input[i]);
  29.   }
  30. }

CV相关的经常会要手撕NMS,也记录下。

0x0f 总结

可以发现,大部分kernel的基本写法都是依赖warp reduce和block reduce的,基本上只要熟练应用warp functions各种场景的写法,应该问题不大;softmax需要考虑网格级同步的问题,或者online softmax以及FlashAttention;sgemm的优化是个很大的课题,不是案例中写的这么简单,但是入门的话,基本就是tiling的思想以及如何做索引之间的mapping;sgemv的优化则主要考虑K不同的值(因为M为1了),比如K=16,64,128等情况下,如何按照warp来处理;relu、sigmoid等都是elementwise的操作,很好实现,可以再考虑加点向量化优化访存;layer norm和rms norm在数学上其实也是挺清晰简单的,落实到cuda kernel时,只要按照逐个token来处理,headdim没有超过1024的情况下(一个block最多可以放1024个threads),可以放到一个block处理,这样并行化就很好写。当然,核心还是warp reduce和block reduce;NMS是乱入的,没有CUDA版本,别问了...

最后,附GitHub repo:

https://github.com/DefTruth/cuda-learn-note/
github.com/DefTruth/cuda-learn-note/

投稿作者为『自动驾驶之心知识星球』特邀嘉宾,欢迎加入交流!

① 全网独家视频课程

BEV感知、BEV模型部署、BEV目标跟踪、毫米波雷达视觉融合多传感器标定多传感器融合多模态3D目标检测车道线检测轨迹预测在线高精地图世界模型点云3D目标检测目标跟踪Occupancy、cuda与TensorRT模型部署大模型与自动驾驶Nerf语义分割自动驾驶仿真、传感器部署、决策规划、轨迹预测等多个方向学习视频(扫码即可学习

cfd0f1aa6bacd69d95600ec39a1596b5.png

网页端官网:www.zdjszx.com

② 国内首个自动驾驶学习社区

国内最大最专业,近3000人的交流社区,已得到大多数自动驾驶公司的认可!涉及30+自动驾驶技术栈学习路线,从0到一带你入门自动驾驶感知2D/3D检测、语义分割、车道线、BEV感知、Occupancy、多传感器融合、多传感器标定、目标跟踪)、自动驾驶定位建图SLAM、高精地图、局部在线地图)、自动驾驶规划控制/轨迹预测等领域技术方案大模型、端到端等,更有行业动态和岗位发布!欢迎扫描下方二维码,加入自动驾驶之心知识星球,这是一个真正有干货的地方,与领域大佬交流入门、学习、工作、跳槽上的各类难题,日常分享论文+代码+视频

d1fa20ccba4747611cf84fc42a09d851.png

③【自动驾驶之心】技术交流群

自动驾驶之心是首个自动驾驶开发者社区,聚焦感知、定位、融合、规控、标定、端到端、仿真、产品经理、自动驾驶开发、自动标注与数据闭环多个方向,目前近60+技术交流群,欢迎加入!扫码添加汽车人助理微信邀请入群,备注:学校/公司+方向+昵称(快速入群方式)

872e2814eaff4826eab7670bda2d098b.jpeg

④【自动驾驶之心】全平台矩阵

fd0b6f60b7c2485dd2ee77c2b2869ad1.png

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/喵喵爱编程/article/detail/856527
推荐阅读
相关标签
  

闽ICP备14008679号