赞
踩
MXNet中,gluon.Block
类和gluon.HybridBlock
类,和Pytorch中的nn.Module
类一样,我们通过继承Block类和HybridBlock类可以很灵活的搭建我们自己的网络模型,这里总结一下HybridBlock
类使用过程中的一些注意点。
HybridBlock
类继承至Block
类,所以HybridBlock
类有Block
类的全部方法和属性。HybridBlock
同时支持符号式编程和命令式编程,HybridBlock
类可以调用hybridize()方法,从而可以从命令式变为符号式,从而将动态图转化为静态图,提高模型的计算性能和移植性。下面是两者的比较:
HybridBlock类 | Block类 | |
---|---|---|
重写方法 | __init__() 、hybrid_forward(self, F, x, *args, **kwargs) |
__init__() 、forwad(self,x,*args) |
是否支持符号式 | 是 | 否 |
支持输入参数 | 位置式参数、关键字参数 | 只支持位置式参数 |
是否支持导出符号模型 | 是 | 否 |
可以看出HybridBlock
类除了多支持符号式编程外,和Block
基本没什么区别,但是注意到支持输入参数那一栏,hybrid_forward
函数还支持输入关键字参数,这点也和Block
不一样,下面详细分析一下hybrid_forward
的调用过程。
当我们构建一个HybridBlock
类后,需要重写其|__init__()
、hybrid_forward()
方法,而我们在源码中可以看到,当一个HybridBlock
类进行forward
操作时,其流程如下:
__call__()
-------->forward()
-------->hybrid_forward()
可以看出HybridBlock
类是通过forward()
方法中来调用hybrid_forward()
。由于HybridBlock
类中的forward()
方法已经被重写过了,所以我们只需要重写hybrid_forward()
就可以了,其中forward()
函数如下:
def forward(self, x, *args):
"""Defines the forward computation. Arguments can be either
:py:class:`NDArray` or :py:class:`Symbol`."""
if isinstance(x, NDArray):
with x.context as ctx:
if self._active:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。