当前位置:   article > 正文

猫狗识别基于tensorflow2.0 GPU版 自建CNN模型+数据增强+Dropout_gpu tensorflow cnn模型

gpu tensorflow cnn模型

猫狗识别基于tensorflow2.0 GPU版 自建CNN模型+数据增强+Dropout

1. 导入库

from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras import optimizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import matplotlib.pyplot as plt
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

2. 配置GPU

from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
# 定义TensorFlow配置
config = ConfigProto()
# 配置GPU内存分配方式,按需增长,很关键
config.gpu_options.allow_growth = True
# 在创建session的时候把config作为参数传进去
session = InteractiveSession(config=config)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

3. 图像数据加载和增强

#手动定义数据集目录
train_dir='./catdogdata/train'
validation_dir='./catdogdata/validation'
test_dir='./catdogdata/test'

train_datagen = ImageDataGenerator(rescale=1/255, 
                                   rotation_range=40, 
                                   width_shift_range=0.2, 
                                   height_shift_range=0.2,
                                   shear_range=0.2, 
                                   zoom_range=0.2, 
                                   horizontal_flip=True)
train_generator = train_datagen.flow_from_directory(train_dir, (224, 224), batch_size=100, class_mode='binary', shuffle=True)

validation_datagen = ImageDataGenerator(rescale=1/255)
validation_generator = validation_datagen.flow_from_directory(validation_dir, (224, 224), batch_size=100, class_mode='binary')

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
for data_batch, labels_batch in train_generator:
    print('data shape:', data_batch.shape)
    print('单张图像:\n')
    plt.imshow(data_batch[0])
    print('Batch:', labels_batch.shape, labels_batch)
    break
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

4. 建立模型

from tensorflow.keras.layers import Dropout

model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
#Dropout层
model.add(Dropout(0.5))

model.add(Dense(1, activation='sigmoid'))

model.summary()

model.compile(loss='binary_crossentropy', optimizer=optimizers.RMSprop(lr=1e-4), metrics=['acc'])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

在这里插入图片描述

5. 训练模型

from tensorflow.keras import callbacks

tensorboard_callback =callbacks.TensorBoard(log_dir="logs/Self_Built_CNN_w_Augment", histogram_freq=1)

history = model.fit_generator(train_generator, 
                              epochs=30,
                              validation_data=validation_generator,
                              callbacks=[tensorboard_callback])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

在这里插入图片描述

6. 模型预测

test_datagen = ImageDataGenerator(rescale=1/255)
test_generator = test_datagen.flow_from_directory(test_dir, (224, 224), batch_size=100, class_mode='binary')

model.evaluate_generator(test_generator)
  • 1
  • 2
  • 3
  • 4

在这里插入图片描述

7. 训练过程可视化

# 显示训练结果
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))
plt.figure(num=1)   # 正确率
plt.plot(epochs, acc, 'bo', label='train_acc')
plt.plot(epochs, val_acc, 'b', label='val_acc')
plt.title('accuracy')
plt.legend()
plt.savefig('acc_aug.png')
plt.figure(num=2)   # 损失函数
plt.plot(epochs, loss, 'bo', label='train_loss')
plt.plot(epochs, val_loss, 'b', label='val_loss')
plt.title('loss')
plt.legend()
plt.savefig('loss_aug.png')
plt.show()


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

在这里插入图片描述

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号