赞
踩
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
from tensorflow import keras
def preprocess(x, y):
"""
x is a simple image, not a batch
"""
x = tf.cast(x, dtype=tf.float32) / 255.
x = tf.reshape(x, [28*28])
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x,y
自定义Layer
class MyDense(layers.Layer):
def __init__(self, inp_dim, outp_dim):
super(MyDense, self).__init__()
self.kernel = self.add_weight('w', [inp_dim, outp_dim])
self.bias = self.add_weight('b', [outp_dim])
def call(self, inputs, training=None):
out = inputs @ self.kernel + self.bias
return out
自定义Model
继承keras.Model可使用complie fit 等方法
class MyModel(keras.Model): def __init__(self): super(MyModel, self).__init__() # 自定义5层网络 self.fc1 = MyDense(28*28, 256) self.fc2 = MyDense(256, 128) self.fc3 = MyDense(128, 64) self.fc4 = MyDense(64, 32) self.fc5 = MyDense(32, 10) def call(self, inputs, training=None): # 定义网络传播过程 x = self.fc1(inputs) x = tf.nn.relu(x) x = self.fc2(x) x = tf.nn.relu(x) x = self.fc3(x) x = tf.nn.relu(x) x = self.fc4(x) x = tf.nn.relu(x) x = self.fc5(x) return x
batchsz = 128
(x, y), (x_val, y_val) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max())
db = tf.data.Dataset.from_tensor_slices((x,y))
db = db.map(preprocess).shuffle(60000).batch(batchsz)
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(preprocess).batch(batchsz)
sample = next(iter(db))
print(sample[0].shape, sample[1].shape)
datasets: (60000, 28, 28) (60000,) 0 255
(128, 784) (128, 10)
创建模型
summary()必须在fit 或 build之后使用
network = MyModel()
network.build(input_shape=(None, 28*28))
# summary()必须在fit 或 build之后使用
network.summary()
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
network.fit(db, epochs=5, validation_data=ds_val,
validation_freq=2)
network.evaluate(ds_val)
Model: "my_model_8"_________________________________________________________________Layer (type) Output Shape Param # =================================================================my_dense_40 (MyDense) multiple 200960 _________________________________________________________________my_dense_41 (MyDense) multiple 32896 _________________________________________________________________my_dense_42 (MyDense) multiple 8256 _________________________________________________________________my_dense_43 (MyDense) multiple 2080 _________________________________________________________________my_dense_44 (MyDense) multiple 330 =================================================================Total params: 244,522Trainable params: 244,522Non-trainable params: 0_________________________________________________________________Epoch 1/5469/469 [==============================] - 10s 22ms/step - loss: 0.3078 - accuracy: 0.9076Epoch 2/5469/469 [==============================] - 13s 28ms/step - loss: 0.1409 - accuracy: 0.9600 - val_loss: 0.1318 - val_accuracy: 0.9641Epoch 3/5469/469 [==============================] - 13s 27ms/step - loss: 0.1125 - accuracy: 0.9680Epoch 4/5469/469 [==============================] - 16s 35ms/step - loss: 0.0984 - accuracy: 0.9724 - val_loss: 0.1196 - val_accuracy: 0.9673Epoch 5/5469/469 [==============================] - 15s 32ms/step - loss: 0.0875 - accuracy: 0.976079/79 [==============================] - 3s 33ms/step - loss: 0.1224 - accuracy: 0.9704
[0.12237951139141393, 0.9704]
模型预测
sample = next(iter(ds_val))x = sample[0]y = sample[1] # one-hotpred = network.predict(x) # [b, 10]# convert back to number y = tf.argmax(y, axis=1)pred = tf.argmax(pred, axis=1)print(pred)print(y)
tf.Tensor([7 2 1 0 4 1 4 9 6 9 0 6 9 0 1 5 9 7 3 4 9 6 6 5 4 0 7 4 0 1 3 1 3 4 7 2 7 1 2 1 1 7 4 2 3 5 1 2 4 4 6 3 5 5 6 0 4 1 9 5 7 8 9 3 7 4 6 4 3 0 7 0 2 9 1 7 3 2 9 7 7 6 2 7 8 4 7 3 6 1 3 6 9 3 1 4 1 7 6 9 6 0 5 4 9 9 2 1 9 4 8 7 3 9 7 9 4 4 9 2 5 4 7 6 7 9 0 5], shape=(128,), dtype=int64)tf.Tensor([7 2 1 0 4 1 4 9 5 9 0 6 9 0 1 5 9 7 3 4 9 6 6 5 4 0 7 4 0 1 3 1 3 4 7 2 7 1 2 1 1 7 4 2 3 5 1 2 4 4 6 3 5 5 6 0 4 1 9 5 7 8 9 3 7 4 6 4 3 0 7 0 2 9 1 7 3 2 9 7 7 6 2 7 8 4 7 3 6 1 3 6 9 3 1 4 1 7 6 9 6 0 5 4 9 9 2 1 9 4 8 7 3 9 7 4 4 4 9 2 5 4 7 6 7 9 0 5], shape=(128,), dtype=int64)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。