当前位置:   article > 正文

cublasCgetrfBatched cublasCgetriBatched sample

cublascgetrfbatched
  1. /* This example demonstrates how to use the CUBLAS library
  2. * by scaling an array of floating-point values on the device
  3. * and comparing the result to the same operation performed
  4. * on the host.
  5. */
  6. /* Includes, system */
  7. #include <stdio.h>
  8. #include <stdlib.h>
  9. #include <string.h>
  10. /* Includes, cuda */
  11. #include <cublas_v2.h>
  12. #include <cuda_runtime.h>
  13. //#include <helper_cuda.h>
  14. #include <vector>
  15. #include <random>
  16. #include <iostream>
  17. using namespace std;
  18. /* Matrix size */
  19. #define N (4)
  20. void printComplexMatrix(float2* A, int m, int n, int lda);
  21. /* Main */
  22. int main(int argc, char **argv) {
  23. int n=N;
  24. int batchCount=1;
  25. const auto kSizeN = n;
  26. int Pivot[kSizeN * batchCount];
  27. int info[batchCount];
  28. int Pivot_cu[kSizeN * batchCount];
  29. int info_cu[batchCount];
  30. // Creates input matrices
  31. auto mat_a = std::vector<float2>(kSizeN * kSizeN);
  32. auto mat_a_cu = std::vector<float2>(kSizeN * kSizeN);
  33. // Create a random number generator
  34. const auto random_seed = 12;
  35. std::default_random_engine generator(
  36. static_cast<unsigned int>(random_seed));
  37. std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
  38. // Populates input data structures
  39. for (auto &item : mat_a) {
  40. item.x = distribution(generator);
  41. item.y = distribution(generator);
  42. }
  43. //LL:: printComplexMatrix(mat_a.data(), kSizeN, kSizeN, kSizeN);
  44. size_t len_a = kSizeN * kSizeN * sizeof(float2);
  45. float2 *devA_cu[batchCount];
  46. float2 **d_Aarray_cu;
  47. for (int i = 0; i < batchCount; i++) {
  48. cudaMalloc((void **)&devA_cu[i], len_a);
  49. }
  50. for (int i = 0; i < batchCount; i++) {
  51. cudaMemcpy(devA_cu[i], mat_a.data(), len_a, cudaMemcpyHostToDevice);
  52. }
  53. cudaMalloc((void **)&d_Aarray_cu, batchCount * sizeof(float2 *));
  54. cudaMemcpy(d_Aarray_cu, devA_cu, batchCount * sizeof(float2 *),
  55. cudaMemcpyHostToDevice);
  56. int *Pivot_cu_d;
  57. cudaMalloc((void **)&Pivot_cu_d, kSizeN * batchCount * sizeof(int));
  58. int *info_cu_d;
  59. cudaMalloc((void **)&info_cu_d, batchCount * sizeof(int));
  60. cublasHandle_t cublasHandle;
  61. cublasStatus_t cu_status;
  62. cu_status = cublasCreate(&cublasHandle);
  63. if(CUBLAS_STATUS_SUCCESS != cu_status)
  64. cout<<"ERROR!"<<endl;
  65. cu_status = cublasCgetrfBatched(cublasHandle, kSizeN, d_Aarray_cu, kSizeN,
  66. Pivot_cu_d, info_cu_d, batchCount);
  67. if(CUBLAS_STATUS_SUCCESS != cu_status)
  68. cout<<"ERROR!"<<endl;
  69. cudaError_t ret = cudaDeviceSynchronize();
  70. cu_status = cublasDestroy(cublasHandle);
  71. if(CUBLAS_STATUS_SUCCESS != cu_status)
  72. cout<<"ERROR!"<<endl;
  73. cudaMemcpy(Pivot_cu, Pivot_cu_d, kSizeN * sizeof(int) * batchCount,
  74. cudaMemcpyDeviceToHost);
  75. cudaMemcpy(info_cu, info_cu_d, sizeof(int) * batchCount,
  76. cudaMemcpyDeviceToHost);
  77. for (int i = 0; i < batchCount; i++) {
  78. for (int j = 0; j < kSizeN; j++) {
  79. //LL:: cout<<(Pivot_cu[i * kSizeN + j])<< endl;
  80. }
  81. cudaMemcpy(mat_a_cu.data(), devA_cu[i], len_a, cudaMemcpyDeviceToHost);
  82. float2 *cu_result = mat_a_cu.data();
  83. //LL:: printComplexMatrix(cu_result, kSizeN,kSizeN,kSizeN );
  84. cudaFree(devA_cu[i]);
  85. }
  86. cudaFree(Pivot_cu_d);
  87. cudaFree(info_cu_d);
  88. cudaFree(d_Aarray_cu);
  89. //------------------------------------------------------------------------
  90. cuComplex a,b;
  91. a.x = -0.214;
  92. a.y = 0.255;
  93. b.x = -1.0000;
  94. b.y = 0.1570;
  95. cuComplex c = cuCdivf(a, b);
  96. cout<<c.x<<" + "<<c.y<<"i"<<endl;
  97. //------------------------------------------------------------------------
  98. return 0;
  99. }
  100. void printComplexMatrix(float2* A, int m, int n, int lda){
  101. for(int i=0; i<m; i++){
  102. for(int j=0; j<n; j++){
  103. printf("(%8.5f,%8.5f*I) ", A[i+j*lda].x, A[i+j*lda].y);
  104. }
  105. printf("\n\n");
  106. }
  107. printf("\n");
  108. }

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/245529?site
推荐阅读
相关标签
  

闽ICP备14008679号