当前位置:   article > 正文

AlexNet结构 及 pytorch、tensorflow、keras、paddle实现ImageNet识别_keras读取imagenet文件中的图片

keras读取imagenet文件中的图片

背景

AlexNet网络是 Hinton及其学生Alex Krizhevsky在ImageNet ILSVRC-2012竞赛中在 ILSVRC-2010数据上的的冠军网络,论文 "ImageNet Classifification with Deep Convolutional Neural Networks",该网络在大规模对象识别上取得的成功也掀起了深度学习的热潮。

AlexNet亮点:使用ReLU作为激活函数,提出LRN机制,Dropout随机失活,重叠Pooling,CUDA并行训练及数据增强

 

网络结构

原文中使用了3维卷积核,卷积核维度增加

原作者使用双GPU进行并行参数训练,将特征图从通道数层面进行分离,故以下通道数是上图的并行通道数之和

 

input layer:224*224*3 images

conv1 layer:11*11*3(3维卷积核)*  96, 4 conv kernels       55*55*96  ouput ( (224 + 0 - 11) / 4 +1 = 54.25 )

pool1 layer:3*3, 2 overlap maxpool                                   27*27*96  output ( (55 - 3) / 2 + 1 = 27 )

conv2 layer:5*5*48*  256 conv kernels(2pad)                   27*27*256 output ( 27 + 2*2 - 5  + 1 = 27 )  

pool2 layer:3*3, 2 overlap maxpool                                   13*13*256 output ( (27 - 3) / 2 + 1 =13 ) 

conv3 layer:3*3*256*  384 conv kernels(1pad)                 13*13*384 output ( 13 + 1*2 - 3  + 1 = 13 )  

conv4 layer:3*3*192*  384 conv kernels(1pad)                 13*13*384 output ( 13 + 1*2 - 3  + 1 = 13 )  

conv5 layer:3*3*192*  256 conv kernels(1pad)                 13*13*256 output ( 13 + 1*2 - 3  + 1 = 13 )  

pool3 layer:3*3, 2 overlap maxpool                                   6*6*256 output ( (13 - 3) / 2 + 1 =6 )  

fc1 layer: 4096 output( 13*13*256 --> 4096 )

fc2 layer: 4096 output( 4096 --> 4096 )

fc3 layer: 1000 output( 4096 --> 1000 )

 

代码:

pytorch实现

tensorflow实现

keras实现

paddle实现

 

注:

以上代码在alexnet的基础上,实现了:

1)调用框架api读取数据集

2)进行train、val的流程

3)在train时可以输出各层shape

4)保存最优loss模型,并在结束时输出最优loss及对应epoch

5)在训练结束后查看loss、acc变化曲线

 

源网络使用224作为输入,这里使用227作为输入;源网络使用3维卷积,这里仍使用2维卷积

 

源数据ImageNet LSVRC-2010,1000类,120万张训练图片、5万测试、15万验证

mini-imagenet数据:来自:https://blog.csdn.net/weixin_41803874/article/details/92068250

实验数据:从mini-imagenet随机选取10类,每类随机选取100张图片,共1000张图片作为数据集

实验数据在resnet18(pretrained=True)条件下进行迁移学习,训练2epoch可以val达到0.9准确率,在resnet18(pretrained=False)条件下训练,100epoch val acc始终处于0.35acc,无法收敛(尝试调整学习率无果),大致说明数据可以收敛但从零训练效果差;在实现的alexnet代码中,pytorch版本在val acc达到0.30左右后开始减小停止收敛,tensorflow、keras版本无法收敛,paddle版本可以收敛到0.60左右

 

 

文件结构:创建my_utils.py文件存放通用函数

 

从mini-imagenet中提取10*100数据(非必要)

  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/1/21 15:20
  3. # @Author : Zhao HL
  4. # @File : data_process.py
  5. import os,random,shutil
  6. import numpy as np
  7. import pandas as pd
  8. #region 类名转换文档
  9. cls_dict_path = r'D:\__Download\百度\caffe_ilsvrc12\synset_words.txt'
  10. # endregion
  11. #region mimi 数据集文档及信息
  12. # csv_path = r'D:\__Download\百度\mini-imagenet\test.csv'
  13. # csv_path = r'D:\__Download\百度\mini-imagenet\train.csv'
  14. csv_path = r'D:\__Download\百度\mini-imagenet\val.csv'
  15. src_data_path = r'D:\__Download\百度\mini-imagenet\images'
  16. '''
  17. train.csv contain 38400 records, 64 classes
  18. test.csv contain 12000 records, 20 classes
  19. val.csv contain 9600 records, 16 classes
  20. '''
  21. # endregion
  22. # region 目标文件
  23. dst_data_path = r'D:\__Download\百度\my_imagenet'
  24. dst_csv_path = r'D:\__Download\百度\my_imagenet.csv'
  25. # endregion
  26. def get_csvInfo():
  27. df = pd.read_csv(csv_path)
  28. total_num = len(df)
  29. class_num = len(df['label'].unique())
  30. print('{} contain {} records, {} classes '.format(os.path.basename(csv_path),total_num,class_num))
  31. def Extract_Image():
  32. # 从val文件中选取10个类,每个类选取100样本
  33. df = pd.read_csv(csv_path)
  34. cls = df['label'].unique()
  35. dst_cls = random.sample(list(cls),10)
  36. df_list = []
  37. for cls in dst_cls:
  38. print('cls {} :'.format(cls))
  39. df_cls = df[df['label']==cls]
  40. dst_df_cls = df_cls.sample(100)
  41. df_list.append(dst_df_cls)
  42. dst_df = pd.concat(df_list,ignore_index=True)
  43. for i,filename in enumerate(dst_df['filename']):
  44. src_path = os.path.join(src_data_path,filename)
  45. dst_path = os.path.join(dst_data_path,filename)
  46. shutil.copy(src_path,dst_path)
  47. print(i,filename)
  48. dst_df.to_csv(dst_csv_path)
  49. if __name__ == '__main__':
  50. pass
  51. # get_csvInfo()
  52. Extract_Image()

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/546195
推荐阅读
相关标签
  

闽ICP备14008679号