当前位置:   article > 正文

【Spikingjelly】SNN框架教程的代码解读_5_ann转化为snn

ann转化为snn

时间驱动

ANN转换SNN

为何要进行转换的方法实现SNN?
转化 SNN (ANN-converted SNN) 是为了在已发展出的深度学习成果上,与硬件结合从而进一步利用事件驱动特性的低能耗优势,从 ANN 的视角出发的一种 SNN 实现方法。其作为间接监督性学习算法,基本理念是在使用 ReLU 函数的 ANN 网络中, 用 SNN 中频率编码下的平均脉冲发放率来近似 ANN 中的连续激活值。

转换方法实现SNN的基本步骤?
先对原始的 ANN 进行训练得到结果后, 再设计相同的拓扑结构将其转换为SNN. 这样做,转换 SNN 的训练实际上依赖的仍是在 ANN 的学习算法,即反向传播, 但是因为没有直接训练 SNN 时的一些困难. 所以就性能表现而言, 转换 SNN保持着与ANN很小的差距, 这一点在大的网络结构和数据集上的良好表现得到了印证。

a. 理论基础

ANN中的ReLU神经元非线性激活和SNN中IF神经元(采用减去阈值方式重置)的发放率有着极强的相关性,我们可以借助这个特性来进行转换。这里说的神经元更新方式,也就是Soft减去阈值的方式。

对IF神经元脉冲发放频率和输入的关系进行实验:我们给与恒定输入到IF神经元,观察其输出脉冲和脉冲发放频率。首先导入相关的模块,新建IF神经元层,确定输入并画出每个IF神经元的输入 x i x_i xi

import torch
from spikingjelly.clock_driven import neuron
from spikingjelly import visualizing
from matplotlib import pyplot as plt
import numpy as np

plt.rcParams['figure.dpi'] = 200
if_node = neuron.IFNode(v_reset=None)
T = 128
x = torch.arange(-0.2, 1.2, 0.04)
plt.scatter(torch.arange(x.shape[0]), x)
plt.title('Input $x_{i}$ to IF neurons')
plt.xlabel('Neuron index $i$')
plt.ylabel('Input $x_{i}$')
plt.grid(linestyle='-.')
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

在这里插入图片描述
其中IF神经元的动态微分方程
d V ( t ) d t = R m I ( t ) \frac{\mathrm{d}V(t)}{\mathrm{d} t} = R_{m}I(t) dtdV(t)=RmI(t)
相应的差分方程:
V ( t ) − V ( t − 1 ) = X ( t ) V(t) - V(t-1) = X(t) V(t)V(t1)=X(t)
类实现如下:

class IFNode(BaseNode):
    def __init__(self, v_threshold=1.0, v_reset=0.0, surrogate_function=surrogate.Sigmoid(), detach_reset=False, monitor_state=False):
    '''
        Integrate-and-Fire 神经元模型,可以看作理想积分器,无输入时电压保持恒定,不会像LIF神经元那样衰减。
    '''
        super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, monitor_state)

    def neuronal_charge(self, dv: torch.Tensor):
        self.v += dv #这里的dv就是上一层的输出,公式中的X(t)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

接下来,将输入送入到IF神经元层,并运行T=128步,观察各个神经元发放的脉冲、脉冲发放频率:

s_list = []
for t in range(T):
    s_list.append(if_node(x).unsqueeze(0))

out_spikes = np.asarray(torch.cat(s_list))
visualizing.plot_1d_spikes(out_spikes, 'IF neurons\' spikes and firing rates', 't', 'Neuron index $i$')
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

在这里插入图片描述
可以发现,脉冲发放的频率在一定范围内,与输入 x i x_i xi的大小成正比。

画出IF神经元脉冲发放频率和输入 x i x_i xi的曲线,并与RELU( x i x_i xi)对比:

    plt.subplot(1, 2, 1)
    firing_rate = np.mean(out_spikes, axis=0)
    plt.plot(x, firing_rate)
    plt.title('Input $x_{i}$ and firing rate')
    plt.xlabel('Input $x_{i}$')
    plt.ylabel('Firing rate')
    plt.grid(linestyle='-.')

    plt.subplot(1, 2, 2)
    plt.plot(x, x.relu())
    plt.title('Input $x_{i}$ and ReLU($x_{i}$)')
    plt.xlabel('Input $x_{i}$')
    plt.ylabel('ReLU($x_{i}$)')
    plt.grid(linestyle='-.')
    plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

在这里插入图片描述
可以发现,两者的曲线几乎一致。需要注意的是,脉冲频率不可能高于1,因此IF神经元无法拟合ANN中ReLU的输入大于1的情况。
用SNN频率编码下的平均脉冲发放率来近似ANN中的连续激活值,这是转换SNN最重要的思想。详细的数学证明可以参考教程中提到的论文。

