当前位置:   article > 正文

GridSearchCV(网格搜索)

gridsearchcv

了解网格搜索,可以阅读:Python机器学习笔记 Grid SearchCV(网格搜索)

  1. GridSearchCV其实可以拆分为GridSearch和CV,即网格搜索和交叉验证。网格搜索,搜索的是参数,即在指定的参数范围内,按步长依次调整参数,利用调整的参数训练学习器,从所有的参数中找到在验证集上精度最高的参数
  2. 以随机森林为例说明GridSearch网格搜索中运行print(gsearch1.grid_scores_)出现的问题
# -*- coding: utf-8 -*-
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
#自适应数据加载函数
def loadDataSet(fileName):
    '''
    这个函数用来加载训练数据集
    输入:存储数据的文件名
    输出:数据集列表 以及 类别标签列表
    '''
    numFeat = len(open(fileName).readline().split('\t')) #get number of fields
    dataMat = []; labelMat = []
    fr = open(fileName)
    for line in fr.readlines():
        lineArr =[]
        curLine = line.strip().split('\t')
        for i in range(numFeat-1):
            lineArr.append(float(curLine[i]))
        dataMat.append(lineArr)
        labelMat.append(float(curLine[-1]))
    return dataMat,labelMat
# 读取数据
x_train, y_train = loadDataSet('horseColicTraining2.txt')
#随机森林
rf0 = RandomForestClassifier(random_state=10)
rf0.fit(x_train,y_train)

# 网格搜索
param_test1 = {
   'n_estimators':range(10,100,10)}
gsearch1 = GridSearchCV(estimator = RandomForestClassifier(min_samples_split=100,
                                  min_samples_leaf=20,max_depth=8,max_features='sqrt' ,random_state=10), 
                       param_grid = param_test1, scoring='roc_auc'
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/463725
推荐阅读
相关标签
  

闽ICP备14008679号