当前位置:   article > 正文

使用TVM优化GEMM_tvm site:csdn.net

tvm site:csdn.net

本文参考的是TVM的官方例程,参考链接为:使用张量表达式处理算子 | Apache TVM 中文站不过选取的M,N,K不太一样,这里选取的是M=512,K=5120,N=512,优化顺序也略有不同,针对其中优化效果不明显的优化策略作了解释和删减,并且增加了一些参数选择的具体逻辑;该参数下GEMM经过优化后性能提升99.1%;

测试机器为:

整体代码如下,在GitHub的repo中也可以找到:https://github.com/Beichen-Wang/HPCTest/blob/master/TVM/src/TE_GEMM.py

  1. import tvm
  2. import tvm.testing
  3. from tvm import te
  4. import numpy as np
  5. import timeit
  6. class GEMM:
  7. def __init__(self, M, N, K, bs, targetI = "llvm"):
  8. self.M = M
  9. self.N = N
  10. self.K = K
  11. self.bs = bs
  12. self.target = targetI
  13. self.dtype = "float32"
  14. self.target = tvm.target.Target(target=targetI)
  15. self.dev = tvm.device(self.target.kind.name)
  16. self.a = tvm.nd.array(np.random.rand(M,K).astype(self.dtype), self.dev)
  17. self.b = tvm.nd.array(np.random.rand(K,N).astype(self.dtype), self.dev)
  18. self.c = tvm.nd.array(np.zeros((M,N), dtype=self.dtype), self.dev)
  19. self.log = []
  20. #utils
  21. def EvaluateOperation(self, func, baseC):
  22. self.c = tvm.nd.array(np.zeros((M,N), dtype=self.dtype), self.dev)
  23. func(self.a, self.b, self.c)
  24. tvm.testing.assert_allclose(self.c.numpy(), baseC, rtol=1e-5)
  25. evaluator = func.time_evaluator(func.entry_name, self.dev, number=10)
  26. mean_time = evaluator(self.a, self.b, self.c).mean
  27. print("%s: %f" % (func.name, mean_time))
  28. self.log.append((func.name, mean_time))
  29. #numpy
  30. def NumpyGEMM(self):
  31. npRepeatNum = 1
  32. npRunningTime = timeit.timeit(
  33. setup="import numpy\n",
  34. stmt="answer = numpy.dot(a_np, b_np)",
  35. number=npRepeatNum,
  36. globals={"a_np": self.a.numpy(), "b_np": self.b.numpy()}
  37. )
  38. print("Numpy running time: %f" % (npRunningTime / npRepeatNum))
  39. return np.dot(self.a.numpy(), self.b.numpy())
  40. #default
  41. def TEDefaultGemm(self):
  42. k = te.reduce_axis((0, self.K), "k")
  43. A = te.placeholder((self.M, self.K), name="A")
  44. B = te.placeholder((self.K, self.N), name="B")
  45. C = te.compute((self.M, self.N), lambda x, y: te.sum(A[x,k]*B[k,y], axis = k), name="C")
  46. s = te.create_schedule(C.op)
  47. func = tvm.build(s, [A,B,C], target = self.target, name = "default")
  48. print(tvm.lower(s, [A,B,C], simple_mode=True))
  49. return func
  50. #optimizer1---final--block,vectory,parallel
  51. def TEBlockVectoryParallelGemm(self):
  52. k = te.reduce_axis((0, self.K), "k")
  53. A = te.placeholder((self.M, self.K), name="A")
  54. B = te.placeholder((self.K, self.N), name="B")
  55. C = te.compute((self.M, self.N), lambda x, y: te.sum(A[x,k]*B[k,y], axis = k), name="C")
  56. s = te.create_schedule(C.op)
  57. xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], self.bs, self.bs)
  58. ko, ki = s[C].split(k, factor=16)
  59. s[C].reorder(xo, yo, ko, xi, ki, yi)
  60. # s[C].unroll(ki)
  61. s[C].vectorize(yi)
  62. s[C].parallel(xo)
  63. func = tvm.build(s, [A,B,C], target = self.target, name = "blockVectoryParallel")
  64. print(tvm.lower(s, [A,B,C], simple_mode=True))
  65. return func
  66. #optimizer2.1--+cache
  67. def TECacheGemm(self):
  68. k = te.reduce_axis((0, self.K), "k")
  69. A = te.placeholder((self.M, self.K), name="A")
  70. B = te.placeholder((self.K, self.N), name="B")
  71. C = te.compute(
  72. (self.M, self.N),
  73. lambda x, y: te.sum(A[x, k] * B[k,y], axis=k),
  74. name = "C",
  75. )
  76. s = te.create_schedule(C.op)
  77. CC = s.cache_write(C, "global")
  78. xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], self.bs, self.bs)
  79. s[CC].compute_at(s[C],yo)
  80. # New inner axes
  81. xc, yc = s[CC].op.axis
  82. (k,) = s[CC].op.reduce_axis
  83. ko, ki = s[CC].split(k, factor=16)
  84. s[CC].reorder(ko, xc, ki, yc)
  85. s[CC].unroll(ki)
  86. s[CC].vectorize(yc)
  87. # parallel
  88. s[C].parallel(xo)
  89. func = tvm.build(s, [A,B,C], target = self.target, name = "CacheParallel")
  90. print(tvm.lower(s, [A, B, C], simple_mode=True))
  91. return func
  92. #optimizer2.2--+pack
  93. def TEPackGemm(self):
  94. k = te.reduce_axis((0, self.K), "k")
  95. A = te.placeholder((self.M, self.K), name="A")
  96. B = te.placeholder((self.K, self.N), name="B")
  97. packedB = te.compute((self.N / self.bs, self.K, self.bs), lambda x, y, z: B[y, x * self.bs + z], name="packedB")
  98. C = te.compute(
  99. (self.M, self.N),
  100. lambda x, y: te.sum(A[x, k] * packedB[tvm.tir.indexdiv(y, self.bs), k,tvm.tir.indexmod(y, self.bs)], axis=k),
  101. name = "C",
  102. )
  103. s = te.create_schedule(C.op)
  104. xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], self.bs, self.bs)
  105. (k,) = s[C].op.reduce_axis
  106. ko, ki = s[C].split(k, factor=16)
  107. s[C].reorder(xo, yo, ko, xi, ki, yi)
  108. s[C].vectorize(yi)
  109. x, y, z = s[packedB].op.axis
  110. s[packedB].vectorize(z)
  111. s[packedB].parallel(x)
  112. s[C].parallel(xo)
  113. func = tvm.build(s, [A,B,C], target = self.target, name = "PackParallel")
  114. print(tvm.lower(s, [A, B, C], simple_mode=True))
  115. return func
  116. #optimizer3.1--+cache+pack
  117. def TECachePackGemm(self):
  118. k = te.reduce_axis((0, self.K), "k")
  119. A = te.placeholder((self.M, self.K), name="A")
  120. B = te.placeholder((self.K, self.N), name="B")
  121. packedB = te.compute((self.N / self.bs, self.K, self.bs), lambda x, y, z: B[y, x * self.bs + z], name="packedB")
  122. C = te.compute(
  123. (self.M, self.N),
  124. lambda x, y: te.sum(A[x, k] * packedB[tvm.tir.indexdiv(y, self.bs), k,tvm.tir.indexmod(y, self.bs)], axis=k),
  125. name = "C",
  126. )
  127. s = te.create_schedule(C.op)
  128. CC = s.cache_write(C, "global")
  129. xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], self.bs, self.bs)
  130. s[CC].compute_at(s[C],yo)
  131. # New inner axes
  132. xc, yc = s[CC].op.axis
  133. (k,) = s[CC].op.reduce_axis
  134. ko, ki = s[CC].split(k, factor=64)
  135. s[CC].reorder(ko, xc, ki, yc)
  136. s[CC].unroll(ki)
  137. # s[CC].pragma(ki, "unroll_explicit", 2)
  138. s[CC].vectorize(yc)
  139. # parallel
  140. s[C].parallel(xo)
  141. x, y, z = s[packedB].op.axis
  142. s[packedB].vectorize(z)
  143. s[packedB].parallel(x)
  144. func = tvm.build(s, [A,B,C], target = self.target, name = "CacheParallel")
  145. print(tvm.lower(s, [A, B, C], simple_mode=True))
  146. return func
  147. if __name__ == "__main__":
  148. M = 512
  149. K = 5120
  150. N = 512
  151. bs = 64
  152. instance = GEMM(M,N,K,bs)
  153. baseC = instance.NumpyGEMM()
  154. funcDefault = instance.TEDefaultGemm()
  155. instance.EvaluateOperation(funcDefault,baseC)
  156. funcBlockPermuteVectory = instance.TEBlockVectoryParallelGemm()
  157. instance.EvaluateOperation(funcBlockPermuteVectory,baseC)
  158. # funcPack = instance.TEPackGemm()
  159. # instance.EvaluateOperation(funcPack,baseC)
  160. # funcCache = instance.TECachePackGemm()
  161. # instance.EvaluateOperation(funcCache,baseC)
  162. # funcCache = instance.TECacheGemm()
  163. # instance.EvaluateOperation(funcCache,baseC)

