当前位置:   article > 正文

Python使用tensorflow读取numpy数据训练DNN模型(手写体识别案例)_tensorflow训练使用npy数据

tensorflow训练使用npy数据

本文主要使用tensorflow、numpy、matplotlib、jupyternotebook进行训练
1.导入库

import numpy as np
import tensorflow as tf
  • 1
  • 2

2.从npz文件读取numpy数组

#numpy文件地址
filename="./datas/mnist/mnist.npz"
with np.load(filename) as data:
    train_examples=data['x_train']
    train_labels=data['y_train']
    test_examples=data['x_test']
    test_labels=data['y_test']
print(type(train_example),type(train_labels))
print(train_examples.ndim,train_labels.ndim)
print(train_examples.shape,train_labels.shape)
print(train_examples.dtype,train_labels.dtype)
train_examples[1]
train_examples[0].shape
train_labels[0]
import matplotlib.pyplot as plt
plt.imshow(train_examples[0])
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

在这里插入图片描述
3.加载Numpy数组到tf.data.Dataset
tf.data.Dataset.from_tensor_slices可以接收元祖,特征矩阵、标签向量,要求它们行数(样本数)相等,会按行匹配组合

train_dataset=tf.data.Dataset.from_tensor_slices((train_examples,train_labels))
test_dataset=tf.data.Dataset.from_tensor_slices((test_examples,test_labels))
#查看数据集的一个样本(这时包含了所有特征列、标签列)
train_dataset.as_numpy_iterator().next()

  • 1
  • 2
  • 3
  • 4
  • 5

4.打乱和批次化数据集

BATCH_SIZE=64
SHUFFLE_BUFFER_SIZE=100
shuffle_ds=train_dataset.shuffle(SHUFFLE_BUFFER_SIZE)
train_dataset=shuffle_ds.batch(BATCH_SIZE)
test_dataset=test_dataset.batch(BATCH_SIZE)
  • 1
  • 2
  • 3
  • 4
  • 5

5.建立和训练模型

#input_shape要省略输入数据的第一维度,(60000,28,28),只需要输入(28,28)
#这里的input_shape其实就等于train_examples.shape[1:]
first_layer=tf.keras.layers.Flatten(input_shape=(28,28))
#搭建模型
model=tf.keras.Sequential([
    first_layer,
    tf.keras.layers.Dense(128,activation='relu'),
    tf.keras.layers.Dense(10,activation='softmax')
])
#模型编译
model.compile(optimizer=tf.keras.optimizers.RMSprop(),
             loss=tf.keras.losses.SparseCategoricalCrossentropy(),
             metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
 #查看模型信息
model.summary()    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

在这里插入图片描述
6.训练和评估

model.fit(train_dataset,epochs=10)
  • 1

在这里插入图片描述
在这里插入图片描述

model.evaluate(test_dataset)
  • 1

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/123899
推荐阅读
相关标签
  

闽ICP备14008679号