当前位置:   article > 正文

MNIST手写数字识别准确度提升最全、最实用的方法_minist数据集手写图片由模糊变清晰

minist数据集手写图片由模糊变清晰

MNIST手写数字识别是所有学习AI同学的入门必经过程,MNIST识别准确率提升修炼是精通AI模型的必经课程,MNIST识别准确率开刚始大家一般都能达到90%左右,再往上提高还需要费较大的精力去修改模型、调优参数,MNIST识别率究竟能达到多少,对于初学者还是很难搞清楚,刚开始也没有经验去提升得很高,我在第一遍学习时,通过参数和训练次数调整,利用了很多模型,达到了99.3%的精度,再往上提升时当时那台电脑的计算能力不够,也没有找到新的模型,没有再做研究了,就学习其它内容去了。

上个月旧电脑出现故障,不能启动,硬盘使用了Bitlocker,数据也不能恢复,学习代码也丢了,这次更换了一台性能非常好的电脑后重新学习,就把MNIST的Lenet-5算法重新学习了一遍,结合Kaggle网站上的方案,做了个模型和方案,最终把MNIST识别准确度提升到了99.96%的水平,据了解由于MNIST数据本身有缺陷,除非将评估数据集加入训练准确率应该是达不到100%的。我想能达到3个9的准度也应该足够了,特将实现方案总结出来,分享新学习的同学们,希望对大家有所帮助。

一、Lenet-5通用模型方案

下面方案是Lenet-5最通用的方案,初始是用3X3卷积盒,模型使用Sequential构建,训练也是用自己写的循环代码一步一步构建,特别适合于新手引用。此外,该案例还有日志保存、模型可在代码。

"""
LetNet-5 实战1:网上使用最多的模型, 测试用例精确度能达到99%
"""
import datetime
import tensorflow as tf
from tensorflow.keras import Sequential, layers, losses, datasets


def preprocess(x, y):
    """
    预处理函数
    """
    # [b, 28, 28], [b]
    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x, y


# 设置LOG与模型保存相关参数
current_time = datetime.datetime.now().strftime(('%Y%m%d-%H%M%S'))
log_dir = 'logs/' + current_time
summary_writer = tf.summary.create_file_writer(log_dir)
(x, y), (x_test, y_test) = datasets.mnist.load_data()  # 加载手写数据集数据
batchsz = 128   # 此模型下batch size在128下比较好
train_db = tf.data.Dataset.from_tensor_slices((x, y))  # 转化为Dataset对象
train_db = train_db.shuffle(100000)  # 随机打散
train_db = train_db.batch(batchsz)  # 批训练
train_db = train_db.map(preprocess)  # 数据预处理

test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.shuffle(1000).batch(batchsz).map(preprocess)
# 通过Sequnentia容器创建LeNet-5
network = Sequential([
    layers.Conv2D(6, kernel_size=3, strides=1),  # 第一个卷积核,6个3X3的卷积核,
    layers.MaxPooling2D(pool_size=2, strides=2),  # 高宽各减半的池化层
    layers.ReLU(),  # 激活函数
    layers.Conv2D(16, kernel_size=3, strides=1),  # 第二个卷积核,16个3X3的卷积核,
    layers.MaxPooling2D(pool_size=2, strides=2),  # 高宽各减半的池化层
    layers.ReLU(),  # 激活函数
    layers.Flatten(),  # 打平层,方便全连接层处理
    layers.Dense(120, activation='relu'),  # 全连接层,120个结点
    layers.Dense(84, activation='relu'),  # 全连接层,84个结点
    layers.Dense(10, activation='relu')  # 全连接层,10个结点
])
# build 一次网络模型,给输入X的形状
network.build(input_shape=(batchsz, 28, 28, 1))
# 统计网络信息
network.summary()

# 创建损失函数的类,在实际计算时直接调用类实例
criteon = losses.CategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.95)  # batch,128,lr=0.01,acc:0.9914
# optimizer = tf.keras.optimizers.Nadam(learning_rate=0.002)  # batch,128,0,89
# optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.01)

# 训练30个epoch
epoch = 30
steps = 0
for n in range(epoch):
    for step, (x, y) in enumerate(train_db):
        with tf.GradientTape() as tape:
            # 插入通道维度 =》[b,28,28,1]
            x = tf.expand_dims(x, axis=3)
            # 前向计算,获取10类别的概率分布 [b,784]=>[b,10]
            out = network(x)
            # 计算交叉熵损失函数,标量
            loss = criteon(y, out)

        # 自动计算梯度
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/350202
推荐阅读
相关标签
  

闽ICP备14008679号