当前位置:   article > 正文

利用InceptionV3实现图像分类_inceptionv3图像分类

inceptionv3图像分类

最近在做一个机审的项目,初步希望实现图像的四分类,即:正常(neutral)、涉政(political)、涉黄(porn)、涉恐(terrorism)。有朋友给推荐了个github上面的文章,浏览量还挺大的。地址如下:

https://github.com/xqtbox/generalImageClassification

我导入试了一下,发现博主没有放他训练的模型文件my_model.h5,所以代码trainMyDataWithKerasModel.py不能直接运行。必须先自己训练个模型才行,所以只好自己搞了。我开发电脑上安装的python版本是3.9.12,这个版本通常会遇到兼容性的问题,所以我决定先搭建个虚拟环境来测试一下。虚拟环境就用3.7.16了。

1、执行:conda create -n InceptionV3 python=3.7

在C:\Users\用户名\anaconda3\envs目录下创建虚拟环境InceptionV3目录。

2、执行:conda activate InceptionV3

启动InceptionV3虚拟环境。

3、执行:pip install -i https://pypi.douban.com/simple/ tensorflow==1.14.0

我的显卡是Nvidia GeForce RTX 3060的,CUDA是11.8,Cudnn是8.7.0,查了一下对应的。查了一下对应tensorflow版本是1.14.0,所以就安装这个。

4、执行:pip install -i https://pypi.douban.com/simple/ protobuf==3.19.0

5、执行:pip install -i https://pypi.douban.com/simple/ tensorflow_hub==0.9.0

6、执行:pip install -i https://pypi.douban.com/simple/ opencv-python

7、执行:pip install -i https://pypi.douban.com/simple/ scikit-learn

8、执行:pip install -i https://pypi.douban.com/simple/ albumentations==1.2.0

9、执行:pip install -i https://pypi.douban.com/simple/ h5py==2.10.0

10、执行:pip install -i https://pypi.douban.com/simple/ matplotlib

11、执行:pip install -i https://pypi.douban.com/simple/ Tensorflow-gpu==2.4.0

12、执行:pip install -i https://pypi.douban.com/simple/ keras==2.6.0

