当前位置:   article > 正文

【tensorflow】 从 tfhub 加载模型训练图像分类模型_shiyong tensorflow hub moxing fenlei

shiyong tensorflow hub moxing fenlei

从 tfhub 加载预训练模型,从本地加载数据集,训练分类模型

import itertools
import os

import matplotlib.pylab as plt
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub

print("TF version:", tf.__version__)
print("Hub version:", hub.__version__)
print("GPU is", "available" if tf.test.is_gpu_available() else "NOT AVAILABLE")

os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 设置GPU
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" #按需分配显存,否则占满全部显存 

############################ settings ################################
BATCH_SIZE = 128
num_epochs = 20
#data_dir   = './head_1203_top_1w_dataset/train/'

train_data_dir = './head_1203_top_1w_dataset/train/'
valid_data_dir = './head_1203_top_1w_dataset/valid/'
test_data_dir  = './head_1203_top_1w_dataset/test/'

saved_model_path = "./models_head_1203_top_1w_dataset/r50_e20/"
if not os.path.exists(saved_model_path):
    os.mkdir(saved_model_path)
loss_acc_img_name = saved_model_path + 'acc_loss_r50_e20.jpg'

############################ pretrained model ################################   
Tfhub_Module = #'https://hub.tensorflow.google.cn/google/imagenet/resnet_v2_50/feature_vector/4'
'https://hub.tensorflow.google.cn/tensorflow/efficientnet/b4/feature-vector/1'
#IMAGE_SIZE   = (224, 224)
IMAGE_SIZE   = (380, 380)

############################ dataset ################################
# trainset
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        rescale=1./255,
        rotation_range=40,
        shear_range=0.2,
        zoom_range=0.2,
        width_shift_range=0.2, height_shift_range=0.2,
        horizontal_flip=True)
# flow_from_directory
# Takes the path to a directory & generates batches of augmented data.
train_generator = train_datagen.flow_from_directory(
        train_data_dir,
        #target_size=(img_height, img_width),
        target_size=IMAGE_SIZE,
        batch_size=BATCH_SIZE,
        shuffle=True,
        class_mode='categorical')

#validset       
valid_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
valid_generator = valid_datagen.flow_from_directory(
        valid_data_dir,
        #target_size=(img_height, img_width),
        target_size=IMAGE_SIZE,
        batch_size=BATCH_SIZE,
        shuffle=True,
        class_mode='categorical')# categorical: 2D one-hot,  One of "categorical", "binary", "sparse", "input", or None. Default: "categorical".
    
#testset
test_datagen  = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_directory(
        test_data_dir,
        #target_size=(img_height, img_width),
        target_size=IMAGE_SIZE,
        batch_size=BATCH_SIZE,
        shuffle=True,
        class_mode='categorical')
        
############################ create a model ################################ 
do_fine_tuning = False #@param {type:"boolean"}
model = tf.keras.Sequential([
    tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE + (3,)),
    hub.KerasLayer(Tfhub_Module, trainable=do_fine_tuning),
    tf.keras.layers.Dropout(rate=0.3),
    tf.keras.layers.Dense(train_generator.num_classes,
                          activation='softmax',
                          kernel_regularizer=tf.keras.regularizers.l2(0.001)
                          )
])
model.build((None,)+IMAGE_SIZE+(3,))

# continue train
#model = tf.keras.models.load_model(saved_model_path)
   
model.summary()

############################ complie the model ################################ 
model.compile(
  #optimizer=tf.keras.optimizers.SGD(lr=0.005, momentum=0.9), 
  optimizer=tf.keras.optimizers.Adam(lr=0.001), 
  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False, label_smoothing=0.2), # one-hot
  metrics=['accuracy'])

############################ train ################################
steps_per_epoch  = train_generator.samples // train_generator.batch_size
validation_steps = valid_generator.samples // valid_generator.batch_size
      
hist = model.fit(
    train_generator,
    epochs=num_epochs, steps_per_epoch=steps_per_epoch,
    validation_data=valid_generator,
    validation_steps=validation_steps).history
    #callbacks=[SaveBestModel],)

############################ save the model ################################
tf.saved_model.save(model, saved_model_path)

loss0, accuracy0 = model.evaluate(test_generator) # model evaluate
print("test_dataset loss: {:.2f}".format(loss0))
print("test_dataset accuracy: {:.2f}".format(accuracy0))
         
############################ draw the loss.jpg acc.jpg ################################
# Learning curves
acc     = hist['accuracy']
val_acc = hist['val_accuracy']

loss     = hist['loss']
val_loss = hist['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')
#plt.savefig('acc.jpg')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Loss')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
#plt.show()
plt.savefig(loss_acc_img_name)


############################ predict a img ################################
def get_class_string_from_index(index):
   for class_string, class_index in valid_generator.class_indices.items():
      if class_index == index:
         return class_string

x, y = next(valid_generator)
image = x[0, :, :, :]
true_index = np.argmax(y[0])
#plt.imshow(image)
#plt.axis('off')
#plt.show()

# Expand the validation image to (1, 224, 224, 3) before predicting the label
prediction_scores = model.predict(np.expand_dims(image, axis=0))
print('prediction_scores = ',prediction_scores)

predicted_index = np.argmax(prediction_scores)
print("True label: " + get_class_string_from_index(true_index))
print("Predicted label: " + get_class_string_from_index(predicted_index))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/532012
推荐阅读
相关标签
  

闽ICP备14008679号