当前位置:   article > 正文

基于keras与tensorflow手工实现ResNet50网络_keras resnet50

keras resnet50

前言

在文章 基于tensorflow的ResNet50V2网络识别动物,我们使用了keras已经提供的神经网络,完成了图像分类的。这个时候,小明同学就问了,那么我怎么自己去写一个神经网络来进行训练呢?
本文就基于tensorflow,自己定一个神经网络。

ResNet50网络

在这里插入图片描述
从结构上看,与我们之前的的区别在于,输入的格式变成(3,224,224)
ResNet50有两个基本的块,分别名为Conv BlockIdentity Block

整体架构

在这里插入图片描述

Conv Block架构

在这里插入图片描述

Identity Block架构

在这里插入图片描述

模型训练

手工实现模型

模型代码(resnet50.py)

# 根据模型进行引入
from keras import layers

from keras.layers import Input,Activation,BatchNormalization,Flatten
from keras.layers import Dense,Conv2D,MaxPooling2D,ZeroPadding2D,AveragePooling2D
from keras.models import Model

def identity_block(input_tensor, kernel_size, filters, stage, block):
    filters1, filters2, filters3 = filters

    name_base = str(stage) + block + '_identity_block_'

    x = Conv2D(filters1, (1, 1), name=name_base + 'conv1')(input_tensor)
    x = BatchNormalization(name=name_base + 'bn1')(x)
    x = Activation('relu', name=name_base + 'relu1')(x)

    x = Conv2D(filters2, kernel_size, padding='same', name=name_base + 'conv2')(x)
    x = BatchNormalization(name=name_base + 'bn2')(x)
    x = Activation('relu', name=name_base + 'relu2')(x)

    x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)
    x = BatchNormalization(name=name_base + 'bn3')(x)

    x = layers.add([x, input_tensor], name=name_base + 'add')
    x = Activation('relu', name=name_base + 'relu4')(x)
    return x

def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):

    filters1, filters2, filters3 = filters

    res_name_base = str(stage) + block + '_conv_block_res_'
    name_base = str(stage) + block + '_conv_block_'

    x = Conv2D(filters1, (1, 1), strides=strides, name=name_base + 'conv1')(input_tensor)
    x = BatchNormalization(name=name_base + 'bn1')(x)
    x = Activation('relu', name=name_base + 'relu1')(x)

    x = Conv2D(filters2, kernel_size, padding='same', name=name_base + 'conv2')(x)
    x = BatchNormalization(name=name_base + 'bn2')(x)
    x = Activation('relu', name=name_base + 'relu2')(x)

    x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)
    x = BatchNormalization(name=name_base + 'bn3')(x)

    shortcut = Conv2D(filters3, (1, 1), strides=strides, name=res_name_base + 'conv')(input_tensor)
    shortcut = BatchNormalization(name=res_name_base + 'bn')(shortcut)

    x = layers.add([x, shortcut], name=name_base+'add')
    x = Activation('relu', name=name_base+'relu4')(x)
    return x

def ResNet50(input_shape=[224,224,3],classes=1000):
    img_input = Input(shape=input_shape)
    x = ZeroPadding2D((3, 3))(img_input)

    x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1')(x)
    x = BatchNormalization(name='bn_conv1')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((3, 3), strides=(2, 2))(x)

    x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')

    x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')

    x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')

    x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')

    x = AveragePooling2D((7, 7), name='avg_pool')(x)

    x = Flatten()(x)
    x = Dense(classes, activation='softmax', name='fc1000')(x)

    model = Model(img_input, x, name='resnet50')
    return model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88

模型训练(resnet50_model_train.py)

基于tensorflow的ResNet50V2网络识别动物的模型训练代码进行少量改造

import os
import pandas as pd

# Model
import keras
from keras.preprocessing.image import ImageDataGenerator

# Callbacks
from keras.callbacks import EarlyStopping, ModelCheckpoint