b. 转换和仿真

由于主要目的是笔者记录以便查看,所以ANN-to-SNN转换的具体方法不进行展开。在教程中提到的论文均有提及,感兴趣可以阅读,下面主要介绍转换代码。

ann-to-snn目前实现了两套实现:基于ONNX 和 基于PyTorch, 在框架中被称为 ONNX kernel 和 PyTorch kernel。 我们下面介绍PyTorch(因为看不懂ONNX)

转换需要先训练一个ann,此处按传统的方法写即可,不予介绍。
我们从ann = torch.load(os.path.join(log_dir, model_name + '.pkl')),获得训练好的ann开始。
调用parser方法,直接获得转换后的SNN.

    onnxparser = parser(name=model_name,
                        log_dir=log_dir + '/parser',
                        kernel='pytorch') # 调用parser,使用kernel为pytorch
    snn = onnxparser.parse(ann, norm_data.to(parser_device)) #获得转换的SNN
  • 1
  • 2
  • 3
  • 4

重点看一下parse方法,定义如下:

    def parse(self, model: nn.Module, data: torch.Tensor, **kargs) -> nn.Module:
        model_name = model.__class__.__name__
        model.eval()

        for m in model.modules():
            if hasattr(m,'weight'):
                assert(data.get_device() == m.weight.get_device())

        try:
            model = z_norm_integration(model=model, z_norm=self.config['z_norm'])
        except KeyError:
            pass
        layer_reduc = False
        for m in model.modules():
            if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d, nn.BatchNorm3d)):
                layer_reduc = True #有BN层就需要进行参数融合,这里叫层reduction
                break
        if self.kernel.lower() == 'onnx':
            try:
                import onnx
                import onnxruntime as ort
            except ImportError:
                print(Warning("Package onnx or onnxruntime not found: launch pytorch convert engine,"
                              " only support very simple arctitecture"))
                self.kernel = 'pytorch'
            else:
                pass

        if self.kernel.lower() == 'onnx':
            # use onnx engine

            data = data.cpu()
            model = model.cpu()

            import spikingjelly.clock_driven.ann2snn.kernels.onnx as onnx_kernel

            onnx_model = onnx_kernel.pytorch2onnx_model(model=model, data=data, log_dir=self.config['log_dir'])
            # onnx_kernel.print_onnx_model(onnx_model.graph)
            onnx.checker.check_model(onnx_model)
            if layer_reduc:
                onnx_model = onnx_kernel.layer_reduction(onnx_model)
            onnx.checker.check_model(onnx_model)
            onnx_model = onnx_kernel.rate_normalization(onnx_model, data.numpy(), **kargs) #**self.config['normalization']
            onnx_kernel.save_model(onnx_model,os.path.join(self.config['log_dir'],model_name+".onnx"))

            convert_methods = onnx2pytorch
            try:
                user_defined = kargs['user_methods']
                assert (user_defined is dict)
                for k in user_defined:
                    convert_methods.add_method(op_name=k, func=user_defined[k])
            except KeyError:
                print('no user-defined conversion method found, use default')
            except AssertionError:
                print('user-defined conversion method should be organized into a dict!')
            model = onnx_kernel.onnx2pytorch_model(onnx_model, convert_methods)
        else: #重点看这几行
            # use pytorch engine

            import spikingjelly.clock_driven.ann2snn.kernel.pytorch as pytorch_kernel

            if layer_reduc:
                model = pytorch_kernel.layer_reduction(model)
            model = pytorch_kernel.rate_normalization(model, data)#, **self.config['normalization']

        self.ann_filename = os.path.join(self.config['log_dir'], model_name + ".pth")
        torch.save(model, self.ann_filename)
        model = self.to_snn(model)
        return model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69

我们的self.kernel.lower() == 'pytorch',所以关注else后的代码。

model = pytorch_kernel.layer_reduction(model)
model = pytorch_kernel.rate_normalization(model, data)
这两行代码做两件事,分别是BN层(BatchNorm)的融合、最大值归一化。
首先是layer_reduction:

