当前位置:   article > 正文

神经网络之MNIST数据集和CIFAR-10数据集训练_mnist cifar-10

mnist cifar-10

在这里插入图片描述

1.构建前馈神经网络训练MNIST

环境:pycharm + win10 + conda3 + python3.6

首先创建一个神经网络类NeuralNetwork:

import numpy as np
#定义一个神经网络类
class NeuralNetwork:
    #初始化参数
    def __init__(self,layers,alpha=0.1):
        self.Weight = []
        self.layers = layers
        self.alpha = alpha

        #遍历从第一层到最后两层之前停止
        for i in np.arange(0,len(layers) - 2):
            #构造MxN权重矩阵,节点相连,增加一个偏置项,正态化数据:除以当前层节点数平方根,得稳定方差
            weight = np.random.randn(layers[i]+1,layers[i+1]+1)
            self.Weight.append(weight/np.sqrt(layers[i]))


        #最后两层输入连接需要一个偏置项,输入不需要
        weight = np.random.randn(layers[-2] + 1, layers[-1])
        self.Weight.append(weight/np.sqrt(layers[-2]))

    #返回网络结构信息
    def __repr__(self):
        return "NeuralNetwork:{}".format("-".join(str(l) for l in self.layers))

    #定义sigmoid激活函数
    def sigmoid(self,x):
        return 1.0/(1+np.exp(-x))

    #sigmoid导数
    def sigmoid_deriv(self,x):
        return x * (1-x)

    #定义fit函数训练神经网络:X训练数据、y是X中单个数据对应的类别标签、epochs训练网络次数、displayUpdate显示更新间隔
    def fit(self,X,y,epochs = 1000, displayUpdate=100):
        #插入一列偏置项特征矩阵
        X = np.c_[X,np.ones(X.shape[0])]

        for epoch in np.arange(0,epochs):
            for (x,target) in zip(X,y):
                self.fit_partial(x,target)

            if epoch == 0 or (epoch +1) % displayUpdate ==0:
                loss = self.calculate_loss(X,y)
                print("epoch={}, loss={:.6f}".format(epoch+1,loss))

    def fit_partial(self,x,y):

        #初始化列表,存储数据点x通过前向传播网络每层输出激活数据
        A = [np.atleast_2d(x)]

        #前向传播
        for layer in np.arange(0,len(self.Weight)):
            #点积计算
            net = A[layer].dot(self.Weight[layer])

            out = self.sigmoid(net)

            A.append(out)

        #反向传播,计算误差
        error = A[-1] - y
        D = [error*self.sigmoid_deriv(A[-1])]

        #通过学习率更新权重矩阵
        for layer in np.arange(len(A) - 2, 0, -1):
            delta = D[-1].dot(self.Weight[layer].T)
            delta = delta * self.sigmoid_deriv(A[layer])
            D.append(delta)

        #更新权重
        D = D[::-1]
        for layer in np.arange(0,len(self.Weight)):
            self.Weight[layer] += -self.alpha * A[layer].T.dot(D[layer])

    #预测
    def predict(self, X, addBias=True):
        p = np.atleast_2d(X)

        #判断偏置项是否存在
        if addBias:
            p = np.c_[p, np.ones((p.shape[0]))]

        for layer in np.arange(0, len(self.Weight)):
            p = self.sigmoid(np.dot(p, self.Weight[layer]))

        return p

    #损失函数
    def calculate_loss(self, X, targets):
        targets = np.atleast_2d(targets)
        predictions = self.predict(X, addBias=False)
        loss = 0.5 * np.sum((predictions - targets) ** 2)
        return loss


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95

通过sklearn下载手写字体数据集,并进行训练,先安装sklearn包:

conda install scikit-learn
  • 1

训练代码:

from neural_network import NeuralNetwork
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np