0.Init

  1. def __init__(self, M, N, K, bs, targetI = "llvm"):
  2. self.M = M
  3. self.N = N
  4. self.K = K
  5. self.bs = bs
  6. self.target = targetI
  7. self.dtype = "float32"
  8. self.target = tvm.target.Target(target=targetI)
  9. self.dev = tvm.device(self.target.kind.name)
  10. self.a = tvm.nd.array(np.random.rand(M,K).astype(self.dtype), self.dev)
  11. self.b = tvm.nd.array(np.random.rand(K,N).astype(self.dtype), self.dev)
  12. self.c = tvm.nd.array(np.zeros((M,N), dtype=self.dtype), self.dev)
  13. self.log = []

这里指定的target是llvm,选用的dtype是float;

1.DefaultGEMM

  1. def TEDefaultGemm(self):
  2. k = te.reduce_axis((0, self.K), "k")
  3. A = te.placeholder((self.M, self.K), name="A")
  4. B = te.placeholder((self.K, self.N), name="B")
  5. C = te.compute((self.M, self.N), lambda x, y: te.sum(A[x,k]*B[k,y], axis = k), name="C")
  6. s = te.create_schedule(C.op)
  7. func = tvm.build(s, [A,B,C], target = self.target, name = "default")
  8. print(tvm.lower(s, [A,B,C], simple_mode=True))
  9. return func