# Pre-Trained Model
import tensorflow as tf
import resnet50

root_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Training Data/'
valid_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Validation Data/'
test_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Testing Data/'
class_names = sorted(os.listdir(root_path))
n_classes = len(class_names)

print(f"Total Number of Classes : {n_classes} \nClass Names : {class_names}")


class_dis = [len(os.listdir(root_path+name)) for name in class_names]


train_gen = ImageDataGenerator(rescale=1/255., rotation_range=10, horizontal_flip=True)
valid_gen = ImageDataGenerator(rescale=1/255.)
test_gen = ImageDataGenerator(rescale=1/255)

# Load Data
train_ds = train_gen.flow_from_directory(root_path, class_mode='binary', target_size=(224,224), shuffle=True, batch_size=32)
valid_ds = valid_gen.flow_from_directory(valid_path, class_mode='binary', target_size=(224,224), shuffle=True, batch_size=32)
test_ds = test_gen.flow_from_directory(test_path, class_mode='binary', target_size=(224,224), shuffle=True, batch_size=32)



with tf.device("/GPU:0"):
    ## Pre-Trained Model
    model = resnet50.ResNet50()
    model.summary()

    model_file = "ResNet50_V1.h5"
    # 加载预训练模型
    if os.path.exists(model_file):
        model.load_weights(model_file)

    # Callbacks
    cbs = [EarlyStopping(patience=5, restore_best_weights=True), ModelCheckpoint(model_file, save_best_only=True)]

    # Model
    opt = tf.keras.optimizers.Adam(learning_rate=2e-3)
    model.compile(loss='sparse_categorical_crossentropy', optimizer=opt, metrics=['accuracy'])

    # Model Training
    history = model.fit(train_ds, validation_data=valid_ds, callbacks=cbs, epochs=200)

    data = pd.DataFrame(history.history)
    print(data)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59

模型执行

GPU基本跑满
在这里插入图片描述

Model: "resnet50"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 zero_padding2d (ZeroPadding2D)  (None, 230, 230, 3)  0          ['input_1[0][0]']                
                                                                                                  
 conv1 (Conv2D)                 (None, 112, 112, 64  9472        ['zero_padding2d[0][0]']         
                                )                                                                 
                                                                                                  
 bn_conv1 (BatchNormalization)  (None, 112, 112, 64  256         ['conv1[0][0]']                  
                                )                                                                 
                                                                                                  
 activation (Activation)        (None, 112, 112, 64  0           ['bn_conv1[0][0]']               
                                )                                                                 
                                                                                                  
 max_pooling2d (MaxPooling2D)   (None, 55, 55, 64)   0           ['activation[0][0]']             
                                                                                                  
 2a_conv_block_conv1 (Conv2D)   (None, 55, 55, 64)   4160        ['max_pooling2d[0][0]']          
                                                                                                  
 2a_conv_block_bn1 (BatchNormal  (None, 55, 55, 64)  256         ['2a_conv_block_conv1[0][0]']    
 ization)                                                                                         
                                                                                                  
 2a_conv_block_relu1 (Activatio  (None, 55, 55, 64)  0           ['2a_conv_block_bn1[0][0]']      
 n)                                                                                               
                                                                                                  
 2a_conv_block_conv2 (Conv2D)   (None, 55, 55, 64)   36928       ['2a_conv_block_relu1[0][0]']    
                                                                                                  
 2a_conv_block_bn2 (BatchNormal  (None, 55, 55, 64)  256         ['2a_conv_block_conv2[0][0]']    
 ization)                                                                                         
                                                                                                  
 2a_conv_block_relu2 (Activatio  (None, 55, 55, 64)  0           ['2a_conv_block_bn2[0][0]']      
 n)                                                                                               
                                                                                                  
 2a_conv_block_conv3 (Conv2D)   (None, 55, 55, 256)  16640       ['2a_conv_block_relu2[0][0]']    
                                                                                                  
 2a_conv_block_res_conv (Conv2D  (None, 55, 55, 256)  16640      ['max_pooling2d[0][0]']          
 )                                                                                                
                                                                                                  
 2a_conv_block_bn3 (BatchNormal  (None, 55, 55, 256)  1024       ['2a_conv_block_conv3[0][0]']    
 ization)                                                                                         
                                                                                                  
 2a_conv_block_res_bn (BatchNor  (None, 55, 55, 256)  1024       ['2a_conv_block_res_conv[0][0]'] 
 malization)                                                                                      
                                                                                                  
 2a_conv_block_add (Add)        (None, 55, 55, 256)  0           ['2a_conv_block_bn3[0][0]',      
                                                                  '2a_conv_block_res_bn[0][0]']   
                                                                                                  
 2a_conv_block_relu4 (Activatio  (None, 55, 55, 256)  0          ['2a_conv_block_add[0][0]']      
 n)                                                                                               
                                                                                                  
 2b_identity_block_conv1 (Conv2  (None, 55, 55, 64)  16448       ['2a_conv_block_relu4[0][0]']    
 D)                                                                                               
                                                                                                  
 2b_identity_block_bn1 (BatchNo  (None, 55, 55, 64)  256         ['2b_identity_block_conv1[0][0]']
 rmalization)                                                                                     
                                                                                                  
 2b_identity_block_relu1 (Activ  (None, 55, 55, 64)  0           ['2b_identity_block_bn1[0][0]']  
 ation)                                                                                           
                                                                                                  
 2b_identity_block_conv2 (Conv2  (None, 55, 55, 64)  36928       ['2b_identity_block_relu1[0][0]']
 D)                                                                                               
           
