赞
踩
导入模块
- import os
- import numpy as np
- import tensorflow as tf
- import random
- import seaborn as sns
- import matplotlib.pyplot as plt
-
- from keras.models import Sequential, Model
- from keras.layers import Dense, Dropout, Activation, Flatten, Input
- from keras.layers.convolutional import Conv2D, MaxPooling2D
- from keras.optimizers import RMSprop, Adam, SGD
- from keras.preprocessing import image
- from keras.preprocessing.image import ImageDataGenerator
- from keras.utils import np_utils
- from keras.applications.vgg16 import VGG16, preprocess_input
- from sklearn.model_selection import train_test_split
读取图片函数
- def read_and_process_image(data_dir,width=32, height=32, channels=3, preprocess=False):
-
- train_classes= [data_dir + i for i in os.listdir(data_dir) ]
- train_images = []
- for train_class in train_classes:
- train_images= train_images + [train_class + "/" + i for i in os.listdir(train_class)]
-
- random.shuffle(train_images)
-
- def read_image(file_path, preprocess):
- img = image.load_img(file_path, target_size=(height, width))
- x = image.img_to_array(img)
- x = np.expand_dims(x, axis=0)
-
- if preprocess:
- x = preprocess_input(x)
- return x
-
- def prep_data(images, proprocess):
- count = len(images)
- data = np.ndarray((count, height, width, channels), dtype = np.float32)
-
- for i, image_file in enumerate(images):
- image = read_image(image_file, preprocess)
- data[i] = image
-
- return data
-
- def read_labels(file_path):
- labels = []
- for i in file_path:
- if 'airplane' in i:
- label = 0
- elif 'automobile' in i:
- label = 1
- elif 'bird' in i:
- label = 2
- elif 'cat' in i:
- label = 3
- elif 'deer' in i:
- label = 4
- elif 'dog' in i:
- label = 5
- elif 'frog' in i:
- label = 6
- elif 'horse' in i:
- label = 7
- elif 'ship' in i:
- label = 8
- elif 'truck' in i:
- label = 9
- labels.append(label)
-
- return labels
-
- X = prep_data(train_images, preprocess)
- labels = read_labels(train_images)
-
- assert X.shape[0] == len(labels)
-
- print("Train shape: {}".format(X.shape))
-
- return X, labels
读取训练集测试集
- # 读取训练集图片
- WIDTH = 48
- HEIGHT = 48
- CHANNELS = 3
- X, y = read_and_process_image('D:/Python Project/cifar-10/train/',width=WIDTH, height=HEIGHT, channels=CHANNELS)
- test_X, test_y = read_and_process_image('D:/Python Project/cifar-10/test/',width=WIDTH, height=HEIGHT, channels=CHANNELS)
查看训练集和测试集类型
- # 统计y
- sns.countplot(y)
- sns.countplot(test_y)
将训练集和测试集标签one-hot编码
- train_y = np_utils.to_categorical(y)
- test_y = np_utils.to_categorical(test_y)
查看几张图片
- # 显示图片
- def show_picture(X, idx):
- plt.figure(figsize=(10,5), frameon=True)
- img = X[idx,:,:,::-1]
- img = img/255
- plt.imshow(img)
- plt.show()
-
- for idx in range(0,3):
- show_picture(X, idx)
定义外接vgg网络
- def vgg16_model(input_shape= (HEIGHT,WIDTH,CHANNELS)):
- vgg16 = VGG16(include_top=False, weights='imagenet',input_shape=input_shape)
-
- for layer in vgg16.layers:
- layer.trainable = False
- last = vgg16.output
- # 后面加入自己的模型
- x = Flatten()(last)
- x = Dense(256, activation='relu')(x)
- x = Dropout(0.5)(x)
- x = Dense(256, activation='relu')(x)
- x = Dropout(0.5)(x)
- x = Dense(num_classes, activation='softmax')(x)
-
- model = Model(inputs=vgg16.input, outputs=x)
-
- return model
创建模型
- model_vgg16 = vgg16_model()
- model_vgg16.summary()
- model_vgg16.compile(loss='categorical_crossentropy',optimizer = Adam(0.0001), metrics = ['accuracy'])
训练模型
- history = model_vgg16.fit(X,train_y, validation_data=(test_X, test_y),epochs=20,batch_size=100,verbose=True)
- score = model_vgg16.evaluate(test_X, test_y, verbose=0)
- print("Large CNN Error: %.2f%%" %(100-score[1]*100))
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。