def mnist_train():
    # 下载
    print("loading MNIST dataset...")
    digits = datasets.load_digits()
    data = digits.data.astype("float")
    data = (data - data.min()) / (data.max() - data.min())
    print("samples:{}, dimension:{}".format(data.shape[0], data.shape[1]))

    # 训练/测试集
    (trainX, testX, trainY, testY) = train_test_split(data, digits.target, test_size=0.25)

    # 将标签转化为向量
    trainY = LabelBinarizer().fit_transform(trainY)
    testY = LabelBinarizer().fit_transform(testY)

    # 训练网络
    print("training network...")
    nn = NeuralNetwork([trainX.shape[1], 32, 16, 10])
    print("{}".format(nn))
    nn.fit(trainX, trainY, epochs=1000)

    # 评估网络
    print("evaluating network...")
    predictions = nn.predict(testX)
    predictions = predictions.argmax(axis=1)
    print(classification_report(testY.argmax(axis=1), predictions))


if __name__ =="__main__":
    mnist_train()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39

结果:

loading MNIST dataset...
samples:1797, dimension:64
training network...
NeuralNetwork:64-32-16-10
epoch=1, loss=605.801722
epoch=100, loss=7.885670
epoch=200, loss=4.572620
epoch=300, loss=3.601540
epoch=400, loss=2.981110
epoch=500, loss=2.804141
epoch=600, loss=2.730295
epoch=700, loss=2.684942
epoch=800, loss=2.654072
epoch=900, loss=2.631717
epoch=1000, loss=2.614806
evaluating network...
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        46
           1       0.93      0.97      0.95        39
           2       1.00      1.00      1.00        52
           3       0.93      0.98      0.96        44
           4       0.98      0.98      0.98        52
           5       1.00      0.90      0.95        51
           6       1.00      0.98      0.99        54
           7       1.00      1.00      1.00        41
           8       0.90      0.88      0.89        42
           9       0.91      1.00      0.95        29

   micro avg       0.97      0.97      0.97       450
   macro avg       0.97      0.97      0.97       450
weighted avg       0.97      0.97      0.97       450

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

2.Keras训练MNIST

环境:

  1. Ubuntu16.04
  2. Pycharm
  3. conda3
  4. python3.6
  5. tensorflow 1.8.0
  6. tensorflow-gpu 1.8.0
  7. scikit-learn 0.21.3
  8. numpy 1.13.3
  9. cudnn 7.0.5
  10. cudatoolkit 8.0
  11. matplotlib 2.2.2

训练代码:

from keras.models import Sequential
from keras.layers.core import Dense
from keras.optimizers import SGD
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np

def mnist_train():
    # 下载mnist数据集,大小有55M
    print("loading MNIST dataset...")
    dataset = datasets.fetch_mldata("MNIST Original")

    #数据缩放到0至1区间
    data = dataset.data.astype("float") / 255.0
    # 分割训练/测试集
    (trainX, testX, trainY, testY) = train_test_split(data, dataset.target, test_size=0.25)

    # 将标签转化为向量
    lb = LabelBinarizer()
    trainY = lb.fit_transform(trainY)
    testY = lb.fit_transform(testY)

    #定义一个 784-256-128-10网络结构
    model = Sequential()
    model.add(Dense(256,input_shape=(784,), activation="sigmoid"))
    model.add(Dense(128, activation="sigmoid"))
    model.add(Dense(10,activation="softmax"))

    # 使用SGD优化器训练网络
    print("training network...")
    sgd = SGD(0.01)
    model.compile(loss="categorical_crossentropy", optimizer=sgd,metrics=["accuracy"])
    H = model.fit(trainX, trainY, validation_data=(testX, testY),epochs=100,batch_size=128)

    # 评估网络
    print("evaluating network...")
    predictions = model.predict(testX, batch_size=128)
    print(classification_report(testY.argmax(axis=1),
                                predictions.argmax(axis=1),
                                target_names=[str(x) for x in lb.classes_]))

    #显示训练结果
    plt.style.use("ggplot")
    plt.figure()
    plt.plot(np.arange(0,100), H.history["loss"], label="train_loss")
    plt.plot(np.arange(0,100), H.history["val_loss"], label="val_loss")
    plt.plot(np.arange(0,100), H.history["acc"], label="train_acc")
    plt.plot(np.arange(0,100), H.history["val_acc"], label="val_acc")
    plt.title("Training Loss and Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Loss/Accuracy")
    plt.legend()
    plt.savefig("./mnist_keras.png")


