当前位置:   article > 正文

卷积神经网络(CNN)讲解及代码_cnn代码

cnn代码

相关文章:
1. 经典反向传播算法公式详细推导
2. 卷积神经网络(CNN)反向传播算法公式详细推导

网上有很多关于CNN的教程讲解,在这里我们抛开长篇大论,只针对代码来谈。本文用的是matlab编写的deeplearning toolbox,包括NN、CNN、DBN、SAE、CAE。在这里我们感谢作者编写了这样一个简单易懂,适用于新手学习的代码。由于本文直接针对代码,这就要求读者有一定的CNN基础,可以参考Lecun的Gradient-Based Learning Applied to Document Recognitiontornadomeet的博文
首先把Toolbox下载下来,解压缩到某位置。然后打开Matlab,把文件夹内的util和data利用Set Path添加至路径中。接着打开tests文件夹的test_example_CNN.m。最后在文件夹CNN中运行该代码。

下面是test_example_CNN.m中的代码及注释,比较简单。

load mnist_uint8;  %读取数据

% 把图像的灰度值变成0~1,因为本代码采用的是sigmoid激活函数
train_x = double(reshape(train_x',28,28,60000))/255;
test_x = double(reshape(test_x',28,28,10000))/255;
train_y = double(train_y');
test_y = double(test_y');

%% 卷积网络的结构为 6c-2s-12c-2s 
% 1 epoch 会运行大约200s, 错误率大约为11%。而 100 epochs 的错误率大约为1.2%。

rand('state',0) %指定状态使每次运行产生的随机结果相同

cnn.layers = {
    struct('type', 'i') % 输入层
    struct('type', 'c', 'outputmaps', 6, 'kernelsize', 5) % 卷积层
    struct('type', 's', 'scale', 2) % pooling层
    struct('type', 'c', 'outputmaps', 12, 'kernelsize', 5) % 卷积层
    struct('type', 's', 'scale', 2) % pooling层
};


opts.alpha = 1;  % 梯度下降的步长
opts.batchsize = 50; % 每次批处理50张图
opts.numepochs = 1; % 所有图片循环处理一次

cnn = cnnsetup(cnn, train_x, train_y); % 初始化CNN
cnn = cnntrain(cnn, train_x, train_y, opts); % 训练CNN

[er, bad] = cnntest(cnn, test_x, test_y); % 测试CNN

%plot mean squared error
figure; plot(cnn.rL);
assert(er<0.12, 'Too big error');
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

下面是cnnsetup.m中的代码及注释。

function net = cnnsetup(net, x, y)
    assert(~isOctave() || compare_versions(OCTAVE_VERSION, '3.8.0', '>='), ['Octave 3.8.0 or greater is required for CNNs as there is a bug in convolution in previous versions. See http://savannah.gnu.org/bugs/?39314. Your version is ' myOctaveVersion]);  %判断版本
    inputmaps = 1;  % 由于网络的输入为1张特征图,因此inputmaps为1
    mapsize = size(squeeze(x(:, :, 1)));  %squeeze():除去x中为1的维度,即得到28*28

    for l = 1 : numel(
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Guff_9hys/article/detail/776287
推荐阅读
  

闽ICP备14008679号