生成出来的中间表示为:

  1. # from tvm.script import ir as I
  2. # from tvm.script import tir as T
  3. @I.ir_module
  4. class Module:
  5. @T.prim_func
  6. def main(A: T.Buffer((512, 5120), "float32"), B: T.Buffer((5120, 512), "float32"), C: T.Buffer((512, 512), "float32")):
  7. T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
  8. for x, y in T.grid(512, 512):
  9. C_1 = T.Buffer((262144,), data=C.data)
  10. C_1[x * 512 + y] = T.float32(0)
  11. for k in range(5120):
  12. cse_var_1: T.int32 = x * 512 + y
  13. A_1 = T.Buffer((2621440,), data=A.data)
  14. B_1 = T.Buffer((2621440,), data=B.data)
  15. C_1[cse_var_1] = C_1[cse_var_1] + A_1[x * 5120 + k] * B_1[k * 512 + y]

其中从打印出来的中间表示中看到,这是最简单的GEMM,里面存在的第一个问题是B的访存不连续,cache miss很严重,所以我们需要进行tile;

2.Optimizer1-TileGEMM

  1. xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], self.bs, self.bs)
  2. ko, ki = s[C].split(k, factor=16)
  3. s[C].reorder(xo, yo, ko, xi, ki, yi)

这里选用的bs是64,主要考虑是64*64*sizeof(float)<L1 cache size(32K);K的factor选用16,cache line的size为64B,连续16个数据占用一条cache line;当然这里选用的不是最优,只是一个相对较优的结果;lower表示为:

  1. # from tvm.script import ir as I
  2. # from tvm.script import tir as T
  3. @I.ir_module
  4. class Module:
  5. @T.prim_func
  6. def main(A: T.Buffer((512, 5120), "float32"), B: T.Buffer((5120, 512), "float32"), C: T.Buffer((512, 512), "float32")):
  7. T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
  8. for x_outer, y_outer in T.grid(8, 8):
  9. C_1 = T.Buffer((262144,), data=C.data)
  10. for x_inner_init, y_inner_init in T.grid(64, 64):
  11. C_1[x_outer * 32768 + x_inner_init * 512 + y_outer * 64 + y_inner_init] = T.float32(0)
  12. for k_outer, x_inner, k_inner, y_inner in T.grid(320, 64, 16, 64):
  13. cse_var_2: T.int32 = y_outer * 64
  14. cse_var_1: T.int32 = x_outer * 32768 + x_inner * 512 + cse_var_2 + y_inner
  15. A_1 = T.Buffer((2621440,), data=A.data)
  16. B_1 = T.Buffer((2621440,), data=B.data)
  17. C_1[cse_var_1] = C_1[cse_var_1] + A_1[x_outer * 327680 + x_inner * 5120 + k_outer * 16 + k_inner] * B_1[k_outer * 8192 + k_inner * 512 + cse_var_2 + y_inner]

可以看到B的访存连续,但是依然可以加入优化空间为:针对最低维度y_inner,可以使用SIMD进行向量化处理,在TVM中的向量化处理是使用冒号来表示的,主要为X[idex:idex+64];