*******略,大致意思是有50层,因为是ResNet50*********                                                                                         
                                                                                                  
 5c_identity_block_bn3 (BatchNo  (None, 7, 7, 2048)  8192        ['5c_identity_block_conv3[0][0]']
 rmalization)                                                                                     
                                                                                                  
 5c_identity_block_add (Add)    (None, 7, 7, 2048)   0           ['5c_identity_block_bn3[0][0]',  
                                                                  '5b_identity_block_relu4[0][0]']
                                                                                                  
 5c_identity_block_relu4 (Activ  (None, 7, 7, 2048)  0           ['5c_identity_block_add[0][0]']  
 ation)                                                                                           
                                                                                                  
 avg_pool (AveragePooling2D)    (None, 1, 1, 2048)   0           ['5c_identity_block_relu4[0][0]']
                                                                                                  
 flatten (Flatten)              (None, 2048)         0           ['avg_pool[0][0]']               
                                                                                                  
 fc1000 (Dense)                 (None, 1000)         2049000     ['flatten[0][0]']                
                                                                                                  
==================================================================================================
Total params: 25,636,712
Trainable params: 25,583,592
Non-trainable params: 53,120
__________________________________________________________________________________________________
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87

在这里插入图片描述

训练结果

Epoch 1/200
625/625 [==============================] - 245s 376ms/step - loss: 0.3225 - accuracy: 0.8965 - val_loss: 0.7825 - val_accuracy: 0.7670
Epoch 2/200
625/625 [==============================] - 226s 361ms/step - loss: 0.2998 - accuracy: 0.9018 - val_loss: 0.9110 - val_accuracy: 0.7060
Epoch 3/200
625/625 [==============================] - 223s 357ms/step - loss: 0.2644 - accuracy: 0.9113 - val_loss: 1.2283 - val_accuracy: 0.6760
Epoch 4/200
625/625 [==============================] - 223s 357ms/step - loss: 0.2465 - accuracy: 0.9188 - val_loss: 0.9871 - val_accuracy: 0.7500
Epoch 5/200
625/625 [==============================] - 225s 360ms/step - loss: 0.2307 - accuracy: 0.9234 - val_loss: 1.1059 - val_accuracy: 0.6720
Epoch 6/200
625/625 [==============================] - 228s 365ms/step - loss: 0.2016 - accuracy: 0.9341 - val_loss: 0.5819 - val_accuracy: 0.8370
Epoch 7/200
625/625 [==============================] - 227s 363ms/step - loss: 0.1859 - accuracy: 0.9380 - val_loss: 0.8662 - val_accuracy: 0.7740
Epoch 8/200
625/625 [==============================] - 223s 356ms/step - loss: 0.1732 - accuracy: 0.9419 - val_loss: 0.6927 - val_accuracy: 0.8130
Epoch 9/200
625/625 [==============================] - 221s 353ms/step - loss: 0.1631 - accuracy: 0.9446 - val_loss: 0.7033 - val_accuracy: 0.8090
Epoch 10/200
625/625 [==============================] - 220s 352ms/step - loss: 0.1416 - accuracy: 0.9528 - val_loss: 0.8072 - val_accuracy: 0.8130
Epoch 11/200
625/625 [==============================] - 221s 352ms/step - loss: 0.1403 - accuracy: 0.9524 - val_loss: 0.9740 - val_accuracy: 0.7570
        loss  accuracy  val_loss  val_accuracy