13、下面是训练代码,文件名是train1.py

  1. import numpy as np
  2. from tensorflow.keras.optimizers import Adam
  3. import cv2
  4. from tensorflow.keras.preprocessing.image import img_to_array
  5. from sklearn.model_selection import train_test_split
  6. from tensorflow.python.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
  7. from tensorflow.keras.applications import InceptionV3
  8. import os
  9. import tensorflow as tf
  10. from tensorflow.python.keras.layers import Dense
  11. from tensorflow.python.keras.models import Sequential
  12. import albumentations
  13. norm_size = 224
  14. datapath = 'data/train'
  15. EPOCHS = 20
  16. INIT_LR = 3e-4
  17. labelList = []
  18. # 这里是分类详情
  19. dicClass = {'neutral':0, 'political':1, 'porn':2, 'terrorism':3}
  20. # 这是分类个数
  21. classnum = 4
  22. batch_size = 2
  23. np.random.seed(42)
  24. # tf.config.list_physical_devices('GPU')
  25. # tf.test.is_gpu_available()
  26. def loadImageData():
  27. imageList = []
  28. listClasses = os.listdir(datapath) # 类别文件夹
  29. print(listClasses)
  30. for class_name in listClasses:
  31. label_id = dicClass[class_name]
  32. class_path = os.path.join(datapath, class_name)
  33. image_names = os.listdir(class_path)
  34. for image_name in image_names:
  35. image_full_path = os.path.join(class_path, image_name)
  36. labelList.append(label_id)
  37. imageList.append(image_full_path)
  38. return imageList
  39. print("开始加载数据")
  40. imageArr = loadImageData()
  41. labelList = np.array(labelList)
  42. print("加载数据完成")
  43. print(labelList)
  44. trainX, valX, trainY, valY = train_test_split(imageArr, labelList, test_size=0.3, random_state=42)
  45. train_transform = albumentations.Compose([
  46. albumentations.OneOf([
  47. albumentations.RandomGamma(gamma_limit=(60, 120), p=0.9),
  48. albumentations.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.9),
  49. albumentations.CLAHE(clip_limit=4.0, tile_grid_size=(4, 4), p=0.9),
  50. ]),
  51. albumentations.HorizontalFlip(p=0.5),
  52. albumentations.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=20,
  53. interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_CONSTANT, p=1),
  54. albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0)
  55. ])
  56. val_transform = albumentations.Compose([
  57. albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0)
  58. ])
  59. def generator(file_pathList, labels, batch_size, train_action=False):
  60. L = len(file_pathList)
  61. while True:
  62. input_labels = []
  63. input_samples = []
  64. for row in range(0, batch_size):
  65. temp = np.random.randint(0, L)
  66. X = file_pathList[temp]
  67. Y = labels[temp]
  68. image = cv2.imdecode(np.fromfile(X, dtype=np.uint8), -1)
  69. if image.shape[2] > 3:
  70. image = image[:,:,:3]
  71. if train_action:
  72. image = train_transform(image=image)['image']
  73. else:
  74. image = val_transform(image=image)['image']
  75. image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
  76. image = img_to_array(image)
  77. input_samples.append(image)
  78. input_labels.append(Y)
  79. batch_x = np.asarray(input_samples)
  80. batch_y = np.asarray(input_labels)
  81. yield (batch_x, batch_y)
  82. checkpointer = ModelCheckpoint(filepath='best_model.hdf5',
  83. monitor='val_acc', verbose=1, save_best_only=True, mode='max')
  84. reduce = ReduceLROnPlateau(monitor='val_acc', patience=10,
  85. verbose=1,
  86. factor=0.5,
  87. min_lr=1e-6)
  88. model = Sequential()
  89. model.add(InceptionV3(include_top=False, pooling='avg', weights='imagenet'))
  90. model.add(Dense(classnum, activation='softmax'))
  91. optimizer = Adam(learning_rate=INIT_LR)
  92. model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['acc'])
  93. # print('trainX = ' + str(trainX))
  94. # print('trainY = ' + str(trainY))
  95. model.add(tf.keras.layers.BatchNormalization())
  96. history = model.fit(generator(trainX, trainY, batch_size, train_action=True),
  97. steps_per_epoch=len(trainX) / batch_size,
  98. validation_data=generator(valX, valY, batch_size, train_action=False),
  99. epochs=EPOCHS,
  100. validation_steps=len(valX) / batch_size,
  101. callbacks=[checkpointer, reduce])
  102. model.save('my_model.h5')
  103. print(history)
  104. loss_trend_graph_path = r"WW_loss.jpg"
  105. acc_trend_graph_path = r"WW_acc.jpg"
  106. import matplotlib.pyplot as plt
  107. print("Now,we start drawing the loss and acc trends graph...")
  108. # summarize history for acc
  109. fig = plt.figure(1)
  110. plt.plot(history.history["acc"])
  111. plt.plot(history.history["val_acc"])
  112. plt.title("Model acc")
  113. plt.ylabel("acc")
  114. plt.xlabel("epoch")
  115. plt.legend(["train", "test"], loc="upper left")
  116. plt.savefig(acc_trend_graph_path)
  117. plt.close(1)
  118. # summarize history for loss
  119. fig = plt.figure(2)
  120. plt.plot(history.history["loss"])
  121. plt.plot(history.history["val_loss"])
  122. plt.title("Model loss")
  123. plt.ylabel("loss")
  124. plt.xlabel("epoch")
  125. plt.legend(["train", "test"], loc="upper left")
  126. plt.savefig(loss_trend_graph_path)
  127. plt.close(2)
  128. print("We are done, everything seems OK...")

