当前位置:   article > 正文

train_gpt2_fp32.cu

train_gpt2_fp32.cu

源程序

llm.c/test_gpt2_fp32.cu at master · karpathy/llm.c (github.com)

  1. #include <stdio.h>
  2. #include <stdlib.h>
  3. #include <math.h>
  4. #include <time.h>
  5. #include <assert.h>
  6. #include <float.h>
  7. #include <string.h>
  8. #include <unistd.h>
  9. #include <cublas_v2.h>
  10. #include <cuda_runtime.h>
  11. #include <cublasLt.h>
  12. #include <cooperative_groups.h>
  13. #include <cooperative_groups/reduce.h>
  14. #include "utils.h"
  15. #include "tokenizer.h"
  16. #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
  17. void cudaCheck(cudaError_t error, const char *file, int line) {
  18. if (error != cudaSuccess) {
  19. printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line,
  20. cudaGetErrorString(error));
  21. exit(EXIT_FAILURE);
  22. }
  23. };
  24. #define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__))
  25. void cublasCheck(cublasStatus_t status, const char *file, int line)
  26. {
  27. if (status != CUBLAS_STATUS_SUCCESS) {
  28. printf("[cuBLAS ERROR]: %d %s %d\n", status, file, line);
  29. exit(EXIT_FAILURE);
  30. }
  31. }
  32. #define cublasCheck(status) { cublasCheck((status), __FILE__, __LINE__); }
  33. static size_t cublaslt_workspace_size = 32 * 1024 * 1024;
  34. static void* cublaslt_workspace = NULL;
  35. static cublasComputeType_t cublas_compute_type;
  36. cublasHandle_t cublas_handle;
  37. cublasLtHandle_t cublaslt_handle;
  38. namespace cg = cooperative_groups;
  39. __device__ inline float4 add_float4(const float4& a, const float4& b) {
  40. return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
  41. }
  42. __global__ void encoder_forward_kernel3(float4* out,
  43. const int* inp, const float4* wte, const float4* wpe,
  44. int B, int T, int C) {
  45. int C4 = C / 4;
  46. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  47. int N = B * T * C4;
  48. if (idx < N) {
  49. int bt = idx / C4;
  50. int b = bt / T;
  51. int t = bt % T;
  52. int c4 = idx % C4;
  53. int ix = inp[b * T + t];
  54. out[b * T * C4 + t * C4 + c4] = add_float4(wte[ix * C4 + c4], wpe[t * C4 + c4]);
  55. }
  56. }
  57. __global__ void encoder_backward_kernel(float* dwte, float* dwpe,
  58. const float* dout, const int* inp,
  59. int B, int T, int C) {
  60. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  61. int N = B * T * C;
  62. if (idx < N) {
  63. int bt = idx / C;
  64. int b = bt / T;
  65. int t = bt % T;
  66. int c = idx % C;
  67. int ix = inp[b * T + t];
  68. const float* dout_btc = dout + b * T * C + t * C + c;
  69. float* dwte_ix = dwte + ix * C + c;
  70. float* dwpe_tc = dwpe + t * C + c;
  71. atomicAdd(dwte_ix, *dout_btc);
  72. atomicAdd(dwpe_tc, *dout_btc);
  73. }
  74. }
  75. __global__ void layernorm_forward_kernel3(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd,
  76. const float* __restrict__ inp, const float* __restrict__ weight,
  77. const float* __restrict__ bias, int N, int C) {
  78. cg::thread_block block = cg::this_thread_block();
  79. cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
  80. int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
  81. if(idx >= N) {
  82. return;
  83. }
  84. const float* x = inp + idx * C;
  85. float sum = 0.0f;
  86. for (int i = warp.thread_rank(); i < C; i += warp.size()) {
  87. sum += x[i];
  88. }
  89. sum = cg::reduce(warp, sum, cg::plus<float>{});
  90. float m = sum / C;
  91. if(warp.thread_rank() == 0 && mean != nullptr) {
  92. __stcs(mean + idx, m);
  93. }
  94. sum = 0.0f;
  95. for (int i = warp.thread_rank(); i < C; i += warp.size()) {
  96. float diff = x[i] - m;
  97. sum += diff * diff;
  98. }
  99. sum = cg::reduce(warp, sum, cg::plus<float>{});
  100. float s = rsqrtf(sum / C + 1e-5f);
  101. if(warp.thread_rank() == 0 && rstd != nullptr) {
  102. __stcs(rstd + idx, s);
  103. }
  104. float* o = out + idx * C;
  105. for (int c = warp.thread_rank(); c < C; c += warp.size()) {
  106. float n = s * (__ldcs(x+c) - m);
  107. __stcs(o+c, n * weight[c] + bias[c]);
  108. }
  109. }
  110. __global__ void permute_kernel(float* q, float* k, float* v,
  111. const float* inp,
  112. int B, int N, int NH, int d) {
  113. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  114. if (idx < B * NH * N * d) {
  115. int b = idx / (NH * N * d);
  116. int rest = idx % (NH * N * d);
  117. int nh_ = rest / (N * d);
  118. rest = rest % (N * d);
  119. int n = rest / d;
  120. int d_ = rest % d;
  121. int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_;
  122. q[idx] = __ldcs(&inp[inp_idx]);
  123. k[idx] = __ldcs(&inp[inp_idx + NH * d]);
  124. v[idx] = __ldcs(&inp[inp_idx + 2 * (NH * d)]);
  125. }
  126. }
  127. __global__ void permute_kernel_backward(float* dinp,
  128. const float* dq, const float* dk, const float* dv,
  129. int B, int N, int NH, int d) {
  130. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  131. if (idx < B * NH * N * d) {
  132. int b = idx / (NH * N * d);
  133. int rest = idx % (NH * N * d);
  134. int nh_ = rest / (N * d);
  135. rest = rest % (N * d);
  136. int n = rest / d;
  137. int d_ = rest % d;
  138. int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_;
  139. dinp[inp_idx] = dq[idx];
  140. dinp[inp_idx + NH * d] = dk[idx];
  141. dinp[inp_idx + 2 * (NH * d)] = dv[idx];
  142. }
  143. }
  144. __global__ void unpermute_kernel(float* inp, float *out, int B, int N, int NH, int d) {
  145. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  146. if (idx < B * NH * N * d) {
  147. int b = idx / (NH * N * d);
  148. int rest = idx % (NH * N * d);
  149. int nh_ = rest / (N * d);
  150. rest = rest % (N * d);
  151. int n = rest / d;
  152. int d_ = rest % d;
  153. int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;
  154. out[other_idx] = __ldcs(&inp[idx]);
  155. }
  156. }
  157. __global__ void unpermute_kernel_backward(float* dinp, const float *dout, int B, int N, int NH, int d) {
  158. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  159. if (idx < B * NH * N * d) {
  160. int b = idx / (NH * N * d);
  161. int rest = idx % (NH * N * d);
  162. int nh_ = rest / (N * d);
  163. rest = rest % (N * d);
  164. int n = rest / d;
  165. int d_ = rest % d;
  166. int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;
  167. dinp[idx] = dout[other_idx];
  168. }
  169. }
  170. __device__ float& vec_at(float4& vec, int index) {
  171. return reinterpret_cast<float*>(&vec)[index];
  172. }
  173. __device__ float vec_at(const float4& vec, int index) {
  174. return reinterpret_cast<const float*>(&vec)[index];
  175. }
  176. __global__ void softmax_forward_kernel5(float* out, float inv_temperature, const float* inp, int N, int T) {
  177. assert(T % 4 == 0);
  178. cg::thread_block block = cg::this_thread_block();
  179. cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
  180. int idx = (gridDim.x - blockIdx.x - 1) * warp.meta_group_size() + warp.meta_group_rank();
  181. if(idx >= N * T) {
  182. return;
  183. }
  184. int own_pos = idx % T;
  185. int pos_by_4 = own_pos / 4;
  186. const float* x = inp + idx * T;
  187. float maxval = -FLT_MAX;
  188. float sumval = 0.0f;
  189. const float4* x_vec = reinterpret_cast<const float4*>(x);
  190. for (int i = warp.thread_rank(); i < pos_by_4; i += warp.size()) {
  191. float4 v = x_vec[i];
  192. float old_maxval = maxval;
  193. for(int k = 0; k < 4; ++k) {
  194. maxval = fmaxf(maxval, vec_at(v, k));
  195. }
  196. sumval *= expf(inv_temperature * (old_maxval - maxval));
  197. for(int k = 0; k < 4; ++k) {
  198. sumval += expf(inv_temperature * (vec_at(v, k) - maxval));
  199. }
  200. }
  201. if(4*pos_by_4 + warp.thread_rank() <= own_pos) {
  202. float old_maxval = maxval;
  203. maxval = fmaxf(maxval, x[4*pos_by_4 + warp.thread_rank()]);
  204. sumval *= expf(inv_temperature * (old_maxval - maxval));
  205. sumval += expf(inv_temperature * (x[4*pos_by_4 + warp.thread_rank()] - maxval));
  206. }
  207. float global_maxval = cg::reduce(warp, maxval, cg::greater<float>{});
  208. sumval *= expf(inv_temperature * (maxval - global_maxval));
  209. float sum = cg::reduce(warp, sumval, cg::plus<float>{});
  210. float norm = 1.f / sum;
  211. for (int i = warp.thread_rank(); i <= own_pos; i += warp.size()) {
  212. float ev = expf(inv_temperature * (__ldcs(x + i) - global_maxval));
  213. __stcs(out + idx * T + i, ev * norm);
  214. }
  215. }
  216. __global__ void residual_forward_kernel(float* out, float* inp1, float* inp2, int N) {
  217. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  218. if (idx < N) {
  219. out[idx] = __ldcs(&inp1[idx]) + __ldcs(&inp2[idx]);
  220. }
  221. }
  222. #define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI)
  223. __global__ void gelu_forward_kernel(float* out, const float* inp, int N) {
  224. int i = blockIdx.x * blockDim.x + threadIdx.x;
  225. if (i < N) {
  226. float xi = inp[i];
  227. float cube = 0.044715f * xi * xi * xi;
  228. out[i] = 0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube)));
  229. }
  230. }
  231. __global__ void gelu_backward_kernel(float* dinp, const float* inp, const float* dout, const int N) {
  232. int i = blockIdx.x * blockDim.x + threadIdx.x;
  233. if (i < N) {
  234. float x = inp[i];
  235. float cube = 0.044715f * x * x * x;
  236. float tanh_arg = GELU_SCALING_FACTOR * (x + cube);
  237. float tanh_out = tanhf(tanh_arg);
  238. float coshf_out = coshf(tanh_arg);
  239. float sech_out = 1.0f / (coshf_out * coshf_out);
  240. float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);
  241. dinp[i] = local_grad * dout[i];
  242. }
  243. }
  244. __global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) {
  245. extern __shared__ float smem[];
  246. const int warp_id = threadIdx.x / warpSize;
  247. const int lane_id = threadIdx.x % warpSize;
  248. const int tl = blockIdx.x * warpSize;
  249. const int vstep = blockDim.x / warpSize;
  250. const float* dout_col = dout + tl + lane_id;
  251. float dout_sum = 0.0f;
  252. for (int row = warp_id; row < B * T; row += vstep) {
  253. dout_sum += dout_col[row * OC];
  254. }
  255. smem[lane_id + warp_id * warpSize] = dout_sum;
  256. __syncthreads();
  257. dout_sum = 0.0f;
  258. if (warp_id == 0) {
  259. for (int j = 0; j < vstep; j++) {
  260. dout_sum += smem[lane_id + j * warpSize];
  261. }
  262. dbias[tl + lane_id] += dout_sum;
  263. }
  264. }
  265. __global__ void layernorm_backward_kernel2(float* dinp, float* dweight, float* dbias,
  266. const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd,
  267. int B, int T, int C) {
  268. extern __shared__ float shared[];
  269. namespace cg = cooperative_groups;
  270. cg::thread_block block = cg::this_thread_block();
  271. cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
  272. int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
  273. int N = B * T;
  274. if(idx >= N) { return; } // thread guards
  275. int b = idx / T;
  276. int t = idx % T;
  277. const float* dout_bt = dout + b * T * C + t * C;
  278. const float* inp_bt = inp + b * T * C + t * C;
  279. float* dinp_bt = dinp + b * T * C + t * C;
  280. const float mean_bt = mean[b * T + t];
  281. const float rstd_bt = rstd[b * T + t];
  282. float* dbias_shared = shared;
  283. float* dweight_shared = shared + C;
  284. #pragma unroll
  285. for(int i = threadIdx.x; i < C; i+= blockDim.x){
  286. dbias_shared[i] = 0.0f;
  287. dweight_shared[i] = 0.0f;
  288. }
  289. __syncthreads();
  290. float dnorm_mean = 0.0f;
  291. float dnorm_norm_mean = 0.0f;
  292. for (int i = warp.thread_rank(); i < C; i += warp.size()) {
  293. float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
  294. float dnorm_i = weight[i] * dout_bt[i];
  295. dnorm_mean += dnorm_i;
  296. dnorm_norm_mean += dnorm_i * norm_bti;
  297. }
  298. dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus<float>{});
  299. dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus<float>{});
  300. dnorm_mean = dnorm_mean / C;
  301. dnorm_norm_mean = dnorm_norm_mean / C;
  302. for (int i = warp.thread_rank(); i < C; i += warp.size()) {
  303. float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
  304. float dnorm_i = weight[i] * dout_bt[i];
  305. atomicAdd(&dbias_shared[i], dout_bt[i]);
  306. atomicAdd(&dweight_shared[i], norm_bti * dout_bt[i]);
  307. float dval = 0.0f;
  308. dval += dnorm_i;
  309. dval -= dnorm_mean;
  310. dval -= norm_bti * dnorm_norm_mean;
  311. dval *= rstd_bt;
  312. dinp_bt[i] += dval;
  313. }
  314. __syncthreads();
  315. for(int i = threadIdx.x; i < C; i+= blockDim.x){
  316. atomicAdd(&dbias[i], dbias_shared[i]);
  317. atomicAdd(&dweight[i], dweight_shared[i]);
  318. }
  319. }
  320. __global__ void softmax_autoregressive_backward_kernel(float* dpreatt, const float* datt, const float* att,
  321. int B, int T, int C, float scale) {
  322. constexpr const int BlockSize = 256;
  323. constexpr int T_per_block = 4;
  324. cg::thread_block block = cg::this_thread_block();
  325. cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
  326. __shared__ float block_acc[32];
  327. int idx = blockIdx.y;
  328. int t0 = T - 1 - T_per_block*blockIdx.x;
  329. att += idx * T * T;
  330. datt += idx * T * T;
  331. dpreatt += idx * T * T;
  332. if (warp.meta_group_rank() == 0) {
  333. block_acc[warp.thread_rank()] = 0;
  334. }
  335. for(int to = 0; to < T_per_block; ++to) {
  336. int t = t0 - to;
  337. if(t < 0) return;
  338. const float* att_bth = att + t * T;
  339. const float* datt_bth = datt + t * T;
  340. float* dpreatt_bth = dpreatt + t * T;
  341. float local_sum = 0;
  342. for (int t2 = block.thread_rank(); t2 <= t; t2 += BlockSize) {
  343. local_sum += att_bth[t2] * datt_bth[t2];
  344. }
  345. block_acc[warp.meta_group_rank()] = cg::reduce(warp, local_sum, cg::plus<float>{});
  346. block.sync();
  347. local_sum = cg::reduce(warp, block_acc[warp.thread_rank()], cg::plus<float>{});
  348. for (int t3 = block.thread_rank(); t3 <= t; t3 += BlockSize) {
  349. float acc = __ldcs(att_bth + t3) * (__ldcs(datt_bth + t3) - local_sum);
  350. __stcs(dpreatt_bth + t3, scale * acc);
  351. }
  352. }
  353. }
  354. __device__ inline float lerp(float start, float end, float weight) {
  355. return fma(weight, end, fma(-weight, start, start));
  356. }
  357. __global__ void adamw_kernel2(float* params_memory, float* grads_memory, float* m_memory, float* v_memory, long num_parameters,
  358. float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) {
  359. int i = blockIdx.x * blockDim.x + threadIdx.x;
  360. if (i >= num_parameters) return;
  361. float grad = grads_memory[i];
  362. float m = m_memory[i];
  363. float v = v_memory[i];
  364. // update the first moment (momentum)
  365. m = lerp(grad, m, beta1);
  366. m_memory[i] = m;
  367. // update the second moment (RMSprop)
  368. v = lerp(grad * grad, v, beta2);
  369. v_memory[i] = v;
  370. m /= beta1_correction;
  371. v /= beta2_correction;
  372. params_memory[i] -= learning_rate * (m / (sqrtf(v) + eps) + weight_decay * params_memory[i]);
  373. }
  374. struct SoftmaxParams {
  375. float Scale;
  376. float Offset;
  377. };
  378. __device__ SoftmaxParams prepare_softmax_blockwide_nofloat4(cg::thread_block_tile<32>& warp,
  379. int idx, const float* inp, int V, int P) {
  380. const float* x = inp + idx * P;
  381. float thread_maxval = -INFINITY;
  382. float thread_sumval = 0.0f;
  383. for (int i = V + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) {
  384. float v = x[i];
  385. float old_maxval = thread_maxval;
  386. thread_maxval = fmaxf(thread_maxval, v);
  387. thread_sumval *= expf((old_maxval - thread_maxval));
  388. thread_sumval += expf(v - thread_maxval);
  389. }
  390. __shared__ float shared_maxval[32];
  391. __shared__ float shared_sumval[32];
  392. int num_warps = blockDim.x / 32;
  393. int warp_id = threadIdx.x / 32;
  394. int lane_id = threadIdx.x % 32;
  395. float warp_maxval = cg::reduce(warp, thread_maxval, cg::greater<float>{});
  396. if (lane_id == 0) { shared_maxval[warp_id] = warp_maxval; }
  397. __syncthreads();
  398. warp_maxval = (lane_id < num_warps) ? shared_maxval[lane_id] : -FLT_MAX;
  399. float block_maxval = cg::reduce(warp, warp_maxval, cg::greater<float>{});
  400. thread_sumval *= expf(thread_maxval - block_maxval);
  401. float warp_sumval = cg::reduce(warp, thread_sumval, cg::plus<float>{});
  402. if (lane_id == 0) { shared_sumval[warp_id] = warp_sumval; }
  403. __syncthreads();
  404. warp_sumval = (lane_id < num_warps) ? shared_sumval[lane_id] : 0.0f;
  405. float block_sumval = cg::reduce(warp, warp_sumval, cg::plus<float>{});
  406. return SoftmaxParams{1.f / block_sumval, block_maxval};
  407. }
  408. __global__ void fused_classifier_kernel3(float* logits, float* losses, float* probs,
  409. const float* dlosses, const int* targets,
  410. int B, int T, int V, int P) {
  411. namespace cg = cooperative_groups;
  412. cg::thread_block block = cg::this_thread_block();
  413. cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
  414. int idx = blockIdx.x;
  415. int ix = targets[idx];
  416. SoftmaxParams sp = prepare_softmax_blockwide_nofloat4(warp, idx, logits, V, P);
  417. if(threadIdx.x == 0) {
  418. float prob = expf(logits[idx * P + ix] - sp.Offset) * sp.Scale;
  419. losses[idx] = -logf(prob);
  420. }
  421. float dloss = dlosses != NULL ? dlosses[idx] : 1.0f / (B*T);
  422. const float* logits_vec = logits + idx * P;
  423. for (int i = threadIdx.x; i < V; i += blockDim.x) {
  424. // this is the 2nd read of logits after the one in prepare_softmax2
  425. // this data will never be needed again, so we reduce cache persistence
  426. float v = __ldcs(&logits_vec[i]);
  427. float prob = expf(v - sp.Offset) * sp.Scale;
  428. if (probs != NULL) {
  429. probs[idx * P + i] = prob;
  430. }
  431. float indicator = (i == ix) ? 1.0f : 0.0f;
  432. logits[idx * P + i] = (prob - indicator) * dloss;
  433. }
  434. }
  435. void encoder_forward(float* out,
  436. const int* inp, const float* wte, const float* wpe,
  437. int B, int T, int C) {
  438. assert(C % 4 == 0);
  439. const int block_size = 512;
  440. const int N = B * T * C;
  441. const int grid_size = CEIL_DIV(N / 4, block_size);
  442. encoder_forward_kernel3<<<grid_size, block_size>>>((float4*) out, inp, (float4*) wte, (float4*) wpe, B, T, C);
  443. cudaCheck(cudaGetLastError());
  444. }
  445. void encoder_backward(float* dwte, float* dwpe,
  446. const float* dout, const int* inp,
  447. int B, int T, int C) {
  448. const int N = B * T * C;
  449. const int block_size = 256;
  450. const int grid_size = CEIL_DIV(N, block_size);
  451. encoder_backward_kernel<<<grid_size, block_size>>>(dwte, dwpe, dout, inp, B, T, C);
  452. cudaCheck(cudaGetLastError());
  453. }
  454. void layernorm_forward(float* out, float* mean, float* rstd,
  455. float* inp, float* weight, float* bias,
  456. int B, int T, int C) {
  457. const int block_size = 512;
  458. const int N = B * T;
  459. const int grid_size = CEIL_DIV(N * 32, block_size);
  460. layernorm_forward_kernel3<<<grid_size, block_size>>>(out, mean, rstd, inp, weight, bias, N, C);
  461. cudaCheck(cudaGetLastError());
  462. }
  463. void matmul_forward_cublaslt(float* out,
  464. float* inp, float* weight, float* bias,
  465. int B, int T, int C, int OC) {
  466. int has_bias = (bias != NULL);
  467. if(((uintptr_t)bias % 16) != 0) {
  468. printf("Bias pointer is not aligned (cuBLASLt requirement)!\n");
  469. exit(EXIT_FAILURE);
  470. }
  471. int returnedResults = 0;
  472. cublasLtMatmulDesc_t operationDesc;
  473. cublasLtMatmulPreference_t preference;
  474. cublasLtMatrixLayout_t weightLayout;
  475. cublasLtMatrixLayout_t inputLayout;
  476. cublasLtMatrixLayout_t outputLayout;
  477. cublasLtMatrixLayout_t biasLayout;
  478. cublasLtMatmulHeuristicResult_t heuristic;
  479. cublasOperation_t opNoTranspose = CUBLAS_OP_N;
  480. cublasOperation_t opTranspose = CUBLAS_OP_T;
  481. cublasLtEpilogue_t epilogueBias = CUBLASLT_EPILOGUE_BIAS;
  482. cublasCheck(cublasLtMatmulDescCreate(&operationDesc, cublas_compute_type, CUDA_R_32F));
  483. cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTranspose, sizeof(opTranspose)));
  484. cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opNoTranspose, sizeof(opNoTranspose)));
  485. if(has_bias) {
  486. cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogueBias,
  487. sizeof(epilogueBias)));
  488. }
  489. cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
  490. cublasCheck(cublasLtMatrixLayoutCreate(&weightLayout, CUDA_R_32F, C, OC, C));
  491. cublasCheck(cublasLtMatrixLayoutCreate(&inputLayout, CUDA_R_32F, C, B*T, C));
  492. cublasCheck(cublasLtMatrixLayoutCreate(&outputLayout, CUDA_R_32F, OC, B*T, OC));
  493. cublasCheck(cublasLtMatrixLayoutCreate(&biasLayout, CUDA_R_32F, OC, 1, OC));
  494. cublasCheck(cublasLtMatmulPreferenceCreate(&preference));
  495. cublasCheck(cublasLtMatmulPreferenceSetAttribute(preference,
  496. CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
  497. &cublaslt_workspace_size, sizeof(cublaslt_workspace_size)));
  498. cublasCheck(cublasLtMatmulAlgoGetHeuristic(cublaslt_handle, operationDesc,
  499. weightLayout, inputLayout, outputLayout, outputLayout,
  500. preference, 1, &heuristic, &returnedResults));
  501. if (returnedResults == 0) {
  502. printf("No cuBLASLt algorithm: B: %d, T: %d, C: %d, OC: %d, bias: %d\n", B, T, C, OC, has_bias);
  503. exit(EXIT_FAILURE);
  504. }
  505. const float alpha = 1.0f, beta = 0.0f;
  506. cublasCheck(cublasLtMatmul(cublaslt_handle, operationDesc,
  507. &alpha, weight, weightLayout, inp, inputLayout, &beta,
  508. out, outputLayout, out, outputLayout, &heuristic.algo,
  509. cublaslt_workspace, cublaslt_workspace_size, 0));
  510. cublasCheck(cublasLtMatmulPreferenceDestroy(preference));
  511. cublasCheck(cublasLtMatmulDescDestroy(operationDesc));
  512. cublasCheck(cublasLtMatrixLayoutDestroy(weightLayout));
  513. cublasCheck(cublasLtMatrixLayoutDestroy(inputLayout));
  514. cublasCheck(cublasLtMatrixLayoutDestroy(outputLayout));
  515. cublasCheck(cublasLtMatrixLayoutDestroy(biasLayout));
  516. }
  517. void attention_forward(float* out, float* qkvr, float* att,
  518. float* inp,
  519. int B, int T, int C, int NH) {
  520. const int block_size = 256;
  521. const int softmax_block_size = 256;
  522. int HS = C / NH; // head size
  523. float *q, *k, *v;
  524. q = qkvr + 0 * B * T * C;
  525. k = qkvr + 1 * B * T * C;
  526. v = qkvr + 2 * B * T * C;
  527. int total_threads = B * NH * T * HS;
  528. int num_blocks = CEIL_DIV(total_threads, block_size);
  529. permute_kernel<<<num_blocks, block_size>>>(q, k, v, inp, B, T, NH, HS);
  530. cudaCheck(cudaGetLastError());
  531. const float alpha = 1.0f;
  532. const float beta = 0.0f;
  533. float* preatt = inp;
  534. cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, &alpha, k, HS, T * HS, q, HS, T * HS, &beta, preatt, T, T * T, B * NH));
  535. float scale = 1.0 / sqrtf(HS);
  536. int grid_size = CEIL_DIV(B * NH * T * 32, softmax_block_size);
  537. softmax_forward_kernel5<<<grid_size, softmax_block_size>>>(att, scale, preatt, B * NH, T);
  538. cudaCheck(cudaGetLastError());
  539. float* vaccum = inp;
  540. cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, &alpha, v, HS, T * HS, att, T, T * T, &beta, vaccum, HS, T * HS, B * NH));
  541. num_blocks = CEIL_DIV(B * T * C, block_size);
  542. unpermute_kernel<<<num_blocks, block_size>>>(vaccum, out, B, T, NH, HS);
  543. cudaCheck(cudaGetLastError());
  544. }
  545. void residual_forward(float* out, float* inp1, float* inp2, int N) {
  546. const int block_size = 256;
  547. const int grid_size = CEIL_DIV(N, block_size);
  548. residual_forward_kernel<<<grid_size, block_size>>>(out, inp1, inp2, N);
  549. cudaCheck(cudaGetLastError());
  550. }
  551. void gelu_forward(float* out, const float* inp, int N) {
  552. const int block_size = 128;
  553. const int grid_size = CEIL_DIV(N, block_size);
  554. gelu_forward_kernel<<<grid_size, block_size>>>(out, inp, N);
  555. cudaCheck(cudaGetLastError());
  556. }
  557. void gelu_backward(float* dinp, const float* inp, const float* dout, const int N) {
  558. const int block_size = 128;
  559. const int grid_size = CEIL_DIV(N, block_size);
  560. gelu_backward_kernel<<<grid_size, block_size>>>(dinp, inp, dout, N);
  561. cudaCheck(cudaGetLastError());
  562. }
  563. void matmul_backward(float* dinp, float* dweight, float* dbias,
  564. float* dout, float* inp, float* weight,
  565. int B, int T, int C, int OC) {
  566. float one = 1.0f;
  567. float zero = 0.0f;
  568. cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, C, B*T, OC, &one, weight, C, dout, OC, &zero, dinp, C));
  569. cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, C, OC, B*T, &one, inp, C, dout, OC, &one, dweight, C));
  570. if (dbias != NULL) {
  571. const int block_size = 1024;
  572. const int grid_size = OC / 32;
  573. matmul_backward_bias_kernel4<<<grid_size, block_size, block_size * sizeof(float)>>>(dbias, dout, B, T, OC);
  574. cudaCheck(cudaGetLastError());
  575. }
  576. }
  577. void layernorm_backward(float* dinp, float* dweight, float* dbias,
  578. const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd,
  579. int B, int T, int C) {
  580. const int block_size = 512;
  581. const int N = B * T;
  582. const int grid_size = CEIL_DIV(32*N, block_size);
  583. size_t shared_mem_size = 2 * C * sizeof(float);
  584. layernorm_backward_kernel2<<<grid_size, block_size, shared_mem_size>>>(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C);
  585. cudaCheck(cudaGetLastError());
  586. }
  587. void attention_backward(float* dinp, float* dqkvr, float* dpreatt, float* datt, float* scratch,
  588. const float* dout,
  589. const float* qkvr, const float* att,
  590. int B, int T, int C, int NH) {
  591. const int block_size = 256;
  592. int HS = C / NH; // head size
  593. const float one = 1.0f;
  594. const float zero = 0.0f; // note beta = 1.0f so that we accumulate gradients (+=)
  595. const float *q, *k, *v;
  596. q = qkvr + 0 * B * T * C;
  597. k = qkvr + 1 * B * T * C;
  598. v = qkvr + 2 * B * T * C;
  599. float *dq, *dk, *dv;
  600. dq = dqkvr + 0 * B * T * C;
  601. dk = dqkvr + 1 * B * T * C;
  602. dv = dqkvr + 2 * B * T * C;
  603. int num_blocks = CEIL_DIV(B * T * C, block_size);
  604. unpermute_kernel_backward<<<num_blocks, block_size>>>(scratch, dout, B, T, NH, HS);
  605. cudaCheck(cudaGetLastError());
  606. cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, T, T, HS, &one, v, HS, T * HS, scratch, HS, T * HS, &zero, datt, T, T * T, B * NH));
  607. cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, HS, T, T, &one, scratch, HS, T * HS, att, T, T * T, &zero, dv, HS, T * HS, B * NH));
  608. int hs = C / NH; // head size
  609. float scale = 1.0f / sqrtf(hs);
  610. softmax_autoregressive_backward_kernel<<<dim3(T / 4, B * NH), 256>>>(dpreatt, datt, att, B, T, C, scale);
  611. cudaCheck(cudaGetLastError());
  612. cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, HS, T, T, &one, k, HS, T * HS, dpreatt, T, T * T, &zero, dq, HS, T * HS, B * NH));
  613. cublasCheck(cublasSgemmStridedBatched(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, HS, T, T, &one, q, HS, T * HS, dpreatt, T, T * T, &zero, dk, HS, T * HS, B * NH));
  614. num_blocks = CEIL_DIV(B * NH * T * HS, block_size);
  615. permute_kernel_backward<<<num_blocks, block_size>>>(dinp, dq, dk, dv, B, T, NH, HS);
  616. cudaCheck(cudaGetLastError());
  617. }
  618. void fused_classifier3(float* logits, float* losses,
  619. const float* dlosses, const int* targets,
  620. int B, int T, int V, int P) {
  621. const int block_size = 1024;
  622. const int N = B * T;
  623. const int grid_size = N;
  624. fused_classifier_kernel3<<<grid_size, block_size>>>(logits, losses, NULL, dlosses, targets, B, T, V, P);
  625. cudaCheck(cudaGetLastError());
  626. }
  627. typedef struct {
  628. int max_seq_len;
  629. int vocab_size;
  630. int padded_vocab_size;
  631. int num_layers;
  632. int num_heads;
  633. int channels;
  634. } GPT2Config;
  635. #define NUM_PARAMETER_TENSORS 16
  636. typedef struct {
  637. float* wte;
  638. float* wpe;
  639. float* ln1w;
  640. float* ln1b;
  641. float* qkvw;
  642. float* qkvb;
  643. float* attprojw;
  644. float* attprojb;
  645. float* ln2w;
  646. float* ln2b;
  647. float* fcw;
  648. float* fcb;
  649. float* fcprojw;
  650. float* fcprojb;
  651. float* lnfw;
  652. float* lnfb;
  653. } ParameterTensors;
  654. void fill_in_parameter_sizes(size_t* param_sizes, GPT2Config config) {
  655. int Vp = config.padded_vocab_size;
  656. int C = config.channels;
  657. int maxT = config.max_seq_len;
  658. int L = config.num_layers;
  659. param_sizes[0] = Vp * C;
  660. param_sizes[1] = maxT * C;
  661. param_sizes[2] = L * C;
  662. param_sizes[3] = L * C;
  663. param_sizes[4] = L * (3 * C) * C;
  664. param_sizes[5] = L * (3 * C);
  665. param_sizes[6] = L * C * C;
  666. param_sizes[7] = L * C;
  667. param_sizes[8] = L * C;
  668. param_sizes[9] = L * C;
  669. param_sizes[10] = L * (4 * C) * C;
  670. param_sizes[11] = L * (4 * C);
  671. param_sizes[12] = L * C * (4 * C);
  672. param_sizes[13] = L * C;
  673. param_sizes[14] = C;
  674. param_sizes[15] = C;
  675. }
  676. float* malloc_and_point_parameters(ParameterTensors* params, size_t* param_sizes, int on_device) {
  677. size_t num_parameters = 0;
  678. for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {
  679. num_parameters += param_sizes[i];
  680. }
  681. float* params_memory;
  682. if (on_device) {
  683. cudaCheck(cudaMalloc((void**)&params_memory, num_parameters * sizeof(float)));
  684. } else {
  685. params_memory = (float*)mallocCheck(num_parameters * sizeof(float));
  686. }
  687. float** ptrs[] = {
  688. &params->wte, &params->wpe, &params->ln1w, &params->ln1b, &params->qkvw, &params->qkvb,
  689. &params->attprojw, &params->attprojb, &params->ln2w, &params->ln2b, &params->fcw, &params->fcb,
  690. &params->fcprojw, &params->fcprojb, &params->lnfw, &params->lnfb
  691. };
  692. float* params_memory_iterator = params_memory;
  693. for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {
  694. *(ptrs[i]) = params_memory_iterator;
  695. params_memory_iterator += param_sizes[i];
  696. }
  697. return params_memory;
  698. }
  699. #define NUM_ACTIVATION_TENSORS 21
  700. typedef struct {
  701. float* encoded;
  702. float* ln1;
  703. float* ln1_mean;
  704. float* ln1_rstd;
  705. float* atty;
  706. float* att;
  707. float* attproj;
  708. float* residual2;
  709. float* ln2;
  710. float* ln2_mean;
  711. float* ln2_rstd;
  712. float* fch;
  713. float* fch_gelu;
  714. float* fcproj;
  715. float* residual3;
  716. float* lnf;
  717. float* lnf_mean;
  718. float* lnf_rstd;
  719. float* losses;
  720. float* qkvr;
  721. float* output;
  722. } ActivationTensors;
  723. void fill_in_activation_sizes(size_t* act_sizes, int B, int T, GPT2Config config) {
  724. size_t Vp = config.padded_vocab_size;
  725. size_t L = config.num_layers;
  726. size_t NH = config.num_heads;
  727. size_t C = config.channels;
  728. act_sizes[0] = B * T * C;
  729. act_sizes[1] = L * B * T * C;
  730. act_sizes[2] = L * B * T;
  731. act_sizes[3] = L * B * T;
  732. act_sizes[4] = L * B * T * C;
  733. act_sizes[5] = L * B * NH * T * T;
  734. act_sizes[6] = L * B * T * C;
  735. act_sizes[7] = L * B * T * C;
  736. act_sizes[8] = L * B * T * C;
  737. act_sizes[9] = L * B * T;
  738. act_sizes[10] = L * B * T;
  739. act_sizes[11] = L * B * T * 4*C;
  740. act_sizes[12] = L * B * T * 4*C;
  741. act_sizes[13] = L * B * T * C;
  742. act_sizes[14] = L * B * T * C;
  743. act_sizes[15] = B * T * C;
  744. act_sizes[16] = B * T;
  745. act_sizes[17] = B * T;
  746. act_sizes[18] = B * T;
  747. act_sizes[19] = L * B * T * 3*C; // qkvr
  748. act_sizes[20] = B * T * max(3*C, max(NH*T, Vp)); // output / scratch
  749. }
  750. #define NUM_BACKWARD_TENSORS 3
  751. typedef struct {
  752. float* bt4c;
  753. float* preatt;
  754. float* residual3;
  755. } GradActTensors;
  756. void fill_in_grad_act_sizes(size_t* act_sizes, int B, int T, GPT2Config config) {
  757. size_t NH = config.num_heads;
  758. size_t C = config.channels;
  759. act_sizes[0] = B * T * 4 * C;
  760. act_sizes[1] = B * NH * T * T;
  761. act_sizes[2] = B * T * C;
  762. }
  763. float* malloc_and_point(float** targets[], const size_t* act_sizes, int n) {
  764. size_t num_activations = 0;
  765. for (size_t i = 0; i < n; i++) {
  766. num_activations += act_sizes[i];
  767. }
  768. float* acts_memory;
  769. cudaCheck(cudaMalloc((void**)&acts_memory, num_activations * sizeof(float)));
  770. float* acts_memory_iterator = acts_memory;
  771. for (size_t i = 0; i < n; i++) {
  772. *(targets[i]) = acts_memory_iterator;
  773. acts_memory_iterator += act_sizes[i];
  774. }
  775. return acts_memory;
  776. }
  777. float* malloc_and_point_activations(ActivationTensors* acts, const size_t* act_sizes) {
  778. float** ptrs[] = {
  779. &acts->encoded, &acts->ln1, &acts->ln1_mean, &acts->ln1_rstd, &acts->atty,
  780. &acts->att, &acts->attproj, &acts->residual2, &acts->ln2, &acts->ln2_mean,
  781. &acts->ln2_rstd, &acts->fch, &acts->fch_gelu, &acts->fcproj, &acts->residual3, &acts->lnf,
  782. &acts->lnf_mean, &acts->lnf_rstd, &acts->losses, &acts->qkvr, &acts->output
  783. };
  784. return malloc_and_point(ptrs, act_sizes, NUM_ACTIVATION_TENSORS);
  785. }
  786. float* malloc_and_point_backward(GradActTensors* acts, const size_t* act_sizes) {
  787. float** ptrs[] = {
  788. &acts->bt4c, &acts->preatt, &acts->residual3
  789. };
  790. return malloc_and_point(ptrs, act_sizes, NUM_BACKWARD_TENSORS);
  791. }
  792. typedef struct {
  793. GPT2Config config;
  794. ParameterTensors params;
  795. size_t param_sizes[NUM_PARAMETER_TENSORS];
  796. float* params_memory;
  797. size_t num_parameters;
  798. ParameterTensors grads;
  799. float* grads_memory;
  800. float* m_memory;
  801. float* v_memory;
  802. ActivationTensors acts;
  803. size_t act_sizes[NUM_ACTIVATION_TENSORS];
  804. float* acts_memory;
  805. size_t num_activations;
  806. GradActTensors grads_acts;
  807. size_t num_grad_acts;
  808. float* grads_acts_memory;
  809. int batch_size;
  810. int seq_len;
  811. int* inputs;
  812. int* targets;
  813. float mean_loss;
  814. float* cpu_losses;
  815. } GPT2;
  816. void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
  817. FILE *model_file = fopenCheck(checkpoint_path, "rb");
  818. int model_header[256];
  819. freadCheck(model_header, sizeof(int), 256, model_file);
  820. if (model_header[0] != 20240326) { fprintf(stderr, "Bad magic model file\n"); exit(EXIT_FAILURE); }
  821. if (model_header[1] != 3) {
  822. // was bumped from 1 -> 3 to incorporate the padded vocab size
  823. fprintf(stderr, "Bad version in model file\n");
  824. fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n");
  825. exit(EXIT_FAILURE);
  826. }
  827. model->config.max_seq_len = model_header[2];
  828. model->config.vocab_size = model_header[3];
  829. model->config.num_layers = model_header[4];
  830. model->config.num_heads = model_header[5];
  831. model->config.channels = model_header[6];
  832. model->config.padded_vocab_size = model_header[7];
  833. fill_in_parameter_sizes(model->param_sizes, model->config);
  834. size_t num_parameters = 0;
  835. for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {
  836. num_parameters += model->param_sizes[i];
  837. }
  838. model->num_parameters = num_parameters;
  839. model->params_memory = malloc_and_point_parameters(&model->params, model->param_sizes, 1);
  840. float* params_memory_cpu = (float*)mallocCheck(num_parameters * sizeof(float));
  841. freadCheck(params_memory_cpu, sizeof(float), num_parameters, model_file);
  842. cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, num_parameters * sizeof(float), cudaMemcpyHostToDevice));
  843. free(params_memory_cpu);
  844. fcloseCheck(model_file);
  845. model->acts_memory = NULL;
  846. model->grads_memory = NULL;
  847. model->m_memory = NULL;
  848. model->v_memory = NULL;
  849. model->grads_acts_memory = NULL;
  850. model->inputs = NULL;
  851. model->targets = NULL;
  852. model->cpu_losses = NULL;
  853. model->batch_size = 0;
  854. model->seq_len = 0;
  855. model->mean_loss = -1.0f; // -1.0f will designate no loss
  856. }
  857. void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) {
  858. if (model->params_memory == NULL) {
  859. printf("Error: model was not initialized properly.\n");
  860. exit(EXIT_FAILURE);
  861. }
  862. int V = model->config.vocab_size;
  863. int Vp = model->config.padded_vocab_size;
  864. int L = model->config.num_layers;
  865. int NH = model->config.num_heads;
  866. int C = model->config.channels;
  867. for(int i = 0; i < B * T; i++) {
  868. assert(0 <= inputs[i] && inputs[i] < V);
  869. if (targets != NULL) {
  870. assert(0 <= targets[i] && targets[i] < V);
  871. }
  872. }
  873. if(model->acts_memory == NULL) {
  874. model->batch_size = B;
  875. model->seq_len = T;
  876. fill_in_activation_sizes(model->act_sizes, B, T, model->config);
  877. size_t num_activations = 0;
  878. for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
  879. num_activations += model->act_sizes[i];
  880. }
  881. model->num_activations = num_activations;
  882. model->acts_memory = malloc_and_point_activations(&model->acts, model->act_sizes);
  883. printf("allocated %zu MiB for activations\n", (num_activations * sizeof(float)) >> 20);
  884. cudaCheck(cudaMalloc((void**)&model->inputs, B * T * sizeof(int)));
  885. cudaCheck(cudaMalloc((void**)&model->targets, B * T * sizeof(int)));
  886. cudaCheck(cudaMallocHost((void**)&model->cpu_losses, B * T * sizeof(float)));
  887. } else {
  888. if (B != model->batch_size || T != model->seq_len) {
  889. printf("Model: B=%d T=%d, Desired: B=%d T=%d\n", model->batch_size, model->seq_len, B, T);
  890. exit(EXIT_FAILURE);
  891. }
  892. }
  893. cudaCheck(cudaMemcpy(model->inputs, inputs, B * T * sizeof(int), cudaMemcpyHostToDevice));
  894. if (targets != NULL) {
  895. cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice));
  896. }
  897. ParameterTensors params = model->params;
  898. ActivationTensors acts = model->acts;
  899. float* residual;
  900. encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, B, T, C);
  901. for (int l = 0; l < L; l++) {
  902. residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;
  903. float* l_ln1w = params.ln1w + l * C;
  904. float* l_ln1b = params.ln1b + l * C;
  905. float* l_qkvw = params.qkvw + l * 3*C * C;
  906. float* l_qkvb = params.qkvb + l * 3*C;
  907. float* l_attprojw = params.attprojw + l * C * C;
  908. float* l_attprojb = params.attprojb + l * C;
  909. float* l_ln2w = params.ln2w + l * C;
  910. float* l_ln2b = params.ln2b + l * C;
  911. float* l_fcw = params.fcw + l * 4*C * C;
  912. float* l_fcb = params.fcb + l * 4*C;
  913. float* l_fcprojw = params.fcprojw + l * C * 4*C;
  914. float* l_fcprojb = params.fcprojb + l * C;
  915. float* l_ln1 = acts.ln1 + l * B * T * C;
  916. float* l_ln1_mean = acts.ln1_mean + l * B * T;
  917. float* l_ln1_rstd = acts.ln1_rstd + l * B * T;
  918. float* l_qkvr = acts.qkvr + l * B * T * 3*C;
  919. float* l_atty = acts.atty + l * B * T * C;
  920. float* l_att = acts.att + l * B * NH * T * T;
  921. float* l_attproj = acts.attproj + l * B * T * C;
  922. float* l_residual2 = acts.residual2 + l * B * T * C;
  923. float* l_ln2 = acts.ln2 + l * B * T * C;
  924. float* l_ln2_mean = acts.ln2_mean + l * B * T;
  925. float* l_ln2_rstd = acts.ln2_rstd + l * B * T;
  926. float* l_fch = acts.fch + l * B * T * 4*C;
  927. float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C;
  928. float* l_fcproj = acts.fcproj + l * B * T * C;
  929. float* l_residual3 = acts.residual3 + l * B * T * C;
  930. float* scratch = acts.output;
  931. layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C);
  932. matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C);
  933. attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH);
  934. matmul_forward_cublaslt(l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C);
  935. residual_forward(l_residual2, residual, l_attproj, B*T*C);
  936. layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C);
  937. matmul_forward_cublaslt(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C);
  938. gelu_forward(l_fch_gelu, l_fch, B*T*4*C);
  939. matmul_forward_cublaslt(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C);
  940. residual_forward(l_residual3, l_residual2, l_fcproj, B*T*C);
  941. }
  942. residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3
  943. layernorm_forward(acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C);
  944. matmul_forward_cublaslt(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp);
  945. if (targets != NULL) {
  946. fused_classifier3(acts.output, acts.losses, NULL, model->targets, B, T, V, Vp);
  947. cudaCheck(cudaMemcpy(model->cpu_losses, acts.losses, B * T * sizeof(float), cudaMemcpyDeviceToHost));
  948. float mean_loss = 0.0f;
  949. for (int i=0; i<B*T; i++) { mean_loss += model->cpu_losses[i]; }
  950. mean_loss /= B*T;
  951. model->mean_loss = mean_loss;
  952. } else {
  953. model->mean_loss = -1.0f;
  954. }
  955. }
  956. void gpt2_zero_grad(GPT2 *model) {
  957. if (model->grads_acts_memory != NULL) { cudaCheck(cudaMemset(model->grads_acts_memory, 0, model->num_grad_acts * sizeof(float))); }
  958. if (model->grads_memory != NULL) { cudaCheck(cudaMemset(model->grads_memory, 0, model->num_parameters * sizeof(float))); }
  959. }
  960. void gpt2_backward(GPT2 *model) {
  961. if (model->mean_loss == -1.0f) {
  962. printf("Error: must forward with targets before backward\n");
  963. exit(EXIT_FAILURE);
  964. }
  965. if (model->grads_memory == NULL) {
  966. model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_sizes, 1);
  967. printf("allocated %zu MiB for parameter gradients\n", (model->num_parameters * sizeof(float)) >> 20);
  968. size_t bw_act_sizes[NUM_ACTIVATION_TENSORS];
  969. GPT2Config cfg = model->config;
  970. cfg.num_layers = 1; // copy the configuration but override number of layers to 1
  971. fill_in_grad_act_sizes(bw_act_sizes, model->batch_size, model->seq_len, cfg);
  972. model->grads_acts_memory = malloc_and_point_backward(&model->grads_acts, bw_act_sizes);
  973. model->num_grad_acts = 0;
  974. for (int i = 0; i < NUM_BACKWARD_TENSORS; i++) {
  975. model->num_grad_acts += bw_act_sizes[i];
  976. }
  977. printf("allocated %zu MiB for activation gradients\n", (model->num_grad_acts * sizeof(float)) >> 20);
  978. gpt2_zero_grad(model);
  979. }
  980. int B = model->batch_size;
  981. int T = model->seq_len;
  982. int Vp = model->config.padded_vocab_size;
  983. int L = model->config.num_layers;
  984. int NH = model->config.num_heads;
  985. int C = model->config.channels;
  986. ParameterTensors params = model->params;
  987. ParameterTensors grads = model->grads;
  988. ActivationTensors acts = model->acts;
  989. GradActTensors grads_acts = model->grads_acts;
  990. matmul_backward(grads_acts.bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, B, T, C, Vp);
  991. float* residual = acts.residual3 + (L-1) * B * T * C;
  992. float* dresidual = grads_acts.residual3;
  993. layernorm_backward(dresidual, grads.lnfw, grads.lnfb, grads_acts.bt4c, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C);
  994. for (int l = L-1; l >= 0; l--) {
  995. residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;
  996. float* l_ln1w = params.ln1w + l * C;
  997. float* l_qkvw = params.qkvw + l * 3*C * C;
  998. float* l_attprojw = params.attprojw + l * C * C;
  999. float* l_ln2w = params.ln2w + l * C;
  1000. float* l_fcw = params.fcw + l * 4*C * C;
  1001. float* l_fcprojw = params.fcprojw + l * C * 4*C;
  1002. float* dl_ln1w = grads.ln1w + l * C;
  1003. float* dl_ln1b = grads.ln1b + l * C;
  1004. float* dl_qkvw = grads.qkvw + l * 3*C * C;
  1005. float* dl_qkvb = grads.qkvb + l * 3*C;
  1006. float* dl_attprojw = grads.attprojw + l * C * C;
  1007. float* dl_attprojb = grads.attprojb + l * C;
  1008. float* dl_ln2w = grads.ln2w + l * C;
  1009. float* dl_ln2b = grads.ln2b + l * C;
  1010. float* dl_fcw = grads.fcw + l * 4*C * C;
  1011. float* dl_fcb = grads.fcb + l * 4*C;
  1012. float* dl_fcprojw = grads.fcprojw + l * C * 4*C;
  1013. float* dl_fcprojb = grads.fcprojb + l * C;
  1014. float* l_ln1 = acts.ln1 + l * B * T * C;
  1015. float* l_ln1_mean = acts.ln1_mean + l * B * T;
  1016. float* l_ln1_rstd = acts.ln1_rstd + l * B * T;
  1017. float* l_qkvr = acts.qkvr + l * B * T * 3*C;
  1018. float* l_atty = acts.atty + l * B * T * C;
  1019. float* l_att = acts.att + l * B * NH * T * T;
  1020. float* l_residual2 = acts.residual2 + l * B * T * C;
  1021. float* l_ln2 = acts.ln2 + l * B * T * C;
  1022. float* l_ln2_mean = acts.ln2_mean + l * B * T;
  1023. float* l_ln2_rstd = acts.ln2_rstd + l * B * T;
  1024. float* l_fch = acts.fch + l * B * T * 4*C;
  1025. float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C;
  1026. float* dl_btc = acts.lnf;
  1027. float* dl_bt4c = grads_acts.bt4c;
  1028. float* dl_preatt = grads_acts.preatt;
  1029. float* scratch = acts.output;
  1030. matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, B, T, 4*C, C);
  1031. gelu_backward(dl_bt4c, l_fch, dl_bt4c, B*T*4*C);
  1032. matmul_backward(dl_btc, dl_fcw, dl_fcb, dl_bt4c, l_ln2, l_fcw, B, T, C, 4 * C);
  1033. layernorm_backward(dresidual, dl_ln2w, dl_ln2b, dl_btc, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C);
  1034. matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, B, T, C, C);
  1035. float* buffer_a = l_atty;
  1036. float* buffer_b = l_fch;
  1037. attention_backward(dl_bt4c, buffer_b, dl_preatt, scratch, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH);
  1038. matmul_backward(dl_btc, dl_qkvw, dl_qkvb, dl_bt4c, l_ln1, l_qkvw, B, T, C, 3 * C);
  1039. layernorm_backward(dresidual, dl_ln1w, dl_ln1b, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C);
  1040. }
  1041. encoder_backward(grads.wte, grads.wpe, dresidual, model->inputs, B, T, C);
  1042. }
  1043. void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, int t) {
  1044. if (model->m_memory == NULL) {
  1045. cudaCheck(cudaMalloc((void**)&model->m_memory, model->num_parameters * sizeof(float)));
  1046. cudaCheck(cudaMalloc((void**)&model->v_memory, model->num_parameters * sizeof(float)));
  1047. cudaCheck(cudaMemset(model->m_memory, 0, model->num_parameters * sizeof(float)));
  1048. cudaCheck(cudaMemset(model->v_memory, 0, model->num_parameters * sizeof(float)));
  1049. printf("allocated %zu MiB for AdamW optimizer state m\n", (model->num_parameters * sizeof(float)) >> 20);
  1050. printf("allocated %zu MiB for AdamW optimizer state v\n", (model->num_parameters * sizeof(float)) >> 20);
  1051. }
  1052. int block_size = 512;
  1053. int num_blocks = CEIL_DIV(model->num_parameters, block_size);
  1054. float beta1_correction = 1.0f - powf(beta1, t);
  1055. float beta2_correction = 1.0f - powf(beta2, t);
  1056. adamw_kernel2<<<num_blocks, block_size>>>(model->params_memory, model->grads_memory, model->m_memory, model->v_memory,
  1057. model->num_parameters,
  1058. learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay);
  1059. cudaCheck(cudaGetLastError());
  1060. }
  1061. void gpt2_free(GPT2 *model) {
  1062. cudaCheck(cudaFree(model->params_memory));
  1063. cudaCheck(cudaFree(model->grads_memory));
  1064. cudaCheck(cudaFree(model->m_memory));
  1065. cudaCheck(cudaFree(model->v_memory));
  1066. cudaCheck(cudaFree(model->acts_memory));
  1067. cudaCheck(cudaFree(model->grads_acts_memory));
  1068. cudaCheck(cudaFree(model->inputs));
  1069. cudaCheck(cudaFree(model->targets));
  1070. cudaFreeHost(model->cpu_losses);
  1071. }
  1072. #ifndef TESTING
  1073. typedef struct {
  1074. int B;
  1075. int T;
  1076. FILE* tokens_file;
  1077. long file_size;
  1078. long current_position;
  1079. // output memory
  1080. int* batch;
  1081. int* inputs;
  1082. int* targets;
  1083. long num_batches;
  1084. } DataLoader;
  1085. void dataloader_init(DataLoader *loader, const char* filename, int B, int T) {
  1086. loader->B = B;
  1087. loader->T = T;
  1088. loader->tokens_file = fopenCheck(filename, "rb");
  1089. fseekCheck(loader->tokens_file, 0, SEEK_END);
  1090. loader->file_size = ftell(loader->tokens_file);
  1091. fseekCheck(loader->tokens_file, 0, SEEK_SET);
  1092. if (loader->file_size < (B * T + 1) * sizeof(int)) {
  1093. printf("Error: file size is too small for the batch size and sequence length\n");
  1094. exit(EXIT_FAILURE);
  1095. }
  1096. loader->current_position = 0;
  1097. cudaMallocHost((void**)&loader->batch, (B * T + 1) * sizeof(int));
  1098. loader->inputs = loader->batch;
  1099. loader->targets = loader->batch + 1;
  1100. loader->num_batches = loader->file_size / (B * T * sizeof(int));
  1101. }
  1102. void dataloader_reset(DataLoader *loader) {
  1103. loader->current_position = 0;
  1104. }
  1105. void dataloader_next_batch(DataLoader *loader) {
  1106. int B = loader->B;
  1107. int T = loader->T;
  1108. if (loader->current_position + (B*T+1) * sizeof(int) > loader->file_size) {
  1109. loader->current_position = 0;
  1110. }
  1111. fseekCheck(loader->tokens_file, loader->current_position, SEEK_SET);
  1112. freadCheck(loader->batch, sizeof(int), B*T+1, loader->tokens_file);
  1113. loader->current_position += B*T * sizeof(int);
  1114. }
  1115. void dataloader_free(DataLoader *loader) {
  1116. fcloseCheck(loader->tokens_file);
  1117. cudaFreeHost(loader->batch);
  1118. }
  1119. #define GPT2_EOT 50256
  1120. unsigned int random_u32(unsigned long long *state) {
  1121. *state ^= *state >> 12;
  1122. *state ^= *state << 25;
  1123. *state ^= *state >> 27;
  1124. return (*state * 0x2545F4914F6CDD1Dull) >> 32;
  1125. }
  1126. float random_f32(unsigned long long *state) {
  1127. return (random_u32(state) >> 8) / 16777216.0f;
  1128. }
  1129. int sample_softmax(const float* logits, int n, float coin) {
  1130. double norm = 0;
  1131. for (int i = 0; i < n; i++) {
  1132. norm += expf(logits[i]);
  1133. }
  1134. coin *= norm;
  1135. float cdf = 0.0f;
  1136. for (int i = 0; i < n; i++) {
  1137. cdf += expf(logits[i]);
  1138. if (coin < cdf) {
  1139. return i;
  1140. }
  1141. }
  1142. return n - 1;
  1143. }
  1144. typedef struct {
  1145. FILE *logfile;
  1146. int flush_every; // every how many steps to flush the log
  1147. } Logger;
  1148. void logger_init(Logger *logger, const char *filename) {
  1149. logger->flush_every = 20;
  1150. logger->logfile = NULL;
  1151. if (filename != NULL) { logger->logfile = fopenCheck(filename, "w"); }
  1152. }
  1153. void logger_log_val(Logger *logger, int step, float val_loss) {
  1154. if (logger->logfile != NULL) {
  1155. fprintf(logger->logfile, "s:%d tel:%.4f\n", step, val_loss);
  1156. }
  1157. }
  1158. void logger_log_train(Logger *logger, int step, float train_loss) {
  1159. if (logger->logfile != NULL) {
  1160. fprintf(logger->logfile, "s:%d trl:%.4f\n", step, train_loss);
  1161. if (step % 10 == 0) { fflush(logger->logfile); }
  1162. }
  1163. }
  1164. void logger_free(Logger *logger) {
  1165. if (logger->logfile != NULL) { fclose(logger->logfile); }
  1166. }
  1167. void error_usage() {
  1168. fprintf(stderr, "Usage: ./train_gpt2fp32cu [options]\n");
  1169. fprintf(stderr, "Example: ./train_gpt2fp32cu -i data/TinyStories -v 100 -s 100 -g 144 -o stories.log\n");
  1170. fprintf(stderr, "Options:\n");
  1171. fprintf(stderr, " -i <string> input dataset prefix (default = data/tiny_shakespeare)\n");
  1172. fprintf(stderr, " -o <string> output log file (default = NULL)\n");
  1173. fprintf(stderr, " -b <int> batch size B (default = 4)\n");
  1174. fprintf(stderr, " -t <int> sequence length T (default = 1024)\n");
  1175. fprintf(stderr, " -l <float> learning rate (default = 3e-4f)\n");
  1176. fprintf(stderr, " -v <int> val_loss_every, how often we evaluate val loss (default = 20)\n");
  1177. fprintf(stderr, " -m <int> val_max_batches, up to how many val batches to estimate val loss? (default = 20)\n");
  1178. fprintf(stderr, " -s <int> sample_every, how often we inference the model (default = 20)\n");
  1179. fprintf(stderr, " -g <int> genT, how many steps of inference we do (default = 64)\n");
  1180. exit(EXIT_FAILURE);
  1181. }
  1182. int main(int argc, char *argv[]) {
  1183. const char* input_dataset_prefix = "data/tiny_shakespeare";
  1184. const char* output_log_file = NULL;
  1185. int B = 4;
  1186. int T = 1024;
  1187. float learning_rate = 3e-4f;
  1188. int val_loss_every = 20;
  1189. int val_max_batches = 20;
  1190. int sample_every = 20;
  1191. int genT = 64;
  1192. for (int i = 1; i < argc; i+=2) {
  1193. if (i + 1 >= argc) { error_usage(); } // must have arg after flag
  1194. if (argv[i][0] != '-') { error_usage(); } // must start with dash
  1195. if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
  1196. // read in the args
  1197. if (argv[i][1] == 'i') { input_dataset_prefix = argv[i+1]; }
  1198. else if (argv[i][1] == 'o') { output_log_file = argv[i+1]; }
  1199. else if (argv[i][1] == 'b') { B = atoi(argv[i+1]); }
  1200. else if (argv[i][1] == 't') { T = atoi(argv[i+1]); }
  1201. else if (argv[i][1] == 'l') { learning_rate = atof(argv[i+1]); }
  1202. else if (argv[i][1] == 'v') { val_loss_every = atoi(argv[i+1]); }
  1203. else if (argv[i][1] == 'm') { val_max_batches = atoi(argv[i+1]); }
  1204. else if (argv[i][1] == 's') { sample_every = atoi(argv[i+1]); }
  1205. else if (argv[i][1] == 'g') { genT = atoi(argv[i+1]); }
  1206. else { error_usage(); }
  1207. }
  1208. printf("+-----------------------+----------------------------------------------------+\n");
  1209. printf("| Parameter | Value |\n");
  1210. printf("+-----------------------+----------------------------------------------------+\n");
  1211. printf("| input dataset prefix | %-50s |\n", input_dataset_prefix);
  1212. printf("| output log file | %-50s |\n", output_log_file == NULL ? "NULL" : output_log_file);
  1213. printf("| batch size B | %-50d |\n", B);
  1214. printf("| sequence length T | %-50d |\n", T);
  1215. printf("| learning rate | %-50f |\n", learning_rate);
  1216. printf("| val_loss_every | %-50d |\n", val_loss_every);
  1217. printf("| val_max_batches | %-50d |\n", val_max_batches);
  1218. printf("| sample_every | %-50d |\n", sample_every);
  1219. printf("| genT | %-50d |\n", genT);
  1220. printf("+-----------------------+----------------------------------------------------+\n");
  1221. int deviceIdx = 0;
  1222. cudaCheck(cudaSetDevice(deviceIdx));
  1223. cudaDeviceProp deviceProp;
  1224. cudaGetDeviceProperties(&deviceProp, deviceIdx);
  1225. cublasCheck(cublasCreate(&cublas_handle));
  1226. cublasCheck(cublasLtCreate(&cublaslt_handle));
  1227. int enable_tf32 = deviceProp.major >= 8 ? 1 : 0;
  1228. cublas_compute_type = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F;
  1229. cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH;
  1230. cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode));
  1231. cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size));
  1232. printf("| device | %-50s |\n", deviceProp.name);
  1233. printf("| TF32 | %-50s |\n", enable_tf32 ? "enabled" : "disabled");
  1234. printf("+-----------------------+----------------------------------------------------+\n");
  1235. GPT2 model;
  1236. gpt2_build_from_checkpoint(&model, "gpt2_124M.bin");
  1237. printf("| max_sequence_length T | %-50d |\n", model.config.max_seq_len);
  1238. printf("| vocab_size V | %-50d |\n", model.config.vocab_size);
  1239. printf("| padded_vocab_size Vp | %-50d |\n", model.config.padded_vocab_size);
  1240. printf("| num_layers L | %-50d |\n", model.config.num_layers);
  1241. printf("| num_heads NH | %-50d |\n", model.config.num_heads);
  1242. printf("| channels C | %-50d |\n", model.config.channels);
  1243. printf("| num_parameters | %-50zu |\n", model.num_parameters);
  1244. printf("+-----------------------+----------------------------------------------------+\n");
  1245. char train_tokens_filename[128];
  1246. char val_tokens_filename[128];
  1247. assert(strlen(input_dataset_prefix) < 100); // being bit lazy here, make sure we don't overflow
  1248. sprintf(train_tokens_filename, "%s_train.bin", input_dataset_prefix);
  1249. sprintf(val_tokens_filename, "%s_val.bin", input_dataset_prefix);
  1250. DataLoader train_loader;
  1251. dataloader_init(&train_loader, train_tokens_filename, B, T);
  1252. DataLoader val_loader;
  1253. dataloader_init(&val_loader, val_tokens_filename, B, T);
  1254. int train_num_batches = train_loader.num_batches; // let's do 1 epoch by default for now
  1255. int val_num_batches = train_loader.num_batches < val_max_batches ? train_loader.num_batches : val_max_batches;
  1256. printf("| train_num_batches | %-50d |\n", train_num_batches);
  1257. printf("| val_num_batches | %-50d |\n", val_num_batches);
  1258. printf("+-----------------------+----------------------------------------------------+\n");
  1259. printf("allocated %d MiB for model parameters\n", (int)round(model.num_parameters * sizeof(float) / (1024 * 1024)));
  1260. Logger logger;
  1261. logger_init(&logger, output_log_file);
  1262. Tokenizer tokenizer;
  1263. tokenizer_init(&tokenizer, "gpt2_tokenizer.bin");
  1264. unsigned long long rng_state = 1337;
  1265. int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int));
  1266. float* cpu_logits = (float*)mallocCheck(model.config.vocab_size * sizeof(float));
  1267. struct timespec start, end;
  1268. double total_sum_iteration_time_s = 0.0;
  1269. for (int step = 0; step <= train_num_batches; step++) {
  1270. int last_step = step == train_num_batches;
  1271. if (step % val_loss_every == 0 || last_step) {
  1272. float val_loss = 0.0f;
  1273. dataloader_reset(&val_loader);
  1274. for (int i = 0; i < val_num_batches; i++) {
  1275. dataloader_next_batch(&val_loader);
  1276. gpt2_forward(&model, val_loader.inputs, val_loader.targets, B, T);
  1277. val_loss += model.mean_loss;
  1278. }
  1279. val_loss /= val_num_batches;
  1280. printf("val loss %f\n", val_loss);
  1281. logger_log_val(&logger, step, val_loss);
  1282. }
  1283. if (step > 0 && step % sample_every == 0 || last_step) {
  1284. for(int i = 0; i < B * T; ++i) {
  1285. gen_tokens[i] = GPT2_EOT;
  1286. }
  1287. printf("generating:\n---\n");
  1288. for (int t = 1; t < genT; t++) {
  1289. gpt2_forward(&model, gen_tokens, NULL, B, T);
  1290. float* logits = model.acts.output + (t - 1) * model.config.padded_vocab_size;
  1291. cudaCheck(cudaMemcpy(cpu_logits, logits, model.config.vocab_size * sizeof(float), cudaMemcpyDeviceToHost));
  1292. float coin = random_f32(&rng_state);
  1293. int next_token = sample_softmax(cpu_logits, model.config.vocab_size, coin);
  1294. gen_tokens[t] = next_token;
  1295. if (tokenizer.init_ok) {
  1296. const char* token_str = tokenizer_decode(&tokenizer, next_token);
  1297. safe_printf(token_str);
  1298. } else {
  1299. printf("%d ", next_token);
  1300. }
  1301. fflush(stdout);
  1302. }
  1303. printf("\n---\n");
  1304. }
  1305. if (last_step) { break; }
  1306. clock_gettime(CLOCK_MONOTONIC, &start);
  1307. dataloader_next_batch(&train_loader);
  1308. gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T);
  1309. gpt2_zero_grad(&model);
  1310. gpt2_backward(&model);
  1311. gpt2_update(&model, learning_rate, 0.9f, 0.999f, 1e-8f, 0.0f, step+1);
  1312. cudaCheck(cudaDeviceSynchronize());
  1313. clock_gettime(CLOCK_MONOTONIC, &end);
  1314. double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;
  1315. total_sum_iteration_time_s += time_elapsed_s;
  1316. int tokens_per_second = (B * T) / time_elapsed_s;
  1317. printf("step %4d/%d: train loss %f (%f ms, %d tok/s)\n", step + 1, train_num_batches, model.mean_loss, time_elapsed_s * 1000, tokens_per_second);
  1318. logger_log_train(&logger, step, model.mean_loss);
  1319. }
  1320. printf("total average iteration time: %f ms\n", total_sum_iteration_time_s / train_num_batches * 1000);
  1321. dataloader_free(&train_loader);
  1322. dataloader_free(&val_loader);
  1323. tokenizer_free(&tokenizer);
  1324. gpt2_free(&model);
  1325. free(cpu_logits);
  1326. free(gen_tokens);
  1327. cudaCheck(cudaFree(cublaslt_workspace));
  1328. cublasCheck(cublasDestroy(cublas_handle));
  1329. cublasCheck(cublasLtDestroy(cublaslt_handle));
  1330. logger_free(&logger);
  1331. return 0;
  1332. }
  1333. #endif