def layer_reduction(model: nn.Module) -> nn.Module:
    relu_linker = {}  # 字典类型,用于通过relu层在network中的序号确定relu前参数化模块的序号
    param_module_relu_linker = {}  # 字典类型,用于通过relu前在network中的参数化模块的序号确定relu层序号
    activation_range = defaultdict(float)  # 字典类型,保存在network中的序号对应层的激活最大值(或某分位点值)

    module_len = 0
    module_list = nn.ModuleList([])
    last_parammodule_idx = 0
    for n, m in model.named_modules():
        Name = m.__class__.__name__
        # 加载激活层
        if isinstance(m,nn.Softmax):
            Name = 'ReLU'
            print(UserWarning("Replacing Softmax by ReLU."))
        if isinstance(m,nn.ReLU) or Name == "ReLU":
            module_list.append(m)
            relu_linker[module_len] = last_parammodule_idx
            param_module_relu_linker[last_parammodule_idx] = module_len
            module_len += 1
            activation_range[module_len] = -1e5
        # 加载BatchNorm层
        if isinstance(m,(nn.BatchNorm2d,nn.BatchNorm1d)):
            if isinstance(module_list[last_parammodule_idx], (nn.Conv2d,nn.Linear)): #这一层是BN,上一层是Conv2d,Linear,进行absorb
                absorb(module_list[last_parammodule_idx], m)
            else:
                module_list.append(copy.deepcopy(m))
        # 加载有参数的层
        if isinstance(m,(nn.Conv2d,nn.Linear)):
            module_list.append(m)
            last_parammodule_idx = module_len
            module_len += 1
        # 加载无参数层
        if isinstance(m,nn.MaxPool2d):
            module_list.append(m)
            module_len += 1
        if isinstance(m,nn.AvgPool2d):
            module_list.append(nn.AvgPool2d(kernel_size=m.kernel_size, stride=m.stride, padding=m.padding))
            module_len += 1
        # if isinstance(m,nn.Flatten):
        if m.__class__.__name__ == "Flatten":
            module_list.append(m)
            module_len += 1
    network = torch.nn.Sequential(*module_list)
    setattr(network,'param_module_relu_linker',param_module_relu_linker)
    setattr(network, 'activation_range', activation_range)
    return network
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46

截取教程原话,absorb按照以下公式对BN参数进行吸收
在这里插入图片描述

def absorb(param_module, bn_module):
    if_2d = len(param_module.weight.size()) == 4  # 判断是否为BatchNorm2d
    bn_std = torch.sqrt(bn_module.running_var.data + bn_module.eps)
    if not if_2d:
        if param_module.bias is not None:
            param_module.weight.data = param_module.weight.data * bn_module.weight.data.view(-1, 1) / bn_std.view(
                -1,
                1)
            param_module.bias.data = (param_module.bias.data - bn_module.running_mean.data.view(
                -1)) * bn_module.weight.data.view(-1) / bn_std.view(
                -1) + bn_module.bias.data.view(-1)
        else:
            param_module.weight.data = param_module.weight.data * bn_module.weight.data.view(-1, 1) / bn_std.view(
                -1,
                1)
            param_module.bias.data = (torch.zeros_like(
                bn_module.running_mean.data.view(-1)) - bn_module.running_mean.data.view(
                -1)) * bn_module.weight.data.view(-1) / bn_std.view(-1) + bn_module.bias.data.view(-1)
    else: #看这里
        if param_module.bias is not None: #前层有偏置,按照公式来
            param_module.weight.data = param_module.weight.data * bn_module.weight.data.view(-1, 1, 1,
                                                                                             1) / bn_std.view(-1, 1,
                                                                                                              1, 1)
            param_module.bias.data = (param_module.bias.data - bn_module.running_mean.data.view(
                -1)) * bn_module.weight.data.view(-1) / bn_std.view(
                -1) + bn_module.bias.data.view(-1)
        else:
            param_module.weight.data = param_module.weight.data * bn_module.weight.data.view(-1, 1, 1,
                                                                                             1) / bn_std.view(-1, 1,
                                                                                                              1, 1)
            param_module.bias.data = (torch.zeros_like(
                bn_module.running_mean.data.view(-1)) - bn_module.running_mean.data.view(
                -1)) * bn_module.weight.data.view(-1) / bn_std.view(-1) + bn_module.bias.data.view(-1)
    return param_module
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

然后是rate_normalization:
这个是最大归一化方法,在2015年Diehl提出,用于解决转换SNN中出现的激活值过小导致的脉冲发放率过低,从而导致精度的降低。2017年Rueckauer等人加入了0.99分位点的方法,采用99.9%的最大值进行归一化,进一步改善了脉冲发放率不足的问题。截取教程原话
在这里插入图片描述
函数实现如下:

