当前位置:   article > 正文

tiny-cnn开源库的使用(MNIST)_tinycnnmodel

tinycnnmodel

tiny-cnn是一个基于CNN的开源库,它的License是BSD 3-Clause。作者也一直在维护更新,对进一步掌握CNN很有帮助,因此下面介绍下tiny-cnn在windows7 64bit vs2013的编译及使用。

1.      从https://github.com/nyanp/tiny-cnn下载源码:

$ git clone https://github.com/nyanp/tiny-cnn.git  版本号为77d80a8,更新日期2016.01.22

2.      源文件中已经包含了vs2013工程,vc/tiny-cnn.sln,默认是win32的,examples/main.cpp需要OpenCV的支持,这里新建一个x64的控制台工程tiny-cnn;

3.      仿照源工程,将相应.h文件加入到新控制台工程中,新加一个test_tiny-cnn.cpp文件;

4.      将examples/mnist中test.cpp和train.cpp文件中的代码复制到test_tiny-cnn.cpp文件中;

  1. #include <iostream>
  2. #include <string>
  3. #include <vector>
  4. #include <algorithm>
  5. #include <tiny_cnn/tiny_cnn.h>
  6. #include <opencv2/opencv.hpp>
  7. using namespace tiny_cnn;
  8. using namespace tiny_cnn::activation;
  9. // rescale output to 0-100
  10. template <typename Activation>
  11. double rescale(double x)
  12. {
  13. Activation a;
  14. return 100.0 * (x - a.scale().first) / (a.scale().second - a.scale().first);
  15. }
  16. void construct_net(network<mse, adagrad>& nn);
  17. void train_lenet(std::string data_dir_path);
  18. // convert tiny_cnn::image to cv::Mat and resize
  19. cv::Mat image2mat(image<>& img);
  20. void convert_image(const std::string& imagefilename, double minv, double maxv, int w, int h, vec_t& data);
  21. void recognize(const std::string& dictionary, const std::string& filename, int target);
  22. int main()
  23. {
  24. //train
  25. std::string data_path = "D:/Download/MNIST";
  26. train_lenet(data_path);
  27. //test
  28. std::string model_path = "D:/Download/MNIST/LeNet-weights";
  29. std::string image_path = "D:/Download/MNIST/";
  30. int target[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
  31. for (int i = 0; i < 10; i++) {
  32. char ch[15];
  33. sprintf(ch, "%d", i);
  34. std::string str;
  35. str = std::string(ch);
  36. str += ".png";
  37. str = image_path + str;
  38. recognize(model_path, str, target[i]);
  39. }
  40. std::cout << "ok!" << std::endl;
  41. return 0;
  42. }
  43. void train_lenet(std::string data_dir_path) {
  44. // specify loss-function and learning strategy
  45. network<mse, adagrad> nn;
  46. construct_net(nn);
  47. std::cout << "load models..." << std::endl;
  48. // load MNIST dataset
  49. std::vector<label_t> train_labels, test_labels;
  50. std::vector<vec_t> train_images, test_images;
  51. parse_mnist_labels(data_dir_path + "/train-labels.idx1-ubyte",
  52. &train_labels);
  53. parse_mnist_images(data_dir_path + "/train-images.idx3-ubyte",
  54. &train_images, -1.0, 1.0, 2, 2);
  55. parse_mnist_labels(data_dir_path + "/t10k-labels.idx1-ubyte",
  56. &test_labels);
  57. parse_mnist_images(data_dir_path + "/t10k-images.idx3-ubyte",
  58. &test_images, -1.0, 1.0, 2, 2);
  59. std::cout << "start training" << std::endl;
  60. progress_display disp(train_images.size());
  61. timer t;
  62. int minibatch_size = 10;
  63. int num_epochs = 30;
  64. nn.optimizer().alpha *= std::sqrt(minibatch_size);
  65. // create callback
  66. auto on_enumerate_epoch = [&](){
  67. std::cout << t.elapsed() << "s elapsed." << std::endl;
  68. tiny_cnn::result res = nn.test(test_images, test_labels);
  69. std::cout << res.num_success << "/" << res.num_total << std::endl;
  70. disp.restart(train_images.size());
  71. t.restart();
  72. };
  73. auto on_enumerate_minibatch = [&](){
  74. disp += minibatch_size;
  75. };
  76. // training
  77. nn.train(train_images, train_labels, minibatch_size, num_epochs,
  78. on_enumerate_minibatch, on_enumerate_epoch);
  79. std::cout << "end training." << std::endl;
  80. // test and show results
  81. nn.test(test_images, test_labels).print_detail(std::cout);
  82. // save networks
  83. std::ofstream ofs("D:/Download/MNIST/LeNet-weights");
  84. ofs << nn;
  85. }
  86. void construct_net(network<mse, adagrad>& nn) {
  87. // connection table [Y.Lecun, 1998 Table.1]
  88. #define O true
  89. #define X false
  90. static const bool tbl[] = {
  91. O, X, X, X, O, O, O, X, X, O, O, O, O, X, O, O,
  92. O, O, X, X, X, O, O, O, X, X, O, O, O, O, X, O,
  93. O, O, O, X, X, X, O, O, O, X, X, O, X, O, O, O,
  94. X, O, O, O, X, X, O, O, O, O, X, X, O, X, O, O,
  95. X, X, O, O, O, X, X, O, O, O, O, X, O, O, X, O,
  96. X, X, X, O, O, O, X, X, O, O, O, O, X, O, O, O
  97. };
  98. #undef O
  99. #undef X
  100. // construct nets
  101. nn << convolutional_layer<tan_h>(32, 32, 5, 1, 6) // C1, 1@32x32-in, 6@28x28-out
  102. << average_pooling_layer<tan_h>(28, 28, 6, 2) // S2, 6@28x28-in, 6@14x14-out
  103. << convolutional_layer<tan_h>(14, 14, 5, 6, 16,
  104. connection_table(tbl, 6, 16)) // C3, 6@14x14-in, 16@10x10-in
  105. << average_pooling_layer<tan_h>(10, 10, 16, 2) // S4, 16@10x10-in, 16@5x5-out
  106. << convolutional_layer<tan_h>(5, 5, 5, 16, 120) // C5, 16@5x5-in, 120@1x1-out
  107. << fully_connected_layer<tan_h>(120, 10); // F6, 120-in, 10-out
  108. }
  109. void recognize(const std::string& dictionary, const std::string& filename, int target) {
  110. network<mse, adagrad> nn;
  111. construct_net(nn);
  112. // load nets
  113. std::ifstream ifs(dictionary.c_str());
  114. ifs >> nn;
  115. // convert imagefile to vec_t
  116. vec_t data;
  117. convert_image(filename, -1.0, 1.0, 32, 32, data);
  118. // recognize
  119. auto res = nn.predict(data);
  120. std::vector<std::pair<double, int> > scores;
  121. // sort & print top-3
  122. for (int i = 0; i < 10; i++)
  123. scores.emplace_back(rescale<tan_h>(res[i]), i);
  124. std::sort(scores.begin(), scores.end(), std::greater<std::pair<double, int>>());
  125. for (int i = 0; i < 3; i++)
  126. std::cout << scores[i].second << "," << scores[i].first << std::endl;
  127. std::cout << "the actual digit is: " << scores[0].second << ", correct digit is: "<<target<<std::endl;
  128. // visualize outputs of each layer
  129. //for (size_t i = 0; i < nn.depth(); i++) {
  130. // auto out_img = nn[i]->output_to_image();
  131. // cv::imshow("layer:" + std::to_string(i), image2mat(out_img));
  132. //}
  133. visualize filter shape of first convolutional layer
  134. //auto weight = nn.at<convolutional_layer<tan_h>>(0).weight_to_image();
  135. //cv::imshow("weights:", image2mat(weight));
  136. //cv::waitKey(0);
  137. }
  138. // convert tiny_cnn::image to cv::Mat and resize
  139. cv::Mat image2mat(image<>& img) {
  140. cv::Mat ori(img.height(), img.width(), CV_8U, &img.at(0, 0));
  141. cv::Mat resized;
  142. cv::resize(ori, resized, cv::Size(), 3, 3, cv::INTER_AREA);
  143. return resized;
  144. }
  145. void convert_image(const std::string& imagefilename,
  146. double minv,
  147. double maxv,
  148. int w,
  149. int h,
  150. vec_t& data) {
  151. auto img = cv::imread(imagefilename, cv::IMREAD_GRAYSCALE);
  152. if (img.data == nullptr) return; // cannot open, or it's not an image
  153. cv::Mat_<uint8_t> resized;
  154. cv::resize(img, resized, cv::Size(w, h));
  155. // mnist dataset is "white on black", so negate required
  156. std::transform(resized.begin(), resized.end(), std::back_inserter(data),
  157. [=](uint8_t c) { return (255 - c) * (maxv - minv) / 255.0 + minv; });
  158. }

5.      编译时会提示几个错误,解决方法是:

(1)、error C4996,解决方法:将宏_SCL_SECURE_NO_WARNINGS添加到属性的预处理器定义中;

(2)、调用for_函数时,error C2668,对重载函数的调用不明教,解决方法:将for_中的第三个参数强制转化为size_t类型;

6.      运行程序,train时,运行结果如下图所示:


7.      对生成的model进行测试,通过画图工具,每个数字生成一张图像,共10幅,如下图:


通过导入train时生成的model,对这10张图像进行识别,识别结果如下图,其中6和9被误识为5和1:


GitHub:https://github.com/fengbingchun/NN

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/805853
推荐阅读
  

闽ICP备14008679号