解读

这段代码是一个使用C语言编写的训练GPT-2模型的程序,它利用了NVIDIA的CUDA平台进行GPU加速。GPT-2是一个基于Transformer架构的自然语言处理模型,常用于文本生成任务。这个程序包含了数据加载、模型构建、前向传播、反向传播、参数更新以及模型保存等功能。下面是对程序主要部分的分析:

  1. 头文件包含:程序开始处包含了多个头文件,这些头文件提供了访问标准库、数学库、时间库、断言、浮点数精度、字符串操作以及UNIX系统调用的功能。

  2. CUDA和cuBLAS库:程序使用了CUDA运行时API和cuBLAS库,这是NVIDIA提供的一些用于GPU加速计算的库。

  3. 自定义宏定义:例如CEIL_DIV用于计算除法的上限,cudaCheckcublasCheck用于检查CUDA和cuBLAS操作的错误。

  4. 设备函数和内核:程序中定义了一些设备函数(在GPU上执行的函数)和内核(可以在GPU上并行执行的函数),例如encoder_forward_kernel3encoder_backward_kernel等,这些内核函数实现了模型的前向传播和反向传播计算。

  5. 前向传播和反向传播函数:程序实现了encoder_forwardencoder_backwardlayernorm_forwardmatmul_forward_cublaslt等函数,这些函数用于执行模型的前向传播。matmul_backwardlayernorm_backward等函数用于执行模型的反向传播。

  6. 注意力机制:程序实现了注意力机制的前向和后向传播,这是Transformer模型的核心部分。

  7. 优化器:程序实现了AdamW优化器的adamw_kernel2内核函数,用于更新模型的参数。

  8. 模型构建gpt2_build_from_checkpoint函数用于从检查点文件中恢复模型的参数。

  9. 模型前向传播gpt2_forward函数执行模型的前向传播,并计算损失。

  10. 模型反向传播和参数更新gpt2_backward函数执行模型的反向传播,计算梯度;gpt2_update函数使用AdamW优化器更新模型参数。

  11. 数据加载器DataLoader结构和相关函数用于从文件中加载训练和验证数据。

  12. 日志记录Logger结构和相关函数用于记录训练过程中的损失和其他信息。

  13. 主函数main函数是程序的入口点,它初始化模型、数据加载器、日志记录器,然后进入训练循环。在循环中,它执行数据加载、模型前向传播、损失计算、反向传播、参数更新,并定期记录验证损失和生成文本样本。

  14. 错误处理和参数解析:程序提供了错误处理和命令行参数解析,以便于用户自定义训练参数。

  15. 随机数生成:程序实现了随机数生成器,用于文本生成过程中的随机抽样。

  16. Tokenizer:程序使用了Tokenizer来处理文本数据,将文本转换为模型可以理解的数字ID。

