当前位置:   article > 正文

一个简介的简洁的cublasSgetrfBatched和cublasSgetriBatched示例_cublas getrfbatched例子

cublas getrfbatched例子

搭建一个能够调用cublas函数的环境中编译运行:

  1. //#include "device_launch_parameters.h"
  2. #include <cublas_v2.h>
  3. #include <cuda_runtime.h>
  4. #include <stdio.h>
  5. #include <stdlib.h>
  6. #include <string.h>
  7. int printMatrix(char* matrixName, int m, int n, float* a_matrix, int lda) {
  8. printf("\n%s =\n", matrixName);
  9. for (int i = 0; i < m; i++) {
  10. printf("\n");
  11. for (int j = 0; j < n; j++) {
  12. printf(" %8.5f ", a_matrix[i + j * lda]);
  13. }
  14. }
  15. printf("\n");
  16. return 0;
  17. }
  18. void printVector(const char* log, int* Vector, int size) {
  19. printf("\n%s\n",log);
  20. for (int i = 0; i < size; i++) {
  21. printf(" %d ", Vector[i]);
  22. }
  23. printf("\n");
  24. }
  25. void initMatrix(float* A, int dim, int seed) {
  26. srand(2022 + seed);
  27. printf("Starting Init Matrix ...\n");
  28. for (int i = 0; i < dim * dim; i++) {
  29. A[i] = (float)((rand() % 100) - 50) / 100.0 + 1.0;
  30. //LL:: printf("A[%d]=%f ",i,A[i]);
  31. }
  32. printf("Init is OK.\n");
  33. }
  34. void do_dgemm(int M, int N, int K, float* A, int lda, float* B, int ldb, float* C, int ldc) {
  35. for (int i = 0; i < M; i++) {
  36. for (int j = 0; j < N; j++) {
  37. C[i + j * ldc] = 0.0; //C(i,j)=0.0;
  38. for (int k = 0; k < K; k++) {
  39. //C(i,j) = ΣA(i,k)B(k,j) k=0,1, ... ,K-1
  40. C[i + j * ldc] += A[i + k * lda] * B[k + j * ldb];
  41. }
  42. }
  43. }
  44. }
  45. float* d_Get3DInv(float* A, int n, int batchedSize)
  46. {
  47. cublasHandle_t cu_cublasHandle;
  48. cublasCreate(&cu_cublasHandle);
  49. float** AarrayDev;
  50. float** CarrayDev;
  51. float** AarrayHost=NULL;
  52. AarrayHost = (float**)malloc(batchedSize * sizeof(float*));
  53. float** CarrayHost = NULL;
  54. CarrayHost = (float**)malloc(batchedSize * sizeof(float*));
  55. float* A_data_dev;
  56. float* C_data_dev;
  57. int* LUPivots_dev;
  58. int* LUInfo_dev;
  59. size_t size_data_A = batchedSize * n * n * sizeof(float);
  60. cudaMalloc(&AarrayDev, batchedSize * sizeof(float*));
  61. cudaMalloc(&CarrayDev, batchedSize * sizeof(float*));
  62. //cudaMalloc(&adC, sizeof(float*));
  63. cudaMalloc(&A_data_dev, size_data_A);
  64. for (int i = 0; i < batchedSize; i++) {
  65. AarrayHost[i] = A_data_dev + i * n * n;
  66. }
  67. cudaMalloc(&C_data_dev, size_data_A);
  68. for (int i = 0; i < batchedSize; i++) {
  69. CarrayHost[i] = C_data_dev + i * n * n;
  70. }
  71. cudaMalloc(&LUPivots_dev, batchedSize * n * sizeof(int));
  72. cudaMalloc(&LUInfo_dev, batchedSize * sizeof(int));
  73. cudaMemcpy(A_data_dev, A, size_data_A, cudaMemcpyHostToDevice);
  74. cudaMemcpy(AarrayDev, AarrayHost, batchedSize*sizeof(float*), cudaMemcpyHostToDevice);
  75. cudaMemcpy(CarrayDev, CarrayHost, batchedSize*sizeof(float*), cudaMemcpyHostToDevice);
  76. cublasSgetrfBatched(cu_cublasHandle, n, AarrayDev, n, LUPivots_dev, LUInfo_dev, batchedSize);
  77. cudaDeviceSynchronize();
  78. float* resGETRF = (float*)malloc(size_data_A);
  79. cudaMemcpy(resGETRF, A_data_dev, size_data_A, cudaMemcpyDeviceToHost);
  80. int* pivot = (int*)malloc(n * batchedSize*sizeof(int));
  81. cudaMemcpy(pivot, LUPivots_dev, n * batchedSize*sizeof(int), cudaMemcpyDeviceToHost);
  82. for (int i = 0; i < batchedSize; i++) {
  83. printVector("Pivot",pivot+i*n,n);
  84. }
  85. for (int i = 0; i < batchedSize; i++) {
  86. printMatrix("Matrix A_LU", n, n, resGETRF+i*n*n, n);
  87. }
  88. cublasSgetriBatched(cu_cublasHandle, n, (const float**)AarrayDev, n, LUPivots_dev, CarrayDev, n, LUInfo_dev, batchedSize);
  89. cudaDeviceSynchronize();
  90. float* res = (float*)malloc(size_data_A);
  91. cudaMemcpy(res, C_data_dev, size_data_A, cudaMemcpyDeviceToHost);
  92. cudaFree(LUInfo_dev);
  93. cudaFree(LUPivots_dev);
  94. cudaFree(C_data_dev);
  95. cudaFree(A_data_dev);
  96. cudaFree(CarrayDev);
  97. cudaFree(AarrayDev);
  98. cublasDestroy(cu_cublasHandle);
  99. free(AarrayHost);
  100. free(CarrayHost);
  101. return res;
  102. }
  103. int main(int argc, char* argv[]) {
  104. int n = 5;
  105. int batchedSize = 2;
  106. int J = 1;
  107. float* A = (float*)malloc(batchedSize * n * n * sizeof(float));
  108. for (int i = 0; i < batchedSize; i++) {
  109. initMatrix(A + i * n * n, n, i);
  110. }
  111. for (int i = 0; i < batchedSize; i++) {
  112. printMatrix("Matrix A", n, n, A + i*n*n, n);
  113. }
  114. float* inv = d_Get3DInv(A, n, batchedSize);
  115. for (int i = 0; i < batchedSize; i++) {
  116. printMatrix("Matrix A'", n, n, inv+i*n*n, n);
  117. }
  118. float* C_gemm = (float*)malloc(n * n * sizeof(float));
  119. for (int i = 0; i < batchedSize; i++) {
  120. do_dgemm(n, n, n, A+i*n*n, n, inv+i*n*n, n, C_gemm, n);
  121. printMatrix("Matrix A*A'", n, n, C_gemm, n);
  122. }
  123. free(C_gemm);
  124. if (inv != NULL) {
  125. free(inv);
  126. inv = NULL;
  127. }
  128. free(A);
  129. return 0;
  130. }

 

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号