当前位置:   article > 正文

keras 利用vgg16进行十分类模型训练_10分类模型训练

10分类模型训练

导入模块

  1. import os
  2. import numpy as np
  3. import tensorflow as tf
  4. import random
  5. import seaborn as sns
  6. import matplotlib.pyplot as plt
  7. from keras.models import Sequential, Model
  8. from keras.layers import Dense, Dropout, Activation, Flatten, Input
  9. from keras.layers.convolutional import Conv2D, MaxPooling2D
  10. from keras.optimizers import RMSprop, Adam, SGD
  11. from keras.preprocessing import image
  12. from keras.preprocessing.image import ImageDataGenerator
  13. from keras.utils import np_utils
  14. from keras.applications.vgg16 import VGG16, preprocess_input
  15. from sklearn.model_selection import train_test_split

读取图片函数

  1. def read_and_process_image(data_dir,width=32, height=32, channels=3, preprocess=False):
  2. train_classes= [data_dir + i for i in os.listdir(data_dir) ]
  3. train_images = []
  4. for train_class in train_classes:
  5. train_images= train_images + [train_class + "/" + i for i in os.listdir(train_class)]
  6. random.shuffle(train_images)
  7. def read_image(file_path, preprocess):
  8. img = image.load_img(file_path, target_size=(height, width))
  9. x = image.img_to_array(img)
  10. x = np.expand_dims(x, axis=0)
  11. if preprocess:
  12. x = preprocess_input(x)
  13. return x
  14. def prep_data(images, proprocess):
  15. count = len(images)
  16. data = np.ndarray((count, height, width, channels), dtype = np.float32)
  17. for i, image_file in enumerate(images):
  18. image = read_image(image_file, preprocess)
  19. data[i] = image
  20. return data
  21. def read_labels(file_path):
  22. labels = []
  23. for i in file_path:
  24. if 'airplane' in i:
  25. label = 0
  26. elif 'automobile' in i:
  27. label = 1
  28. elif 'bird' in i:
  29. label = 2
  30. elif 'cat' in i:
  31. label = 3
  32. elif 'deer' in i:
  33. label = 4
  34. elif 'dog' in i:
  35. label = 5
  36. elif 'frog' in i:
  37. label = 6
  38. elif 'horse' in i:
  39. label = 7
  40. elif 'ship' in i:
  41. label = 8
  42. elif 'truck' in i:
  43. label = 9
  44. labels.append(label)
  45. return labels
  46. X = prep_data(train_images, preprocess)
  47. labels = read_labels(train_images)
  48. assert X.shape[0] == len(labels)
  49. print("Train shape: {}".format(X.shape))
  50. return X, labels

读取训练集测试集

  1. # 读取训练集图片
  2. WIDTH = 48
  3. HEIGHT = 48
  4. CHANNELS = 3
  5. X, y = read_and_process_image('D:/Python Project/cifar-10/train/',width=WIDTH, height=HEIGHT, channels=CHANNELS)
  6. test_X, test_y = read_and_process_image('D:/Python Project/cifar-10/test/',width=WIDTH, height=HEIGHT, channels=CHANNELS)

查看训练集和测试集类型

  1. # 统计y
  2. sns.countplot(y)
  3. sns.countplot(test_y)

将训练集和测试集标签one-hot编码

  1. train_y = np_utils.to_categorical(y)
  2. test_y = np_utils.to_categorical(test_y)

查看几张图片

  1. # 显示图片
  2. def show_picture(X, idx):
  3. plt.figure(figsize=(10,5), frameon=True)
  4. img = X[idx,:,:,::-1]
  5. img = img/255
  6. plt.imshow(img)
  7. plt.show()
  8. for idx in range(0,3):
  9. show_picture(X, idx)

定义外接vgg网络

  1. def vgg16_model(input_shape= (HEIGHT,WIDTH,CHANNELS)):
  2. vgg16 = VGG16(include_top=False, weights='imagenet',input_shape=input_shape)
  3. for layer in vgg16.layers:
  4. layer.trainable = False
  5. last = vgg16.output
  6. # 后面加入自己的模型
  7. x = Flatten()(last)
  8. x = Dense(256, activation='relu')(x)
  9. x = Dropout(0.5)(x)
  10. x = Dense(256, activation='relu')(x)
  11. x = Dropout(0.5)(x)
  12. x = Dense(num_classes, activation='softmax')(x)
  13. model = Model(inputs=vgg16.input, outputs=x)
  14. return model

创建模型

  1. model_vgg16 = vgg16_model()
  2. model_vgg16.summary()
  3. model_vgg16.compile(loss='categorical_crossentropy',optimizer = Adam(0.0001), metrics = ['accuracy'])

训练模型

  1. history = model_vgg16.fit(X,train_y, validation_data=(test_X, test_y),epochs=20,batch_size=100,verbose=True)
  2. score = model_vgg16.evaluate(test_X, test_y, verbose=0)
  3. print("Large CNN Error: %.2f%%" %(100-score[1]*100))

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/103923
推荐阅读
相关标签
  

闽ICP备14008679号