当前位置:   article > 正文

[NCNN学习笔记]-1

[NCNN学习笔记]-1

1、前言

本次继续学习NCNN,希望能够坚持,往期学习NCNN的链接如下。

[NCNN学习笔记]-0

2、学习内容

2.1、batchnorm_arm.cpp

这个章节学习NCNN中batchnorm在NEON上的实现。batchnorm的学习可参考链接:https://zhuanlan.zhihu.com/p/93643523

在NCNN中batchnorm的预处理为ncnn-master\src\layer\batchnorm.cpp

int BatchNorm::load_model(const ModelBin& mb)
{
    slope_data = mb.load(channels, 1);
    if (slope_data.empty())
        return -100;

    mean_data = mb.load(channels, 1);
    if (mean_data.empty())
        return -100;

    var_data = mb.load(channels, 1);
    if (var_data.empty())
        return -100;

    bias_data = mb.load(channels, 1);
    if (bias_data.empty())
        return -100;

    a_data.create(channels);
    if (a_data.empty())
        return -100;
    b_data.create(channels);
    if (b_data.empty())
        return -100;
    
    // 通过https://zhuanlan.zhihu.com/p/93643523中的公式,可以很推算出下面的过程
    for (int i = 0; i < channels; i++)
    {
        float sqrt_var = sqrtf(var_data[i] + eps);
        if (sqrt_var == 0.f)
            sqrt_var = 0.0001f; // sanitize divide by zero
        a_data[i] = bias_data[i] - slope_data[i] * mean_data[i] / sqrt_var;
        b_data[i] = slope_data[i] / sqrt_var;
    }
    return 0;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36

通过上面的代码,明白了a_data和b_data的作用,下面就正式开始学习在数据处理过程中的batchnorm吧

int BatchNorm_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
    int elembits = bottom_top_blob.elembits();
    int dims = bottom_top_blob.dims;         // 数据维度,最多为4
    int elempack = bottom_top_blob.elempack; // 每个数据可被分为多少组   例如float32x4_t可被分为4组

    if (elempack == 4)    // float32x4_t、int32x4_t、float16x4_t
    {
        // 对一行进行batchnorm
        if (dims == 1 )
        {
            int w = bottom_top_blob.w; 
            #pragma omp parallel for num_threads(opt.num_threads)
            for (int i = 0; i < w; i++)
            {
                float* ptr = (float*)bottom_top_blob + i * 4;   // 每间隔4个数据取一次地址

                float32x4_t _a = vld1q_f32((const float*)a_data + i * 4); 
                float32x4_t _b = vld1q_f32((const float*)b_data + i * 4);

                float32x4_t _p = vld1q_f32(ptr);
                _p = vmlaq_f32(_a, _p, _b);   // _a + _p.*b     y = a_data + b_data * x
                vst1q_f32(ptr, _p);
            }
        }
		 // 对每一行batchnorm
        if (dims == 2)
        {
            int w = bottom_top_blob.w;
            int h = bottom_top_blob.h;

            #pragma omp parallel for num_threads(opt.num_threads)
            for (int i = 0; i < h; i++)
            {
                float32x4_t _a = vld1q_f32((const float*)a_data + i * 4);
                float32x4_t _b = vld1q_f32((const float*)b_data + i * 4);
                float* ptr = bottom_top_blob.row(i);
                for (int j = 0; j < w; j++)  
                {
                    float32x4_t _p = vld1q_f32(ptr);
                    _p = vmlaq_f32(_a, _p, _b);
                    vst1q_f32(ptr, _p);

                    ptr += 4;
                }
            }
        }
		// 针对channel进行batchnorm
        if (dims == 3 || dims == 4)
        {
            int w = bottom_top_blob.w;
            int h = bottom_top_blob.h;
            int d = bottom_top_blob.d;
            int c = bottom_top_blob.c;
            int size = w * h * d;

            #pragma omp parallel for num_threads(opt.num_threads)
            for (int q = 0; q < c; q++)
            {
                float32x4_t _a = vld1q_f32((const float*)a_data + q * 4);
                float32x4_t _b = vld1q_f32((const float*)b_data + q * 4);

                float* ptr = bottom_top_blob.channel(q);

                for (int i = 0; i < size; i++)
                {
                    float32x4_t _p = vld1q_f32(ptr);
                    _p = vmlaq_f32(_a, _p, _b);
                    vst1q_f32(ptr, _p);

                    ptr += 4;
                }
            }
        }

        return 0;
    }
    return 0;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79

2.2、bias_arm.cpp

这个章节学习ncnn中的的bias计算

int Bias_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
    int w = bottom_top_blob.w;
    int h = bottom_top_blob.h;
    int d = bottom_top_blob.d;
    int channels = bottom_top_blob.c;
    int size = w * h * d;
    const float* bias_ptr = bias_data;
    #pragma omp parallel for num_threads(opt.num_threads)
    for (int q = 0; q < channels; q++)
    {
        float* ptr = bottom_top_blob.channel(q);  // 第q个channel的数据
        float bias = bias_ptr[q];
        int nn = size >> 2; // 处理4的整数倍的数据
        int remain = size - (nn << 2);     // 剩余处理向量
        float32x4_t _bias = vdupq_n_f32(bias);
        for (; nn > 0; nn--)
        {
            float32x4_t _p = vld1q_f32(ptr);
            float32x4_t _outp = vaddq_f32(_p, _bias);  // x + bias
            vst1q_f32(ptr, _outp);   
            ptr += 4;
        }
        // 剩余向量使用c语言计算
        for (; remain > 0; remain--)
        {
            *ptr = *ptr + bias;
            ptr++;
        }
    }
    return 0;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
3、总结

本次学习了NCNN中的batchnorm和bias操作,后续准备学习NCNN时不局限于学习NCNN中的NEON实现,还会关注其他的内容!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/257868
推荐阅读
相关标签
  

闽ICP备14008679号