if __name__ =="__main__":
    mnist_train()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61

运行一下报错: DeprecationWarning: Function mldata_filename is deprecated; mldata_filename was deprecated in version 0.20 and will be removed in version 0.22. Please use fetch_openml.
解决办法:
手动下载MNIST数据集:
https://github.com/amplab/datascience-sp14/raw/master/lab7/mldata/mnist-original.mat
将下载后的mnist-original.mat文件放入/home/ubuntu/scikit_learn_data/mldata路径下
在这里插入图片描述

结果:

Using TensorFlow backend.
loading MNIST dataset...
training network...
Train on 52500 samples, validate on 17500 samples
Epoch 1/100
  128/52500 [..............................] - ETA: 11:43 - loss: 2.4822 - acc: 0.1016
 2688/52500 [>.............................] - ETA: 32s - loss: 2.4141 - acc: 0.0911  
 4992/52500 [=>............................] - ETA: 17s - loss: 2.3748 - acc: 0.0972
 7168/52500 [===>..........................] - ETA: 11s - loss: 2.3545 - acc: 0.1018
 9088/52500 [====>.........................] - ETA: 9s - loss: 2.3433 - acc: 0.1039 
11264/52500 [=====>........................] - ETA: 7s - loss: 2.3349 - acc: 0.1014
13440/52500 [======>.......................] - ETA: 5s - loss: 2.3283 - acc: 0.1026
15360/52500 [=======>......................] - ETA: 5s - loss: 2.3239 - acc: 0.1054
17280/52500 [========>.....................] - ETA: 4s - loss: 2.3199 - acc: 0.1080
...
Epoch 100/100

  128/52500 [..............................] - ETA: 2s - loss: 0.2071 - acc: 0.9531
 2048/52500 [>.............................] - ETA: 1s - loss: 0.2768 - acc: 0.9268
 3968/52500 [=>............................] - ETA: 1s - loss: 0.2854 - acc: 0.9214
 5888/52500 [==>...........................] - ETA: 1s - loss: 0.2810 - acc: 0.9222
 7808/52500 [===>..........................] - ETA: 1s - loss: 0.2792 - acc: 0.9214
 9984/52500 [====>.........................] - ETA: 1s - loss: 0.2813 - acc: 0.9206
