赞
踩
本文参考的是TVM的官方例程,参考链接为:使用张量表达式处理算子 | Apache TVM 中文站不过选取的M,N,K不太一样,这里选取的是M=512,K=5120,N=512,优化顺序也略有不同,针对其中优化效果不明显的优化策略作了解释和删减,并且增加了一些参数选择的具体逻辑;该参数下GEMM经过优化后性能提升99.1%;
测试机器为:
整体代码如下,在GitHub的repo中也可以找到:https://github.com/Beichen-Wang/HPCTest/blob/master/TVM/src/TE_GEMM.py
- import tvm
- import tvm.testing
- from tvm import te
- import numpy as np
- import timeit
-
- class GEMM:
- def __init__(self, M, N, K, bs, targetI = "llvm"):
- self.M = M
- self.N = N
- self.K = K
- self.bs = bs
- self.target = targetI
- self.dtype = "float32"
- self.target = tvm.target.Target(target=targetI)
- self.dev = tvm.device(self.target.kind.name)
- self.a = tvm.nd.array(np.random.rand(M,K).astype(self.dtype), self.dev)
- self.b = tvm.nd.array(np.random.rand(K,N).astype(self.dtype), self.dev)
- self.c = tvm.nd.array(np.zeros((M,N), dtype=self.dtype), self.dev)
- self.log = []
- #utils
- def EvaluateOperation(self, func, baseC):
- self.c = tvm.nd.array(np.zeros((M,N), dtype=self.dtype), self.dev)
- func(self.a, self.b, self.c)
- tvm.testing.assert_allclose(self.c.numpy(), baseC, rtol=1e-5)
- evaluator = func.time_evaluator(func.entry_name, self.dev, number=10)
- mean_time = evaluator(self.a, self.b, self.c).mean
- print("%s: %f" % (func.name, mean_time))
- self.log.append((func.name, mean_time))
- #numpy
- def NumpyGEMM(self):
- npRepeatNum = 1
- npRunningTime = timeit.timeit(
- setup="import numpy\n",
- stmt="answer = numpy.dot(a_np, b_np)",
- number=npRepeatNum,
- globals={"a_np": self.a.numpy(), "b_np": self.b.numpy()}
- )
- print("Numpy running time: %f" % (npRunningTime / npRepeatNum))
- return np.dot(self.a.numpy(), self.b.numpy())
- #default
- def TEDefaultGemm(self):
- k = te.reduce_axis((0, self.K), "k")
- A = te.placeholder((self.M, self.K), name="A")
- B = te.placeholder((self.K, self.N), name="B")
- C = te.compute((self.M, self.N), lambda x, y: te.sum(A[x,k]*B[k,y], axis = k), name="C")
- s = te.create_schedule(C.op)
- func = tvm.build(s, [A,B,C], target = self.target, name = "default")
- print(tvm.lower(s, [A,B,C], simple_mode=True))
- return func
- #optimizer1---final--block,vectory,parallel
- def TEBlockVectoryParallelGemm(self):
- k = te.reduce_axis((0, self.K), "k")
- A = te.placeholder((self.M, self.K), name="A")
- B = te.placeholder((self.K, self.N), name="B")
- C = te.compute((self.M, self.N), lambda x, y: te.sum(A[x,k]*B[k,y], axis = k), name="C")
- s = te.create_schedule(C.op)
- xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], self.bs, self.bs)
- ko, ki = s[C].split(k, factor=16)
- s[C].reorder(xo, yo, ko, xi, ki, yi)
- # s[C].unroll(ki)
- s[C].vectorize(yi)
- s[C].parallel(xo)
- func = tvm.build(s, [A,B,C], target = self.target, name = "blockVectoryParallel")
- print(tvm.lower(s, [A,B,C], simple_mode=True))
- return func
- #optimizer2.1--+cache
- def TECacheGemm(self):
- k = te.reduce_axis((0, self.K), "k")
- A = te.placeholder((self.M, self.K), name="A")
- B = te.placeholder((self.K, self.N), name="B")
- C = te.compute(
- (self.M, self.N),
- lambda x, y: te.sum(A[x, k] * B[k,y], axis=k),
- name = "C",
- )
- s = te.create_schedule(C.op)
- CC = s.cache_write(C, "global")
- xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], self.bs, self.bs)
- s[CC].compute_at(s[C],yo)
- # New inner axes
- xc, yc = s[CC].op.axis
-
- (k,) = s[CC].op.reduce_axis
- ko, ki = s[CC].split(k, factor=16)
- s[CC].reorder(ko, xc, ki, yc)
- s[CC].unroll(ki)
- s[CC].vectorize(yc)
-
- # parallel
- s[C].parallel(xo)
-
- func = tvm.build(s, [A,B,C], target = self.target, name = "CacheParallel")
- print(tvm.lower(s, [A, B, C], simple_mode=True))
- return func
- #optimizer2.2--+pack
- def TEPackGemm(self):
- k = te.reduce_axis((0, self.K), "k")
- A = te.placeholder((self.M, self.K), name="A")
- B = te.placeholder((self.K, self.N), name="B")
- packedB = te.compute((self.N / self.bs, self.K, self.bs), lambda x, y, z: B[y, x * self.bs + z], name="packedB")
- C = te.compute(
- (self.M, self.N),
- lambda x, y: te.sum(A[x, k] * packedB[tvm.tir.indexdiv(y, self.bs), k,tvm.tir.indexmod(y, self.bs)], axis=k),
- name = "C",
- )
- s = te.create_schedule(C.op)
- xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], self.bs, self.bs)
- (k,) = s[C].op.reduce_axis
- ko, ki = s[C].split(k, factor=16)
- s[C].reorder(xo, yo, ko, xi, ki, yi)
- s[C].vectorize(yi)
-
- x, y, z = s[packedB].op.axis
- s[packedB].vectorize(z)
- s[packedB].parallel(x)
-
- s[C].parallel(xo)
- func = tvm.build(s, [A,B,C], target = self.target, name = "PackParallel")
- print(tvm.lower(s, [A, B, C], simple_mode=True))
- return func
- #optimizer3.1--+cache+pack
- def TECachePackGemm(self):
- k = te.reduce_axis((0, self.K), "k")
- A = te.placeholder((self.M, self.K), name="A")
- B = te.placeholder((self.K, self.N), name="B")
- packedB = te.compute((self.N / self.bs, self.K, self.bs), lambda x, y, z: B[y, x * self.bs + z], name="packedB")
- C = te.compute(
- (self.M, self.N),
- lambda x, y: te.sum(A[x, k] * packedB[tvm.tir.indexdiv(y, self.bs), k,tvm.tir.indexmod(y, self.bs)], axis=k),
- name = "C",
- )
- s = te.create_schedule(C.op)
- CC = s.cache_write(C, "global")
- xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], self.bs, self.bs)
- s[CC].compute_at(s[C],yo)
- # New inner axes
- xc, yc = s[CC].op.axis
-
- (k,) = s[CC].op.reduce_axis
- ko, ki = s[CC].split(k, factor=64)
- s[CC].reorder(ko, xc, ki, yc)
- s[CC].unroll(ki)
- # s[CC].pragma(ki, "unroll_explicit", 2)
- s[CC].vectorize(yc)
-
- # parallel
- s[C].parallel(xo)
- x, y, z = s[packedB].op.axis
- s[packedB].vectorize(z)
- s[packedB].parallel(x)
- func = tvm.build(s, [A,B,C], target = self.target, name = "CacheParallel")
- print(tvm.lower(s, [A, B, C], simple_mode=True))
- return func
-
- if __name__ == "__main__":
- M = 512
- K = 5120
- N = 512
- bs = 64
- instance = GEMM(M,N,K,bs)
- baseC = instance.NumpyGEMM()
- funcDefault = instance.TEDefaultGemm()
- instance.EvaluateOperation(funcDefault,baseC)
- funcBlockPermuteVectory = instance.TEBlockVectoryParallelGemm()
- instance.EvaluateOperation(funcBlockPermuteVectory,baseC)
- # funcPack = instance.TEPackGemm()
- # instance.EvaluateOperation(funcPack,baseC)
- # funcCache = instance.TECachePackGemm()
- # instance.EvaluateOperation(funcCache,baseC)
- # funcCache = instance.TECacheGemm()
- # instance.EvaluateOperation(funcCache,baseC)
- def __init__(self, M, N, K, bs, targetI = "llvm"):
- self.M = M
- self.N = N
- self.K = K
- self.bs = bs
- self.target = targetI
- self.dtype = "float32"
- self.target = tvm.target.Target(target=targetI)
- self.dev = tvm.device(self.target.kind.name)
- self.a = tvm.nd.array(np.random.rand(M,K).astype(self.dtype), self.dev)
- self.b = tvm.nd.array(np.random.rand(K,N).astype(self.dtype), self.dev)
- self.c = tvm.nd.array(np.zeros((M,N), dtype=self.dtype), self.dev)
- self.log = []
这里指定的target是llvm,选用的dtype是float;
- def TEDefaultGemm(self):
- k = te.reduce_axis((0, self.K), "k")
- A = te.placeholder((self.M, self.K), name="A")
- B = te.placeholder((self.K, self.N), name="B")
- C = te.compute((self.M, self.N), lambda x, y: te.sum(A[x,k]*B[k,y], axis = k), name="C")
- s = te.create_schedule(C.op)
- func = tvm.build(s, [A,B,C], target = self.target, name = "default")
- print(tvm.lower(s, [A,B,C], simple_mode=True))
- return func
生成出来的中间表示为:
- # from tvm.script import ir as I
- # from tvm.script import tir as T
-
- @I.ir_module
- class Module:
- @T.prim_func
- def main(A: T.Buffer((512, 5120), "float32"), B: T.Buffer((5120, 512), "float32"), C: T.Buffer((512, 512), "float32")):
- T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
- for x, y in T.grid(512, 512):
- C_1 = T.Buffer((262144,), data=C.data)
- C_1[x * 512 + y] = T.float32(0)
- for k in range(5120):
- cse_var_1: T.int32 = x * 512 + y
- A_1 = T.Buffer((2621440,), data=A.data)
- B_1 = T.Buffer((2621440,), data=B.data)
- C_1[cse_var_1] = C_1[cse_var_1] + A_1[x * 5120 + k] * B_1[k * 512 + y]
其中从打印出来的中间表示中看到,这是最简单的GEMM,里面存在的第一个问题是B的访存不连续,cache miss很严重,所以我们需要进行tile;
- xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], self.bs, self.bs)
- ko, ki = s[C].split(k, factor=16)
- s[C].reorder(xo, yo, ko, xi, ki, yi)
这里选用的bs是64,主要考虑是64*64*sizeof(float)<L1 cache size(32K);K的factor选用16,cache line的size为64B,连续16个数据占用一条cache line;当然这里选用的不是最优,只是一个相对较优的结果;lower表示为:
- # from tvm.script import ir as I
- # from tvm.script import tir as T
-
- @I.ir_module
- class Module:
- @T.prim_func
- def main(A: T.Buffer((512, 5120), "float32"), B: T.Buffer((5120, 512), "float32"), C: T.Buffer((512, 512), "float32")):
- T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
- for x_outer, y_outer in T.grid(8, 8):
- C_1 = T.Buffer((262144,), data=C.data)
- for x_inner_init, y_inner_init in T.grid(64, 64):
- C_1[x_outer * 32768 + x_inner_init * 512 + y_outer * 64 + y_inner_init] = T.float32(0)
- for k_outer, x_inner, k_inner, y_inner in T.grid(320, 64, 16, 64):
- cse_var_2: T.int32 = y_outer * 64
- cse_var_1: T.int32 = x_outer * 32768 + x_inner * 512 + cse_var_2 + y_inner
- A_1 = T.Buffer((2621440,), data=A.data)
- B_1 = T.Buffer((2621440,), data=B.data)
- C_1[cse_var_1] = C_1[cse_var_1] + A_1[x_outer * 327680 + x_inner * 5120 + k_outer * 16 + k_inner] * B_1[k_outer * 8192 + k_inner * 512 + cse_var_2 + y_inner]
可以看到B的访存连续,但是依然可以加入优化空间为:针对最低维度y_inner,可以使用SIMD进行向量化处理,在TVM中的向量化处理是使用冒号来表示的,主要为X[idex:idex+64];
s[C].vectorize(yi)
lower表示为:
- # from tvm.script import ir as I
- # from tvm.script import tir as T
-
- @I.ir_module
- class Module:
- @T.prim_func
- def main(A: T.Buffer((512, 5120), "float32"), B: T.Buffer((5120, 512), "float32"), C: T.Buffer((512, 512), "float32")):
- T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
- for x_outer, y_outer in T.grid(8, 8):
- C_1 = T.Buffer((262144,), data=C.data)
- for x_inner_init in range(64):
- C_1[x_outer * 32768 + x_inner_init * 512 + y_outer * 64:x_outer * 32768 + x_inner_init * 512 + y_outer * 64 + 64] = T.Broadcast(T.float32(0), 64)
- for k_outer, x_inner, k_inner in T.grid(320, 64, 16):
- cse_var_2: T.int32 = y_outer * 64
- cse_var_1: T.int32 = x_outer * 32768 + x_inner * 512 + cse_var_2
- A_1 = T.Buffer((2621440,), data=A.data)
- B_1 = T.Buffer((2621440,), data=B.data)
- C_1[cse_var_1:cse_var_1 + 64] = C_1[cse_var_1:cse_var_1 + 64] + T.Broadcast(A_1[x_outer * 327680 + x_inner * 5120 + k_outer * 16 + k_inner], 64) * B_1[k_outer * 8192 + k_inner * 512 + cse_var_2:k_outer * 8192 + k_inner * 512 + cse_var_2 + 64]
vectorize之后速度提升38.9%,效果明显;
最后一种优化方式为多核并行;
s[C].parallel(xo)
lower表示为:
- # from tvm.script import ir as I
- # from tvm.script import tir as T
-
- @I.ir_module
- class Module:
- @T.prim_func
- def main(A: T.Buffer((512, 5120), "float32"), B: T.Buffer((5120, 512), "float32"), C: T.Buffer((512, 512), "float32")):
- T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
- for x_outer in T.parallel(8):
- for y_outer in range(8):
- C_1 = T.Buffer((262144,), data=C.data)
- for x_inner_init in range(64):
- C_1[x_outer * 32768 + x_inner_init * 512 + y_outer * 64:x_outer * 32768 + x_inner_init * 512 + y_outer * 64 + 64] = T.Broadcast(T.float32(0), 64)
- for k_outer, x_inner, k_inner in T.grid(320, 64, 16):
- cse_var_2: T.int32 = y_outer * 64
- cse_var_1: T.int32 = x_outer * 32768 + x_inner * 512 + cse_var_2
- A_1 = T.Buffer((2621440,), data=A.data)
- B_1 = T.Buffer((2621440,), data=B.data)
- C_1[cse_var_1:cse_var_1 + 64] = C_1[cse_var_1:cse_var_1 + 64] + T.Broadcast(A_1[x_outer * 327680 + x_inner * 5120 + k_outer * 16 + k_inner], 64) * B_1[k_outer * 8192 + k_inner * 512 + cse_var_2:k_outer * 8192 + k_inner * 512 + cse_var_2 + 64]
这里开了八个线程,做了并行处理,效果又提升了百分之83.8%。
优化 | 时间(s) |
---|---|
default | 2.930760 |
+tile | 0.261980 |
+vectorize | 0.16041 |
+parallel | 0.025833 |
整体来说tvm的TE使用还是非常方便,其中遇到了一点点小问题是:unroll,我想指定unroll的行数时,还没找到具体的使用技巧,只能unroll一整个维度;
赞
踩
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。