赞
踩
Taichi是一款高性能空间稀疏数据结构的计算引擎。其涉及到的计算密集型任务全部由C++写成,而前端则选择了易于上手且灵活性强的Python。乍一看重点应该是C++,然而一个好的前端设计同样很重要,因为它是用户认识Taichi的第一关。这里的前端不单指Python本身,也是Taichi在Python的基础上开发出的自己的一套使用规则。
故事的起源是从这个Issue开始的:https://github.com/taichi-dev/taichi/issues/548
Unifyti.kernel
andti.classkernel
先来说一下这两个decorator分别在干什么。
一般来说,Taichi用户需要用@ti.kernel
来修饰一个用于计算的Python function。举个例子:
- import taichi as ti
-
- x = ti.var(ti.i32, shape=(42,))
-
- @ti.kernel
- def compute():
- for i in x:
- x[i] += 1
Taichi同时还支持OOP。但为此,Taichi需要两个decorator:@ti.data_oriented
和@ti.classkernel
,使用方法如下:
- import taichi as ti
-
- # 下文会单独讲解@ti.data_oriented
- @ti.data_oriented
- class X(object):
- def __init__(self):
- self.x = ti.var(ti.i32, shape=(42,))
-
- @ti.classkernel
- def compute(self):
- for i in self.x:
- self.x[i] += 1
可以看到,为了正确使用Taichi,用户需要记住@ti.classkernel
和@ti.kernel
各自的使用场景。这在一定程度上增加了用户的心智负担,因此这个前端的设计仍有改进的空间。
改进目标很明确:只留下@ti.kernel
就好。
如何做这个改进?思路也比较清晰,判断一下被修饰的函数是否是一个class method即可:
inspect.ismethod(func)
思路1
一般来说,一个Python decorator大概长这样:
- def decorator(func):
- @functools.wraps(func)
- def wrapped(*args, **kwargs):
- func(*args, **kwargs)
- return wrapped
一种正常人的思路是,把这个决定放到wrapped
的执行期。由于在wrapped
中我们有了args
,我们可以尝试查看args[0]
的一些元数据来确定func
是否属于某个class。
先把我最开始想到的方法写出来:
- # 判断func是否是个某个class的函数
- def is_func_inside_class(x, func):
- try:
- # __wrapped__是因为func已经被@functools.wraps修饰过
- return type(x).__dict__[func.__name__].__wrapped__ == func
- except:
- return False
-
- def decorator(func):
- @functools.wrap(func)
- def wrapped(*args, **kwargs):
- is_classkernel = False
- try:
- is_classkernel = is_func_inside_class(args[0], func)
- except:
- pass
- # ...
-
- return wrapped
为了理解is_func_inside_class
在干什么,我们需要理解Python class中的function究竟是如何被绑定到一个instance上的,即——self
是从哪里来的?
还是通过例子来解释,例子来自 https://stackoverflow.com/a/18342905/12003165
- >>> X.compute
- <function X.compute at 0x7fc2a0016d40>
-
- >>> x = X()
- >>> x.compute
- <bound method X.compute of <__main__.X object at 0x7fd5f0031fd0>>
可以看到,通过class X
本身和通过instance x
来获取compute
,返回的结果是不一样的。前者仍然是一个function
,而后者变成了bound method
。这里发生的事情涉及到了Python descriptor的概念。
长话短说,在执行x.compute
时候,Python内部发生了这么一个过程:
- try:
- #1
- return x.__dict__['compute']
- except KeyError:
- value = type(x).__dict__['compute']
- try:
- #2
- return value.__get__(x)
- except AttributeError:
- #3
- return value
X.compute
作为class X
中的function,并没有被存在instance x
的__dict__
中,因此#1
抛出KeyError
X.compute
被存在了class X
的__dict__
中,因此value
就指向X.compute
本身。同时function
定义了__get__
,因此我们从#2
返回。那么,把X.compute
绑定到x
上就是发生在function.__get__
了。这是其概念上的实现:
- class function(object):
- # Built-in function class
- # ...
- def __get__(self, instance):
- return Boundmethod(self, instance)
- # ...
-
- class BoundMethod(object):
- def __init__(self, func, instance):
- self.__func__ = func
- self.__self__ = instance
-
- def __call__(self, *args, **kwargs):
- return self.__func__(self.__self__, *args, **kwargs)
可以看见,BoundMethod
不过是同时存了X.compute
的function pointer(无状态的plain function)以及instance x
。被调用时,它会将x
绑定到X.compute
的第一个参数上。
最后,由于执行期间任意一步都可能会挂掉(args
是空的、__dict__
中找不到等),这个判别式被放到了一个try
block中。一旦挂了立刻返回False
。
到这里,我们似乎是顺利解决了这个问题?
Taichi另一个特性是反向自动微分。所有被@ti.kernel
修饰过的函数都自动带有一个grad
的callable,调用它将计算这个kernel的导数。
- @ti.kernel
- def compute():
- # ...
-
- compute()
- # 自动生成的导数kernel
- compute.grad()
这个grad
同样也是在被@ti.kernel
修饰期间加上的。
但是这就造成了一个问题。由于上面这个方案需要在wrapped
执行期间才能判定function是否属于class,而grad
针对class function或plain function是由完全不同的两种方案实现的。这就导致了一个限制:想要使用compute.grad()
,用户必须至少运行一次compute
本身,使得wrapped
得以执行。
这个人为限制是由于实现方法本身并非最优导致的,有没有更给力的方法呢?
思路2
前面说完了正常人的思路。而下面这个思路,我第一次是跪着看完的。详情见 https://stackoverflow.com/questions/8793233/python-can-a-decorator-determine-if-a-function-is-being-defined-inside-a-class
PO主的问题和我们的如出一辙:python decorator可否判断所修饰的function是否在一个class中?
高赞回答的思路是...检查定义class时的stackframe!
具体来说,对于下面这个例子:
- def decor(func):
- import inspect
- frames = inspect.stack()
- # ...
-
-
- class X(object):
- @decor
- def compute(self):
- # ...
@decor
本身作用到compute
是在Python解释class X
的定义期间执行的。也就是说,当执行到frames = inspect.stack()
这一步时,我们还在定义class X
的过程中。此时的frames
大概长这样
- FrameInfo(..., code_context=[' frames = inspect.stack()n'], index=0)
- |- FrameInfo(..., code_context=[' frames = @decn'], index=1)
- |- FrameInfo(..., code_context=[' frames = class X(object):n'], index=2)
可以看到,在这种情况下,index=2
的stackframe正是class X
本身,因此我们通过检查frames[2].code_context[0].startswith('class ')
就可以完成这个判断[注1]。我们不再需要把这个判断推迟到wrapped
执行时进行。
grad
是如何被添加的
讲到这里,我们最开始想要解决的问题已经结束了。然而Taichi本身实现的grad
也非常巧妙,值得说道一番。
对于plain function kernel,grad
并没有什么特殊的,向返回的wrapped
对象上添加grad
kernel即可。
- def kernel(func):
- is_classkernel = check_inside_class_by_stackframe()
- primal = Kernel(func, is_grad=False, ...)
- adjoint = Kernel(func, is_grad=True, ...)
-
- if is_classkernel:
- @functools.wraps(func)
- def wrapped(*args, **kwargs):
- # TODO: 如何实现???
- else:
- @functools.wraps(func)
- def wrapped(*args, **kwargs):
- primal(*args, **kwargs)
- wrapped.grad = adjoint
- # ...
- return wrapped
但对于OOP kernel,这个问题变得有意思了很多。
先来设想一下我们如何调用OOP kernel grad,非常简单:
x.compute.grad()
然而这里有个问题,我们需要把grad
绑定到x
上,而x
和grad
之间隔着compute
。
如果我们把.
的作用域划分的更清楚一些,如下图所示:
- (x.compute).grad()
- |---------| |
- |---------------|
可以看到,如果令x.compute
返回某个包含了x
的proxy object(类比前面提到的BoundMethod
),那么这个proxy在调用grad()
时候可以自动把x
作为第一个参数传给grad()
。
Taichi实现OOP grad的原理正是如此。在accessx
的某个attribute时,如果能利用某种方法截获这个attribute,并且做一些检查,判断这个attribute是不是一个kernel。如果是,我们就把它变成一个proxy。否则的话我们退化到Python本身对attribute的搜索规则。
到这里,答案已经呼之欲出了。Python的__getattribute__
恰好可以满足我们的需求。进一步的,想要实现这个方案,我们需要@ti.kernel
和@ti.data_oriented
这两个decorator配合工作。前者会在返回的object上添加几个私有的标记,而后者则override了所修饰的class本身的__getattribute__
,来读取这些标记。
- def kernel(func):
- is_classkernel = check_inside_class_by_stackframe()
- primal = ...
- adjoint = ...
- # ...
- wrapped._is_wrapped_kernel = True
- wrapped._classkernel = is_classkernel
- wrapped._primal = primal
- wrapped._adjoint = adjoint
- return wrapped
-
- def data_oriented(cls):
- def getattr(self, item):
- x = super(cls, self).__getattribute__(item)
- if hasattr(x, '_is_wrapped_kernel'):
- wrapped = x.__func__
- if wrapped._classkernel:
- return BoundedDifferentiableMethod(self, wrapped)
- return x
-
- cls.__getattribute__ = getattr
- return cls
之前提到的proxy就是这个BoundedDifferentiableMethod
。其原理也和将无状态的function
变为bound method
的方法类似,以下为其实现:
- class BoundedDifferentiableMethod:
- def __init__(self, kernel_owner, wrapped_kernel_func):
- self._kernel_owner = kernel_owner
- self._primal = wrapped_kernel_func._primal
- self._adjoint = wrapped_kernel_func._adjoint
-
- def __call__(self, *args, **kwargs):
- return self._primal(self._kernel_owner, *args, **kwargs)
-
- def grad(self, *args, **kwargs):
- return self._adjoint(self._kernel_owner, *args, **kwargs)
我们还剩下最后一个小细节。之前在实现kernel
这个decorator时,我们并没有给出在is_classkernel == True
的情况下wrapped
的实现。
Take a guess of its implementation first :)
事实上,它的实现毫无影响。因为在这个情况下,使用这个kernel时会被BoundedDifferentiableMethod
接管,因此wrapped
的实现并不会调用。
为了确保这个invariant,Taichi在这里只不过是抛出异常而已:
- def kernel(func):
- is_classkernel = check_inside_class_by_stackframe()
- # ...
- if is_classkernel:
- @functools.wraps(func)
- def wrapped(*args, **kwargs):
- raise KernelDefError(...)
- # ...
- return wrapped
备注
code_context
是有一些区别的,需要分别处理。以我个人的品味来看,这并非一个很优雅的解决方案。但是写软件本身就是妥协的过程:没有其他方案的情况下,能用的就是最好的。Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。