赞
踩
在机器学习模型中,需要人工选择的参数称为超参数。比如随机森林中决策树的个数,人工神经网络模型中隐藏层层数和每层的节点个数,正则项中常数大小等等,他们都需要事先指定。超参数选择不恰当,就会出现欠拟合或者过拟合的问题。而在选择超参数的时候,有两个途径,一个是凭经验微调,另一个就是选择不同大小的参数,带入模型中,挑选表现最好的参数。
微调的一种方法是手工调制超参数,直到找到一个好的超参数组合,这么做的话会非常沉长,你也可能没有时间探索多种组合,所以可以使用Scikit-Learn的GridSearchCV来做这项搜索工作
GridSearchCV的名字其实可以拆分为两部分,GridSearch和CV,即网格搜索和交叉验证。这两个名字都非常好理解。网格搜索,搜索的是参数,即在指定的参数范围内,按步长依次调整参数,利用调整的参数训练学习器,从所有的参数中找到在验证集上精度最高的参数,这其实是一个训练和比较的过程。
GridSearchCV可以保证在指定的参数范围内找到精度最高的参数,但是这也是网格搜索的缺陷所在,他要求遍历所有可能参数的组合,在面对大数据集和多参数的情况下,非常耗时
什么是GridSearch
Grid Search:一种调参手段;穷举搜索:在所有候选的参数选择中,通过循环遍历,尝试每一种可能性,表现最好的参数就是最终的结果。其原理就像是在数组里找到最大值。这
种方法的主要缺点是比较耗时!
所以网格搜索适用于三四个(或者更少)的超参数(当超参数的数量增长时,网格搜索的计算复杂度会呈现指数增长,这时候则使用随机搜索),用户列出一个较小的超参数值域,这些超参数至于的笛卡尔积(排列组合)为一组组超参数。网格搜索算法使用每组超参数训练模型并挑选验证集误差最小的的超参数组合
以随机森林为例说明GridSearch网格搜索
- from sklearn.model_selection import GridSearchCV
- from sklearn import RandomForestRegressor
- param_grid = [
- {'n_estimators': [3,10, 30],
- 'max_features':[2,3,4]},
- {'bootstrap': [False],
- 'n_estimators': [3, 10],
- 'max_features': [2,4,6,8]},
- ]
-
- forest_reg = RandomForestRegressor()
- grid_search=GridSearchCV(forest_reg, param_grid, cv=5, scoring='neg_mean_squared_error')
- grid_search.fit(housing_prepared, housing_labels)
-
- grid_search.best_params_
-
- # 输出结果如下:
- # {'max_features': 8, 'n_estimators': 30}
-
- # 得到最好的模型
- # grid_search.best_estimator_
- # 随机森林
- RandomForestRegressor(bootstrap=True,
- criterion='mse',
- max_depth=None,
- max_features=8,
- max_leaf_nodes=None,
- min_impurity_decrease=0.0,
- min_impurity_split=None,
- min_samples_leaf=1,
- min_samples_split=2,
- min_weight_fraction_leaf=0.0,
- n_estimators=30,
- n_jobs=1,
- oob_score=False,
- random_state=None,
- verbose=0,
- warm_start=False)
我们在搜索超参数的时候,如果超参数个数较少(三四个或者更少),那么我们可以采用网格搜索,一种穷尽式的搜索方法。但是当超参数个数比较多的时候,我们仍然采用网格搜索,那么搜索所需时间将会指数级上升
RandomizedSearchCV的使用方法其实是和GridSearchCV一致的,但它以随机在参数空间中采样的方式代替了GridSearchCV对于于参数的网格搜索,在对于有连续变量的参数时,RandomizedSearchCV会将其当做一个分布进行采样进行这是网格搜索做不到的,它的搜索能力取决于设定的n_iter参数
- sklearn.model_selection.GridSearchCV(estimator,
- param_grid,
- scoring=None,
- fit_params=None,
- n_jobs=None,
- iid='warn',
- refit=True,
- CV='warn',
- verbose=0,
- pre_dispatch=2*n_jobs,
- error_score='raise-deprecating',
- return_train_score='warn')
进行预测常用方法和属性
GridSearchCV属性说明
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。