赞
踩
本题爬取了校园常见的8种植物作为数据集,每种植物600张图片,其中500张作为训练集,100张作为测试集。
因为百度图片爬取下的的图片,有一部分是广告,有一部分跟真实的差别很大,所以这里做一个人工清理,删除这些图片。
这里使用了一个图片批量生成器,可以直接将文件中的图片,转换大小,并生成数据,也会生成对应的标签。
train_dir = 'D:\编程\数据集\训练集' val_dir = 'D:\编程\数据集\测试集' from keras.preprocessing.image import ImageDataGenerator train_datagen = ImageDataGenerator(rescale=1./255)#进行缩放 test_datagen = ImageDataGenerator(rescale=1./255) train_generator = train_datagen.flow_from_directory( train_dir,#训练集所在目录 target_size=(256,256),#将图片转换为目标大小 batch_size=20,#每一批的数量 class_mode='binary' # ) valid_generator = test_datagen.flow_from_directory( val_dir, target_size=(256,256), batch_size=20, class_mode = 'binary' )
输入的大小是256*256 的3通道数据,全连接层最后输出的大小为8,因为为8分类的任务。
model = tf.keras.models.Sequential([ # 第一层卷积层 tf.keras.layers.Conv2D(filters=6, kernel_size=(5,5), padding='valid', activation=tf.nn.relu, input_shape=(256,256,3)), # 第一池化层 tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2), padding='same'), # 第二卷积层 tf.keras.layers.Conv2D(filters=16, kernel_size=(5,5), padding='valid', activation=tf.nn.relu), # 第二池化层 tf.keras.layers.AveragePooling2D(pool_size=(2,2), strides=(2,2), padding='same'), # 扁平化层,将多维数据转换为一维数据。 tf.keras.layers.Flatten(), # 全连接层 tf.keras.layers.Dense(units=120, activation=tf.nn.relu), # 全连接层 tf.keras.layers.Dense(units=84, activation=tf.nn.relu), # 输出层,全连接 tf.keras.layers.Dense(units=8, activation=tf.nn.softmax) ]) model.summary()#打印网络结构
这里学习率为0.001
from tensorflow.keras import optimizers
model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy,
optimizer = optimizers.RMSprop(learning_rate=0.001),
metrics=['acc'])
#损失函数交叉熵损失函数,优化方法RMSprop,评价指标acc
for data_batch,labels_batch in train_generator:
print(data_batch.shape)
print(labels_batch)
break
image = data_batch[0]
image.shape
plt.imshow(image)
plt.show()
history = model.fit_generator(
train_generator,#第一个参数必须是python生成器
# steps_per_epoch=200,#批量数
epochs = 10,#迭代次数
validation_data = valid_generator,#待验证的数据集
validation_steps = 20
)
image = np.array(train_generator[0][0])
image[0]
pred = model.predict(image[6].reshape(1,256,256,3))
xx = {0: '桂树', 1: '垂柳',2:'红继木',3:'金心吊兰',4:'蜡梅',5:'南天竹',6:'珊瑚树',7:'栀子花'}
print(xx[pred.argmax()])
plt.imshow(image[6].reshape(256,256,3))
plt.show()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。