赞
踩
1.import模块
2.导入数据集,划分train,test
3.定义网络结构, class MyModel(Model):
4.model.compile(…)
5.断点续训、模型保存、model.fit(…)
6.可训练参数保存到txt文件中
7.acc/loss可视化
(注:断点续训、模型保存、可训练参数保存到txt文件中 这三步可省略)
以下以cifar10数据集为例进行演示
(cifar10数据集有5万张32*32像素点的彩色图片,用于训练
有1万张32*32像素点的彩色图片,用于测试)
import tensorflow as tf import os import numpy as np from matplotlib import pyplot as plt from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense from tensorflow.keras import Model np.set_printoptions(threshold=np.inf) cifar10 = tf.keras.datasets.cifar10 (x_train, y_train), (x_test, y_test) = cifar10.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 class Baseline(Model): def __init__(self): super(Baseline, self).__init__() self.c1 = Conv2D(filters=6, kernel_size=(5, 5), padding='same') # 卷积层 self.b1 = BatchNormalization() # BN层 self.a1 = Activation('relu') # 激活层 self.p1 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same') # 池化层 self.d1 = Dropout(0.2) # dropout层 self.flatten = Flatten() self.f1 = Dense(128, activation='relu') self.d2 = Dropout(0.2) self.f2 = Dense(10, activation='softmax') def call(self, x): x = self.c1(x) x = self.b1(x) x = self.a1(x) x = self.p1(x) x = self.d1(x) x = self.flatten(x) x = self.f1(x) x = self.d2(x) y = self.f2(x) return y model = Baseline() model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['sparse_categorical_accuracy']) # 断点续训、模型保存 checkpoint_save_path = "./checkpoint/Baseline.ckpt" if os.path.exists(checkpoint_save_path + '.index'): print('-------------load the model-----------------') model.load_weights(checkpoint_save_path) cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True) history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback]) model.summary() # print(model.trainable_variables) file = open('./weights.txt', 'w') for v in model.trainable_variables: file.write(str(v.name) + '\n') file.write(str(v.shape) + '\n') file.write(str(v.numpy()) + '\n') file.close() ############################################### show ############################################### # 显示训练集和验证集的acc和loss曲线 acc = history.history['sparse_categorical_accuracy'] val_acc = history.history['val_sparse_categorical_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss'] plt.subplot(1, 2, 1) plt.plot(acc, label='Training Accuracy') plt.plot(val_acc, label='Validation Accuracy') plt.title('Training and Validation Accuracy') plt.legend() plt.subplot(1, 2, 2) plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.title('Training and Validation Loss') plt.legend() plt.show()
以下为运行两次的结果
注:本文来自于中国大学mooc中北京大学的人工智能实践:Tensorflow笔记,在此感谢北大的曹健老师
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。