12032/52500 [=====>........................] - ETA: 1s - loss: 0.2812 - acc: 0.9199
14208/52500 [=======>......................] - ETA: 1s - loss: 0.2797 - acc: 0.9205
16256/52500 [========>.....................] - ETA: 0s - loss: 0.2803 - acc: 0.9202
18432/52500 [=========>....................] - ETA: 0s - loss: 0.2766 - acc: 0.9211
20480/52500 [==========>...................] - ETA: 0s - loss: 0.2761 - acc: 0.9203
22272/52500 [===========>..................] - ETA: 0s - loss: 0.2754 - acc: 0.9205
24320/52500 [============>.................] - ETA: 0s - loss: 0.2769 - acc: 0.9201
26496/52500 [==============>...............] - ETA: 0s - loss: 0.2761 - acc: 0.9204
28672/52500 [===============>..............] - ETA: 0s - loss: 0.2733 - acc: 0.9210
30720/52500 [================>.............] - ETA: 0s - loss: 0.2721 - acc: 0.9215
32640/52500 [=================>............] - ETA: 0s - loss: 0.2720 - acc: 0.9216
34560/52500 [==================>...........] - ETA: 0s - loss: 0.2715 - acc: 0.9213
36480/52500 [===================>..........] - ETA: 0s - loss: 0.2714 - acc: 0.9210
38528/52500 [=====================>........] - ETA: 0s - loss: 0.2723 - acc: 0.9208
40576/52500 [======================>.......] - ETA: 0s - loss: 0.2720 - acc: 0.9205
42624/52500 [=======================>......] - ETA: 0s - loss: 0.2730 - acc: 0.9205
44416/52500 [========================>.....] - ETA: 0s - loss: 0.2728 - acc: 0.9206
46336/52500 [=========================>....] - ETA: 0s - loss: 0.2733 - acc: 0.9204
48256/52500 [==========================>...] - ETA: 0s - loss: 0.2727 - acc: 0.9203
50304/52500 [===========================>..] - ETA: 0s - loss: 0.2725 - acc: 0.9203
52352/52500 [============================>.] - ETA: 0s - loss: 0.2733 - acc: 0.9201
52500/52500 [==============================] - 2s 30us/step - loss: 0.2735 - acc: 0.9201 - val_loss: 0.2829 - val_acc: 0.9186
evaluating network...
              precision    recall  f1-score   support

         0.0       0.94      0.96      0.95      1722
         1.0       0.95      0.97      0.96      2046
         2.0       0.91      0.89      0.90      1682
         3.0       0.92      0.89      0.90      1814
         4.0       0.92      0.93      0.93      1734
         5.0       0.87      0.88      0.87      1504
         6.0       0.94      0.96      0.95      1705
         7.0       0.93      0.93      0.93      1816
         8.0       0.90      0.88      0.89      1753
         9.0       0.89      0.89      0.89      1724

    accuracy                           0.92     17500
   macro avg       0.92      0.92      0.92     17500
weighted avg       0.92      0.92      0.92     17500


Process finished with exit code 0
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65

训练图如下:
在这里插入图片描述

3.Keras训练CIFAR-10

在这里插入图片描述
环境:

  1. Ubuntu16.04
  2. Pycharm
  3. conda3
  4. python3.6
  5. tensorflow 1.8.0
  6. tensorflow-gpu 1.8.0
  7. scikit-learn 0.21.3
  8. numpy 1.13.3
  9. cudnn 7.0.5
  10. cudatoolkit 8.0
  11. matplotlib 2.2.2

cifar-10数据集有60000张图片,大小为32x32的RGB图,每一张图的数据量有32x32x3 = 3072。一共有10个类,每个类别有6000张图像。50000张图像用来训练,10000张用来测试。
训练代码:

from keras.models import Sequential
from keras.layers.core import Dense
from keras.optimizers import SGD
from keras.datasets import cifar10
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np

def mnist_train():
    # 下载cifar10数据集,大小有170M
    print("loading CIFAR10 dataset...")
    ((trainX,trainY),(testX,testY)) = cifar10.load_data()

    #数据缩放到0至1区间
    trainX = trainX.astype("float") / 255.0
    testX = testX.astype("float") / 255.0
    #创建矩阵
    trainX = trainX.reshape((trainX.shape[0], 3072))
    testX = testX.reshape((testX.shape[0],3072))

    # 将标签转化为向量
    lb = LabelBinarizer()
    trainY = lb.fit_transform(trainY)
    testY = lb.fit_transform(testY)

    #类别标签
    labelNames = ["airplane","automobile","bird","cat","deer",
                  "dog","frog","horse","ship","truck"]

    #定义一个 3072-1024-512-10网络结构
    model = Sequential()
    model.add(Dense(1024,input_shape=(3072,), activation="relu"))
    model.add(Dense(512, activation="relu"))
    model.add(Dense(10,activation="softmax"))

    # 使用SGD优化器训练网络
    print("training network...")
    sgd = SGD(0.01)
    model.compile(loss="categorical_crossentropy", optimizer=sgd,metrics=["accuracy"])
    H = model.fit(trainX, trainY, validation_data=(testX, testY),epochs=100,batch_size=128)

    # 评估网络
    print("evaluating network...")
    predictions = model.predict(testX, batch_size=128)
    print(classification_report(testY.argmax(axis=1),
                                predictions.argmax(axis=1),
                                target_names= labelNames))

    #显示训练结果
    plt.style.use("ggplot")
    plt.figure()
    plt.plot(np.arange(0,100), H.history["loss"], label="train_loss")
    plt.plot(np.arange(0,100), H.history["val_loss"], label="val_loss")
    plt.plot(np.arange(0,100), H.history["acc"], label="train_acc")
    plt.plot(np.arange(0,100), H.history["val_acc"], label="val_acc")
    plt.title("Training Loss and Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Loss/Accuracy")
    plt.legend()
    plt.savefig("./cifar10_keras.png")