整体来看,这个程序是一个完整的训练循环实现,包括数据加载、模型构建、训练、验证和文本生成。它展示了如何使用CUDA和cuBLAS库在GPU上执行复杂的机器学习模型训练。

这个C语言程序是一个用于训练GPT-2模型的完整示例,它使用了NVIDIA CUDA和cuBLAS库来实现GPU加速。程序的主要目的是展示如何在GPU上使用这些技术来训练一个基于Transformer架构的语言模型。以下是对程序主要部分的解读:

  1. 头文件和库:程序包含了多个C标准库头文件,以及CUDA运行时API、cuBLAS库和合作组(cooperative groups)库,后者用于在CUDA内核中实现更高效的并行计算。

  2. 宏定义CEIL_DIV宏用于计算除法的上限。cudaCheckcublasCheck宏用于检查CUDA和cuBLAS操作的错误,并在出错时退出程序。

  3. CUDA和cuBLAS环境设置:程序初始化了CUDA和cuBLAS环境,包括创建句柄和分配工作空间。

  4. 设备函数:定义了一些设备函数,如add_float4,这些函数在GPU上执行,用于进行特定的数学运算。

  5. CUDA内核:程序包含多个CUDA内核函数,例如encoder_forward_kernel3encoder_backward_kernellayernorm_forward_kernel3等,这些内核实现了模型的前向传播和反向传播计算。

  6. 前向传播和反向传播函数:程序实现了encoder_forwardencoder_backwardlayernorm_forwardmatmul_forward_cublaslt等函数,用于执行模型的前向传播和反向传播。

  7. 注意力机制:实现了注意力机制的前向和后向传播,这是Transformer模型的核心部分。

  8. 优化器:实现了AdamW优化器的adamw_kernel2内核函数,用于更新模型的参数。

  9. 模型构建和初始化gpt2_build_from_checkpoint函数用于从检查点文件中恢复模型的参数。

  10. 模型前向传播gpt2_forward函数执行模型的前向传播,并计算损失。

  11. 模型反向传播和参数更新gpt2_backward函数执行模型的反向传播,计算梯度;gpt2_update函数使用AdamW优化器更新模型参数。

  12. 数据加载器DataLoader结构和相关函数用于从文件中加载训练和验证数据。

  13. 日志记录Logger结构和相关函数用于记录训练过程中的损失和其他信息。

  14. 主函数main函数是程序的入口点,它初始化模型、数据加载器、日志记录器,然后进入训练循环。在循环中,它执行数据加载、模型前向传播、损失计算、反向传播、参数更新,并定期记录验证损失和生成文本样本。

  15. Tokenizer:程序使用了Tokenizer来处理文本数据,将文本转换为模型可以理解的数字ID。

  16. 命令行参数解析:程序提供了命令行参数解析,以便于用户自定义训练参数。

  17. 随机数生成:程序实现了随机数生成器,用于文本生成过程中的随机抽样。

  18. 性能测量:程序使用clock_gettime来测量训练步骤的执行时间,并计算每秒处理的令牌数。

  19. 释放资源:在训练结束后,程序会释放所有分配的资源,包括CUDA内存、文件句柄和日志文件。

这个程序是一个完整的训练循环实现,展示了如何使用CUDA和cuBLAS库在GPU上执行复杂的机器学习模型训练。程序的结构清晰,包含了训练过程中的所有关键步骤,是一个学习如何在GPU上进行深度学习模型训练的很好的示例。

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/Guff_9hys/article/detail/798680
推荐阅读
相关标签
  

闽ICP备14008679号