13.1、norm_size = 224 设置输入图像的大小,InceptionV3默认的图片尺寸是224×224。但是我的图片有300px以上的,好像也没什么问题

13.2、datapath = ‘data/train’ 设置图片存放的路径

13.3、EPOCHS = 20 epochs的数量,关于epoch的设置多少合适,这个问题很纠结,一般情况设置300足够了,如果感觉没有训练好,再载入模型训练。

13.4、INIT_LR = 1e-3 学习率,一般情况从0.001开始逐渐降低,也别太小了到1e-6就可以了。

13.5、classnum = 12 类别数量,数据集有两个类别,所有就分为两类。

13.6、batch_size = 4 batchsize,根据硬件的情况和数据集的大小设置,太小了loss浮动太大,太大了收敛不好,根据经验来,一般设置为2的次方。windows可以通过任务管理器查看显存的占用情况。

14、工程目录的文件如下图:

其中train1.py是训练程序;test.py是检测程序,本文后面会再详细讲怎么用;FormatImages.py是格式化图片的程序,功能就是把从网上爬下来比较大的图片等比压缩成300px以内。

data目录存放的就是训练用的数据,如下图:

其中train存放的是训练图片,test存放的是测试图片。train下的目录如下图:

可以看到,图中的train目录中的文件夹名要与train1.py中dicClass的值对应起来,训练数据放到对应目录下就可以了。如下图:

15、下面开始训练了,在训练之前有几个事情要做一下。

首先检查一下自己的cuda安装好没有,方法是在cmd下面输入命令nvcc -V,如果显示版本号就没问题了,如下图:

如果还没有安装也没关系,先看看自己显卡的cuda版本,如下图:

然后去https://developer.nvidia.com/cuda-toolkit-archive下载显卡对应版本的cuda工具包。如下图:

下载完成后安装到默认目录就行,一般是安装在C:\Program Files\NVIDIA GPU Computing Toolkit,如下图:

安装完成后在到https://developer.nvidia.com/rdp/cudnn-download去下载cudnn

下载完成后解压缩,把解压缩后的目录cudnn-windows-x86_64-8.8.0.121_cuda12-archive下的bin、include、lib三个目录里的文件分别复制到C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8的bin、include、lib三个目录里。如下图:

最后到https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#install-zlib-windows下载ZLIB.DLL。如下图:

下载完成后解压缩,把解压后zlib123dllx64\dll_x64\zlibwapi.dll文件复制到C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin目录下

现在,在train1.py目录下执行:python train1.py

可以看一下任务管理器,压力应该都在GPU上:

16、训练完成后,可以看到train1.py目录下多了几个文件,如下图:

其中my_model.h5就是咱们训练出来的模型文件。WW_acc.jpg和WW_loss.jpg是训练结果保存的图,看了一下觉得还不错。

17、接下来要验证一下模型的效果,现在data\test\放一张用于预测的图。如下图:

18、下面是测试代码,文件名是test.py:

  1. import cv2
  2. import numpy as np
  3. from tensorflow.keras.preprocessing.image import img_to_array
  4. from tensorflow.keras.models import load_model
  5. import time
  6. import albumentations
  7. norm_size = 224
  8. imagelist = []
  9. emotion_labels = {
  10. 0: 'neutral',
  11. 1: 'political',
  12. 2: 'porn',
  13. 3: 'terrorism',
  14. }
  15. val_transform = albumentations.Compose([
  16. albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0)
  17. ])
  18. emotion_classifier = load_model("best_model.hdf5")
  19. t1 = time.time()
  20. image = cv2.imdecode(np.fromfile('data/test/01.jpg', dtype=np.uint8), -1)
  21. image = val_transform(image=image)['image']
  22. image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
  23. image = img_to_array(image)
  24. imagelist.append(image)
  25. imageList = np.array(imagelist, dtype="float")
  26. out = emotion_classifier.predict(imageList)
  27. print(out)
  28. pre = np.argmax(out)
  29. emotion = emotion_labels[pre]
  30. t2 = time.time()
  31. print(emotion)
  32. t3 = t2 - t1
  33. print(t3)