0   0.322520   0.89655  0.782460         0.767
1   0.299801   0.90180  0.910993         0.706
2   0.264406   0.91130  1.228291         0.676
3   0.246451   0.91880  0.987062         0.750
4   0.230691   0.92340  1.105917         0.672
5   0.201550   0.93405  0.581927         0.837
6   0.185873   0.93800  0.866195         0.774
7   0.173195   0.94185  0.692714         0.813
8   0.163083   0.94460  0.703307         0.809
9   0.141572   0.95280  0.807219         0.813
10  0.140329   0.95240  0.974021         0.757
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

训练结果

从输出的训练结果来看,效果没有ResNet50V2的,对参数进行了一些调整,没有太多的效果。各位如果有兴趣可以对这样的网络进行修改,从而提升验证的正确性。

模型验证

模型验证代码

from keras.models import load_model
import tensorflow as tf
from tensorflow.keras.utils import load_img, img_to_array
import numpy as np
import os

import matplotlib.pyplot as plt

root_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Training Data/'

class_names = sorted(os.listdir(root_path))

model = load_model('./ResNet50_V1.h5')
model.summary()

def load_image(path):
    '''This function will load the image present at the given location'''
    image = tf.cast(tf.image.resize(img_to_array(load_img(path))/255., (224,224)), tf.float32)
    #image = tf.cast(tf.image.resize(img_to_array(load_img(path)) / 255., (224, 224)), tf.float32)
    return image

i_path = './animal/Animals_Classification/Animal-Data-V2/Data-V2/Validation Data/Gorilla/Gorilla (3).jpeg'
image = load_image(i_path)
preds = model.predict(image[np.newaxis, ...])[0]

print(preds)

pred_class = class_names[np.argmax(preds)]

confidence_score = np.round(preds[np.argmax(preds)], 2)

# Configure Title
title = f"Pred : {pred_class}\nConfidence : {confidence_score:.2}"
print(title)

plt.figure(figsize=(25, 8))
plt.title(title)
plt.imshow(image)
plt.show()

while True:
    path =  input("input:")
    if (path == "q!"):
        exit()
    image = load_image(path)

    preds = model.predict(image[np.newaxis, ...])[0]
    print(preds)

    pred_class = class_names[np.argmax(preds)]

    confidence_score = np.round(preds[np.argmax(preds)], 2)

    # Configure Title
    title = f"Pred : {pred_class}\nConfidence : {confidence_score:.2}"
    print(title)

    plt.figure(figsize=(25, 8))
    plt.title(title)
    plt.imshow(image)
    plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/92803?site
推荐阅读
相关标签
  

闽ICP备14008679号