赞
踩
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
代码如下:
# 1.导入模块---import import tensorflow as tf from tensorflow.keras import Model from tensorflow.keras.layers import Flatten, Dense from tensorflow.keras.preprocessing.image import ImageDataGenerator # 2.指定数据集和训练集---(x_train,y_train) mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train/255, x_test/255 # 给数据增加一个维度,使数据和网络结构匹配 x_train = x_train.reshape(x_train.shape[0], 28, 28, 1) # x_train.shape:(60000,28,28) x_test = x_test.reshape(x_test.shape[0], 28, 28, 1) # 数据增强函数参数设置----数据增强函数输入为4维,升维 image_gen_train = ImageDataGenerator(rescale=1./1, rotation_range=45, width_shift_range=.15, height_shift_range=.15, horizontal_flip=True, zoom_range=0.5 ) # 输入数据(图像像素)进行数据增强 image_gen_train.fit(x_train) # 3.搭建网络模型----class class MnistModel(Model): def __init__(self): super(MnistModel, self).__init__() self.flatten = Flatten() self.d1 = Dense(128, activation='relu') self.d2 = Dense(10, activation='softmax') def call(self, x): x = self.flatten(x) x = self.d1(x) y = self.d2(x) return y model = MnistModel() # 4。为网络模型配置训练方法----compile model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['sparse_categorical_accuracy'] ) # 5.网络模型传入数据集开始训练---fit model.fit(image_gen_train.flow(x_train, y_train, batch_size=32), epochs=5, validation_data=(x_test, y_test), validation_freq=1) # 6.打印网络莫模型和结构----summary model.summary()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。