if __name__ =="__main__":
    mnist_train()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67

结果:

Using TensorFlow backend.
loading CIFAR10 dataset...
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

     8192/170498071 [..............................] - ETA: 3:51:38
    24576/170498071 [..............................] - ETA: 2:34:56
    40960/170498071 [..............................] - ETA: 2:19:34
    57344/170498071 [..............................] - ETA: 2:01:56
    73728/170498071 [..............................] - ETA: 1:52:03
    90112/170498071 [..............................] - ETA: 1:45:49
	...
	170401792/170498071 [============================>.] - ETA: 2s
	170418176/170498071 [============================>.] - ETA: 2s
	170434560/170498071 [============================>.] - ETA: 1s
	170450944/170498071 [============================>.] - ETA: 1s
	170467328/170498071 [============================>.] - ETA: 0s
	170483712/170498071 [============================>.] - ETA: 0s
	170500096/170498071 [==============================] - 4585s 27us/step
training network...

Train on 50000 samples, validate on 10000 samples
	Epoch 1/100
	  128/50000 [..............................] - ETA: 7:31 - loss: 2.3651 - acc: 0.0781
	 1664/50000 [..............................] - ETA: 35s - loss: 2.2779 - acc: 0.1436 
	 3072/50000 [>.............................] - ETA: 19s - loss: 2.2363 - acc: 0.1735
	 4480/50000 [=>............................] - ETA: 13s - loss: 2.2142 - acc: 0.1833
	 5888/50000 [==>...........................] - ETA: 10s - loss: 2.1908 - acc: 0.1967
	...
	47872/50000 [===========================>..] - ETA: 0s - loss: 0.6568 - acc: 0.7823
	49280/50000 [============================>.] - ETA: 0s - loss: 0.6572 - acc: 0.7820
	50000/50000 [==============================] - 2s 42us/step - loss: 0.6585 - acc: 0.7816 - val_loss: 1.4283 - val_acc: 0.5421
	Epoch 100/100
	
	  128/50000 [..............................] - ETA: 2s - loss: 0.6698 - acc: 0.7578
	 1536/50000 [..............................] - ETA: 1s - loss: 0.6496 - acc: 0.7819
	 2944/50000 [>.............................] - ETA: 1s - loss: 0.6353 - acc: 0.7860
	 4352/50000 [=>............................] - ETA: 1s - loss: 0.6302 - acc: 0.7891
	 5632/50000 [==>...........................] - ETA: 1s - loss: 0.6478 - acc: 0.7814
	 7168/50000 [===>..........................] - ETA: 1s - loss: 0.6438 - acc: 0.7824
	 8704/50000 [====>.........................] - ETA: 1s - loss: 0.6528 - acc: 0.7792
	10112/50000 [=====>........................] - ETA: 1s - loss: 0.6492 - acc: 0.7797
	11520/50000 [=====>........................] - ETA: 1s - loss: 0.6623 - acc: 0.7752
	12928/50000 [======>.......................] - ETA: 1s - loss: 0.6574 - acc: 0.7767
	14336/50000 [=======>......................] - ETA: 1s - loss: 0.6549 - acc: 0.7785
	15744/50000 [========>.....................] - ETA: 1s - loss: 0.6556 - acc: 0.7781
	17024/50000 [=========>....................] - ETA: 1s - loss: 0.6519 - acc: 0.7803
	18560/50000 [==========>...................] - ETA: 1s - loss: 0.6531 - acc: 0.7798
	19968/50000 [==========>...................] - ETA: 1s - loss: 0.6532 - acc: 0.7799
	21248/50000 [===========>..................] - ETA: 1s - loss: 0.6536 - acc: 0.7804
	22528/50000 [============>.................] - ETA: 1s - loss: 0.6548 - acc: 0.7805
	23936/50000 [=============>................] - ETA: 0s - loss: 0.6534 - acc: 0.7809
	25344/50000 [==============>...............] - ETA: 0s - loss: 0.6505 - acc: 0.7819
	26752/50000 [===============>..............] - ETA: 0s - loss: 0.6499 - acc: 0.7820
	28160/50000 [===============>..............] - ETA: 0s - loss: 0.6499 - acc: 0.7823
	29696/50000 [================>.............] - ETA: 0s - loss: 0.6495 - acc: 0.7821
	31104/50000 [=================>............] - ETA: 0s - loss: 0.6498 - acc: 0.7818
	32640/50000 [==================>...........] - ETA: 0s - loss: 0.6544 - acc: 0.7801
	34048/50000 [===================>..........] - ETA: 0s - loss: 0.6537 - acc: 0.7803
	35456/50000 [====================>.........] - ETA: 0s - loss: 0.6542 - acc: 0.7806
	36864/50000 [=====================>........] - ETA: 0s - loss: 0.6533 - acc: 0.7805
	38144/50000 [=====================>........] - ETA: 0s - loss: 0.6528 - acc: 0.7810
	39552/50000 [======================>.......] - ETA: 0s - loss: 0.6508 - acc: 0.7819
	40960/50000 [=======================>......] - ETA: 0s - loss: 0.6543 - acc: 0.7805
	42368/50000 [========================>.....] - ETA: 0s - loss: 0.6537 - acc: 0.7805
	43776/50000 [=========================>....] - ETA: 0s - loss: 0.6526 - acc: 0.7812
	45056/50000 [==========================>...] - ETA: 0s - loss: 0.6521 - acc: 0.7815
	46464/50000 [==========================>...] - ETA: 0s - loss: 0.6522 - acc: 0.7815
	47872/50000 [===========================>..] - ETA: 0s - loss: 0.6522 - acc: 0.7815
	49280/50000 [============================>.] - ETA: 0s - loss: 0.6530 - acc: 0.7811
	50000/50000 [==============================] - 2s 42us/step - loss: 0.6532 - acc: 0.7810 - val_loss: 1.4354 - val_acc: 0.5334
evaluating network...
              precision    recall  f1-score   support

    airplane       0.64      0.53      0.58      1000
  automobile       0.61      0.69      0.64      1000
        bird       0.55      0.36      0.43      1000
         cat       0.41      0.33      0.37      1000
        deer       0.58      0.37      0.46      1000
         dog       0.58      0.32      0.41      1000
        frog       0.60      0.65      0.63      1000
       horse       0.45      0.76      0.57      1000
        ship       0.46      0.84      0.59      1000
       truck       0.62      0.48      0.54      1000

    accuracy                           0.53     10000
   macro avg       0.55      0.53      0.52     10000
weighted avg       0.55      0.53      0.52     10000


Process finished with exit code 0

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91

结果图:
在这里插入图片描述

标准前馈神经网络在训练CIFAR-10数据集时,验证损失函数值增大,出现了过拟合现象,通过改变学习率, 网络节点数和深度等优化超参数, 结果并没有多大改善,需要通过卷积神经网来极大地改善准确率。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/350243
推荐阅读
相关标签
  

闽ICP备14008679号