其中emotion_labels是分类,填上与训练文件中一致的内容。

在image = cv2.imdecode(np.fromfile('data/test/01.jpg', dtype=np.uint8), -1)这行修改路径,指向到用于预测的图片位置。

19、执行python test.py

可以看到,data/test/01.jpg被预测成为terrorism,验证正确。至此大功告成。

后记:我是python的领域的新兵,在开发过程中遇到最麻烦的事情就是版本的问题。tensorflow最新版本已经2.11.0了,但是使用起来会有各种问题。我尝试了很多版本,查了不少资料,最后才确定了能用的这个组合。尤其是过程中gpu一直利用不上,程序总是使用cpu在训练,经过一顿折腾总算是能用了,但是为什么这么组合,我也没有找到一个清晰的说明,希望能有大神能给解释一下CUDA、Cudnn、tensorflow、tensorflow-gpu的版本怎么组合最合理。下面把我虚拟环境的配置发上来供大家参考:

  1. Package Version
  2. ----------------------- ---------
  3. absl-py 0.15.0
  4. albumentations 1.2.0
  5. astor 0.8.1
  6. astunparse 1.6.3
  7. cachetools 5.3.0
  8. certifi 2022.12.7
  9. charset-normalizer 3.0.1
  10. cycler 0.11.0
  11. flatbuffers 1.12
  12. fonttools 4.38.0
  13. gast 0.3.3
  14. google-auth 2.16.1
  15. google-auth-oauthlib 0.4.6
  16. google-pasta 0.2.0
  17. grpcio 1.32.0
  18. h5py 2.10.0
  19. idna 3.4
  20. imageio 2.25.1
  21. importlib-metadata 6.0.0
  22. joblib 1.2.0
  23. keras 2.6.0
  24. Keras-Applications 1.0.8
  25. Keras-Preprocessing 1.1.2
  26. kiwisolver 1.4.4
  27. Markdown 3.4.1
  28. MarkupSafe 2.1.2
  29. matplotlib 3.5.3
  30. networkx 2.6.3
  31. numpy 1.19.5
  32. oauthlib 3.2.2
  33. opencv-python 4.7.0.68
  34. opencv-python-headless 4.7.0.68
  35. opt-einsum 3.3.0
  36. packaging 23.0
  37. Pillow 9.4.0
  38. pip 22.3.1
  39. protobuf 3.19.0
  40. pyasn1 0.4.8
  41. pyasn1-modules 0.2.8
  42. pyparsing 3.0.9
  43. python-dateutil 2.8.2
  44. PyWavelets 1.3.0
  45. PyYAML 6.0
  46. qudida 0.0.4
  47. requests 2.28.2
  48. requests-oauthlib 1.3.1
  49. rsa 4.9
  50. scikit-image 0.18.3
  51. scikit-learn 1.0.2
  52. scipy 1.7.3
  53. setuptools 65.6.3
  54. six 1.15.0
  55. tensorboard 2.11.2
  56. tensorboard-data-server 0.6.1
  57. tensorboard-plugin-wit 1.8.1
  58. tensorflow 1.14.0
  59. tensorflow-estimator 2.4.0
  60. tensorflow-gpu 2.4.0
  61. tensorflow-hub 0.9.0
  62. termcolor 1.1.0
  63. threadpoolctl 3.1.0
  64. tifffile 2021.11.2
  65. typing-extensions 3.7.4.3
  66. urllib3 1.26.14
  67. Werkzeug 2.2.3
  68. wheel 0.38.4
  69. wincertstore 0.2
  70. wrapt 1.12.1
  71. zipp 3.14.0
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/681126
推荐阅读
相关标签
  

闽ICP备14008679号