赞
踩
卷积神经网络包含一个或多个卷积层(Convolutional Layer)、池化层(Pooling Layer)和全连接层(Fully-connected Layer)。
卷积神经网络的一个实现如下所示,新加入了一些卷积层和池化层。当然这个网络可以增加、删除或调整 CNN 的网络结构和参数,以达到更好效果。
class CNN(tf.keras.Model):
def __init__(self):
super().__init__()
self.conv1 = tf.keras.layers.Conv2D(
filters=32, # 卷积层神经元(卷积核)数目
kernel_size=[5, 5], # 感受野大小
padding='same', # padding策略(vaild 或 same)
activation=tf.nn.relu # 激活函数
)
self.pool1 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)
self.conv2 = tf.keras.layers.Conv2D(
filters=64,
kernel_size=[5, 5],
padding='same',
activation=tf.nn.relu
)
self.pool2 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)
self.flatten = tf.keras.layers.Reshape(target_shape=(7 * 7 * 64,))
self.dense1 = tf.keras.layers.Dense(units=1024, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(units=10)
def call(self, inputs):
x = self.conv1(inputs) # [batch_size, 28, 28, 32]
x = self.pool1(x) # [batch_size, 14, 14, 32]
x = self.conv2(x) # [batch_size, 14, 14, 64]
x = self.pool2(x) # [batch_size, 7, 7, 64]
x = self.flatten(x) # [batch_size, 7 * 7 * 64]
x = self.dense1(x) # [batch_size, 1024]
x = self.dense2(x) # [batch_size, 10]
output = tf.nn.softmax(x)
return output
将前节的 model = MLP()
更换成 model = CNN()
,训练结束以及预测输出:
批次 4682: 损失 0.010545
批次 4683: 损失 0.003783
批次 4684: 损失 0.000980
测试准确率: 0.990600
可以发现准确率相较于之前的多层感知机有非常显著的提高。
我们来看一个问题,如果我们要做一个具体场景的计算机视觉任务,那么从头开始训练一个网络是合适的选择吗?怎么样才能避免浪费过多的计算时间?
定义:
迁移学习到底在什么情况下使用呢?有两个方面需要我们考虑的
最常见的称呼叫做fine tuning,即微调
已训练好的模型,称之为Pre-trained model
通常我们需要加载以训练好的模型,这些可以是一些机构或者公司在ImageNet等类似比赛上进行训练过的模型。TensorFlow同样也提供了相关模型地址以及API:
https://www.tensorflow.org/api_docs/python/tf/keras/applications,下图是其中包含的一些模型:
这里我们举一个例子,假设有两个任务A和B,任务 A 拥有海量的数据资源且已训练好,但并不是我们的目标任务,任务 B 是我们的目标任务。下面的网络模型假设是已训练好的1000个类别模型
而B任务假设是某个具体场景如250个类别的食物识别,那么该怎么去做
tf.keras.applications 中有一些预定义好的经典卷积神经网络结构,如 VGG16 、 VGG19 、 ResNet 、 MobileNet 等。我们可以直接调用这些经典的卷积神经网络结构(甚至载入预训练的参数),而无需手动定义网络结构。
支持以下结构:
我们可以使用以下代码来实例化一个 MobileNetV2 网络结构:
当执行以上代码时,TensorFlow 会自动从网络上下载 MobileNetV2 网络结构,因此在第一次执行代码时需要具备网络连接。
可以使用 MobileNetV2 网络对相关数据集进行训练看看效果
model = tf.keras.applications.MobileNetV2(weights=None, classes=5)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。