当前位置:   article > 正文

深度学习:Keras快速开发入门(二)Keras内置Scikit-Learn接口包装器:卷积网络识别Mnist手写数字,并用网格搜索方法自动调优一些超参数_keras.wrappers.scikit_learn

keras.wrappers.scikit_learn

这本书基本抄袭了keras中文文档,建议配合keras中文文档一起使用
github下载keras中文文档中的示例程序mnist_sklearn_wrapper.py,自己添加注释。
将源码第79行dense_size_candidates = [[32], [64], [32, 32], [64, 64]]改为dense_size_candidates = [[32], [64], [32, 32], [64, 64], [32, 64]. [64, 32]]

Keras内置Scikit-Learn接口包装器

Scikit-Learn机器学习库的简约易用与Keras框架非常相似。如果用过Scikit-Learn,并且想要在Keras中使用Scikit-Learn的一些特有功能(如交叉验证、网格搜索等),可以使用Keras内置的Scikit-Learn接口包装器。
  通过包装器将Sequential模型(贯序模型)(仅有一个输入)作为Scikit-Learn工作流的一部分,相关的包装器定义在keras.wrappers.scikit_learn.py中。
  目前,有两个包装器可用:

keras.wrappers.scikit_learn.KerasClassifier(build_fn=None, **sk_params)
  • 1

实现了sklearn的分类器接口

keras.wrappers.scikit_learn.KerasRegressor(build_fn=None, **sk_params)
  • 1

实现了sklearn的回归器接口

参数如下

build_fn:可调用的函数或类对象

build_fn应构造、编译并返回一个Keras模型。该模型将稍后用于训练/测试。build_fn的值可能为下列3中之一:
  一个函数。
  一个具有call方法的类实例
  None,代表类继承自KerasClassifier或KerasRegressor,其call方法是其父类的call方法

sk_params:模型参数和训练参数

sk_params合法的模型参数为build_fn的参数。注意,“build_fn”应提供其参数的默认值。所以我们不传递任何值给sk_params也可以创建一个分类器/回归器。sk_params还接受用于调用fit、predict、predict_proba和score方法的参数,如nb_epoch、batch_size等。这些用于训练或预测的参数按如下先后顺序选择:
  传递给fit、predict、predict_proba和score的字典参数。
  传递给sk_params的参数。
  keras.models.Sequential、fit、predict、predict_proba和score的默认值
  当使用scikit-learn的网格搜索(grid_search)接口时,合法的可调参数是可以传递给sk_params的参数,包括训练参数。即可以使用grid_search来搜索最佳的batch_size或nb_epoch以及其他模型参数。

网格搜索

自动超优化参数是Scikit-Learn接口中非常实用的功能之一。这种优化方式无需人工调参,其中最常用的两种方法是网格搜索(grid_search)和随机化搜索(randomized_search)。网格搜索会尝试给出的所有超参数组合:而随机搜索更多地使用在超参数较多的情况,会在超参数空间的一个特定分布上随机抽样一些超参数组合,最后返回抽样到的最好超参数组合。
在这里插入图片描述

代码

下面展示了如何使用Scikit-Learn的网格搜索(grid_search)自动寻找合适超参数的过程。
整个任务是建立一个简单的卷积网络识别mnist手写数字,并且用网格搜索方法自动调优一些超参数,返回最优超参数。

# 下面展示了如何使用Scikit-Learn的网格搜索(grid_search)自动寻找合适超参数的过程

# 整个任务是建立一个简单的卷积网络识别mnist手写数字,并且用网格搜索方法自动调优
# 一些超参数,返回最优超参数

# 首先,代码中我们导入一些必须的库,
# 包括前面提到的Keras包装器KerasClassifier和Scikit-Learn的GridSearchCV方法
from __future__ import print_function

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.wrappers.scikit_learn import KerasClassifier
from keras import backend as K
from sklearn.model_selection import GridSearchCV


num_classes = 10

# 输入图片的尺寸
img_rows, img_cols = 28, 28

# 加载数据集,进行基本数据归一化
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

# 将类向量转换为二进制类矩阵
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)


def make_model(dense_layer_sizes, filters, kernel_size, pool_size):
    '''创建由2个卷积层和密集层组成的模型

    dense_layer_sizes: 图层大小列表,这个列表中的每一层都有一个数字
    filters: 每个卷积层中卷积滤波器的个数
    kernel_size: 卷积核大小
    pool_size: 最大池化的池化窗口大小
    '''

    model = Sequential()
    model.add(Conv2D(filters, kernel_size,
                     padding='valid',
                     input_shape=input_shape))
    model.add(Activation('relu'))
    model.add(Conv2D(filters, kernel_size))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=pool_size))
    model.add(Dropout(0.25))

    model.add(Flatten())
    for layer_size in dense_layer_sizes:
        model.add(Dense(layer_size))
        model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(num_classes))
    model.add(Activation('softmax'))

    model.compile(loss='categorical_crossentropy',
                  optimizer='adadelta',
                  metrics=['accuracy'])

    return model


# 我们构造了一个简单的有二层卷积的网络,注意我们把全连接层留空,
# 假设需要添加到全连接层的数量和输出未知。即dense_layer_sizes是我们想要调优的超参数
# 我们想要考虑这样一些情况
# (1) 再添加1层32输出到额全连接层
# (2) 再添加1层64输出到额全连接层
# (3) 再添加2层32输出到额全连接层
# (4) 再添加2层63输出到额全连接层
# (5) 先添加1层32输出的全连接层,再添加1层64输出的全连接层
# (6) 先添加1层64输出的全连接层,再添加1层32输出的全连接层
# 因此,所有情况列表为[[32], [64], [32, 32], [64, 64],[32,64],[64,32]]
#                   根据这些情况我们进行网格搜索
# 通过Keras的KerasClassifier接口获得经过包装的Scikit-Learn模型实例
# 把模型实例放在Scikit-Learn的网格搜索GridSearchCV接口中进行模型的超参数搜索和验证
dense_size_candidates = [[32], [64], [32, 32], [64, 64], [32, 64], [64, 32]]
# KerasClassifier的第一个参数为可调用的函数或类对象
# 参数1应构造、编译并返回一个Keras模型。该模型将稍后用于训练/测试
my_classifier = KerasClassifier(make_model, batch_size=32)
# 注意,虽然epochs、filters、kernel_size、pool_size等参数没有在构造模型时明确指定
# 这些参数默认在Scikit-Learn接口中是可调超参数
# 可以直接在使用GridSearchCV时指定调优的组合
validator = GridSearchCV(my_classifier,
                         param_grid={'dense_layer_sizes': dense_size_candidates,
                                     # epochs即使不是模型构建函数的参数也可用于调优
                                     'epochs': [3, 6],
                                     'filters': [8],
                                     'kernel_size': [3],
                                     'pool_size': [2]},
                         scoring='neg_log_loss',
                         n_jobs=1)
# 最后,调用.fit方法进行超参数搜索,输出结果
validator.fit(x_train, y_train)

print('The parameters of the best model are: ')
print(validator.best_params_)

# 此处validator.best_estimator_返回的是被Keras包装过的Scikit-Learn模型对象
# validator.best_estimator_.model返回的是未包装过的Keras模型对象
best_model = validator.best_estimator_.model
metric_names = best_model.metrics_names
metric_values = best_model.evaluate(x_test, y_test)
for metric, value in zip(metric_names, metric_values):
    print(metric, ': ', value)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/350186
推荐阅读
相关标签
  

闽ICP备14008679号