3.Optimizer2-VectorizeGEMM

s[C].vectorize(yi)

lower表示为:

  1. # from tvm.script import ir as I
  2. # from tvm.script import tir as T
  3. @I.ir_module
  4. class Module:
  5. @T.prim_func
  6. def main(A: T.Buffer((512, 5120), "float32"), B: T.Buffer((5120, 512), "float32"), C: T.Buffer((512, 512), "float32")):
  7. T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
  8. for x_outer, y_outer in T.grid(8, 8):
  9. C_1 = T.Buffer((262144,), data=C.data)
  10. for x_inner_init in range(64):
  11. C_1[x_outer * 32768 + x_inner_init * 512 + y_outer * 64:x_outer * 32768 + x_inner_init * 512 + y_outer * 64 + 64] = T.Broadcast(T.float32(0), 64)
  12. for k_outer, x_inner, k_inner in T.grid(320, 64, 16):
  13. cse_var_2: T.int32 = y_outer * 64
  14. cse_var_1: T.int32 = x_outer * 32768 + x_inner * 512 + cse_var_2
  15. A_1 = T.Buffer((2621440,), data=A.data)
  16. B_1 = T.Buffer((2621440,), data=B.data)
  17. C_1[cse_var_1:cse_var_1 + 64] = C_1[cse_var_1:cse_var_1 + 64] + T.Broadcast(A_1[x_outer * 327680 + x_inner * 5120 + k_outer * 16 + k_inner], 64) * B_1[k_outer * 8192 + k_inner * 512 + cse_var_2:k_outer * 8192 + k_inner * 512 + cse_var_2 + 64]

vectorize之后速度提升38.9%,效果明显;

  • 在官方例程中他提到了数组打包的方式,但是优化效果一般,主要原理是将B做了块转置,这样B的块间访存连续,但是B的块内访存其实已经可以打满cache line,所以块间访问存对性能提升效果不明显,甚至因为需要构造转置矩阵,针对本例子的测试构成了负优化,因此这里不做讨论;
  • 针对C,官方例程中也构造一个中间变量来对每一个pack做了存储,以此来构成C的连续访问,但是因为tile size为64,C访存时cache基本打满了,所以优化效果一般,甚至也有负优化,这里不做讨论;
  • 针对unroll操作,优化效果也基本没有,主要原因猜测为编译器优化已经很好了,所以加unroll优化效果也一般,这里不做展开;

最后一种优化方式为多核并行;

4.Optimizer3-parallel

s[C].parallel(xo)

lower表示为:

  1. # from tvm.script import ir as I
  2. # from tvm.script import tir as T
  3. @I.ir_module
  4. class Module:
  5. @T.prim_func
  6. def main(A: T.Buffer((512, 5120), "float32"), B: T.Buffer((5120, 512), "float32"), C: T.Buffer((512, 512), "float32")):
  7. T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
  8. for x_outer in T.parallel(8):
  9. for y_outer in range(8):
  10. C_1 = T.Buffer((262144,), data=C.data)
  11. for x_inner_init in range(64):
  12. C_1[x_outer * 32768 + x_inner_init * 512 + y_outer * 64:x_outer * 32768 + x_inner_init * 512 + y_outer * 64 + 64] = T.Broadcast(T.float32(0), 64)
  13. for k_outer, x_inner, k_inner in T.grid(320, 64, 16):
  14. cse_var_2: T.int32 = y_outer * 64
  15. cse_var_1: T.int32 = x_outer * 32768 + x_inner * 512 + cse_var_2
  16. A_1 = T.Buffer((2621440,), data=A.data)
  17. B_1 = T.Buffer((2621440,), data=B.data)
  18. C_1[cse_var_1:cse_var_1 + 64] = C_1[cse_var_1:cse_var_1 + 64] + T.Broadcast(A_1[x_outer * 327680 + x_inner * 5120 + k_outer * 16 + k_inner], 64) * B_1[k_outer * 8192 + k_inner * 512 + cse_var_2:k_outer * 8192 + k_inner * 512 + cse_var_2 + 64]

这里开了八个线程,做了并行处理,效果又提升了百分之83.8%。

5.总结

5.1测试时间

优化时间(s)
default2.930760
+tile0.261980
+vectorize0.16041
+parallel0.025833

5.2tvm使用体会

整体来说tvm的TE使用还是非常方便,其中遇到了一点点小问题是:unroll,我想指定unroll的行数时,还没找到具体的使用技巧,只能unroll一整个维度;

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/166339?site
推荐阅读
相关标签
  

闽ICP备14008679号