def rate_normalization(model: nn.Module, data: torch.Tensor, **kargs) -> nn.Module:
    if not hasattr(model,"activation_range") or not hasattr(model,"param_module_relu_linker"):
        raise(AttributeError("run layer_reduction first!"))
    try:
        robust_norm = kargs['robust']
    except KeyError:
        robust_norm = False
    x = data
    i = 0
    with torch.no_grad():
        for n, m in model.named_modules():
            Name = m.__class__.__name__
            if Name in ['Conv2d', 'ReLU', 'MaxPool2d', 'AvgPool2d', 'Flatten', 'Linear']:
                x = m.forward(x)
                a = x.cpu().numpy().reshape(-1)
                if robust_norm:
                    model.activation_range[i] = np.percentile(a[np.nonzero(a)], 99.9)
                else:
                    model.activation_range[i] = np.max(a)
                i += 1
    i = 0
    last_lambda = 1.0
    for n, m in model.named_modules():
        Name = m.__class__.__name__
        if Name in ['Conv2d', 'ReLU', 'MaxPool2d', 'AvgPool2d', 'Flatten', 'Linear']:
            if Name in ['Conv2d', 'Linear']:
                relu_output_layer = model.param_module_relu_linker[i]
                if hasattr(m, 'weight') and m.weight is not None:
                    m.weight.data = m.weight.data * last_lambda / model.activation_range[relu_output_layer]
                if hasattr(m, 'bias') and m.bias is not None:
                    m.bias.data = m.bias.data / model.activation_range[relu_output_layer]
                last_lambda = model.activation_range[relu_output_layer]
            i += 1
    return model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

经过参数融合和归一化之后,我们就获得了与ANN有相同的拓扑结构的SNN,但还需转换ANN的其他一些操作到SNN。
这里主要是RELU用IF神经元代替、MaxPooling用AvgPooling代替,实现如下:

    def to_snn(self, model: nn.Module, **kargs) -> nn.Module:
        for name, module in model._modules.items():
            if hasattr(module, "_modules"):
                model._modules[name] = self.to_snn(module, **kargs)
            if module.__class__.__name__ == "AvgPool2d":
                new_module = nn.Sequential(module, neuron.IFNode(v_reset=None))
                model._modules[name] = new_module
            if "BatchNorm" in module.__class__.__name__:
                try:
                    new_module = nn.Sequential(module, neuron.NSIFNode(v_threshold=(-1.0, 1.0), v_reset=None))
                except AttributeError:
                    new_module = module
                model._modules[name] = new_module
            if module.__class__.__name__ == "ReLU":
                new_module = neuron.IFNode(v_reset=None)
                model._modules[name] = new_module
            try:
                if module.__class__.__name__ == 'PReLU':
                    p = module.weight
                    assert (p.size(0) == 1 and p != 0)
                    if -1 / p.item() > 0:
                        model._modules[name] = neuron.NSIFNode(v_threshold=(1.0 / p.item(), 1.0),
                                                                     bipolar=(1.0, 1.0), v_reset=None)
                    else:
                        model._modules[name] = neuron.NSIFNode(v_threshold=(-1 / p.item(), 1.0),
                                                                     bipolar=(-1.0, 1.0), v_reset=None)
            except AttributeError:
                assert False, 'NSIFNode has been removed.'
            if module.__class__.__name__ == "MaxPool2d":
                new_module = nn.AvgPool2d(
                    kernel_size=module.kernel_size,
                    stride=module.stride,
                    padding=module.padding)
                model._modules[name] = new_module
        return model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

c. snn模型评估

之前训练的ann可以达到98.5%的准确率。下面是构建snn的仿真器

    # 定义用于分类的SNN仿真器
    # define simulator for classification task
    sim = classify_simulator(snn,
                             log_dir=log_dir + '/simulator',
                             device=simulator_device,
                             canvas=fig
                             )
    # 仿真SNN
    # Simulate SNN
    sim.simulate(test_data_loader,
                T=T,
                online_drawer=True,
                ann_acc=ann_acc,
                fig_name=model_name,
                step_max=True
                )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

第一个Batch(100)上分类测试结果:
在这里插入图片描述
为啥转换后的精度比原始的大,不太清楚
可以看到提高仿真时间步长,有利于提高精度。

--------------------simulator summary--------------------
time elapsed: 96.4521272 (sec)
---------------------------------------------------------
  • 1
  • 2
  • 3

d. 结果、分析

转换SNN是追求高性能SNN的一种实现方式,但是之前也总结过诸多不足:

整体而言,转换的 SNN 存在一些局限性:显而易见的是在转换的过程中ANN的一些条件限制:例如激活函数的选择和偏置的置零,另外在深度的神经网络,脉冲神经网络若要使用平均脉冲发放率代替模拟的激活值,相比与 ANN 的前向推理,SNN通常要选取大的时间步长,进行上百步的时间模拟,这增加了额外的延时,反而与 SNN功耗低的目标不吻合。同时转换的 SNN 更多关注的是转换上的一些操作,训练算法依赖的仍然是 ANN 的反向传播,就训练方式来讲,还不够有很强的生物解释性。

感觉这篇写的不是清楚,深入了解需多看原教程和提到的论文

参考

原文教程:ANN转换SNN.

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/352414
推荐阅读
相关标签
  

闽ICP备14008679号