当前位置:   article > 正文

机器学习-交叉验证和网格搜索(带案例)_四倍交叉验证

四倍交叉验证

八、交叉验证和网格搜索

1.什么是交叉验证?

就是将拿到的训练数据,分成训练集和验证集,比如将一份数据分成4份,其中一份作为验证集。然后经过过4次测试,每次都更换不同的验证集。即得到4次模型的结果,取平均值作为最终结果。又称4折交叉验证。

2.为什么要做交叉验证?

交叉验证的目的:为了让被评估的模型更加准确可信。

问题:这个只是让被评估的模型更加准确可信,那么怎么选择或者调优参数呢?

3.什么是网格搜索

通常情况下,**有很多参数是需要手动指定的,这种叫超参数。**但是手动过程繁杂,所以需要对模型预设几种参数组合。每组超参数都采用交叉验证来评估。最后选出最优参数组合建立模型。

4.交叉验证(模型准确可信),网格搜素(模型调优)API:
  • sklearn.model_selection.GridSearchCV(estimator,param_grid=None,cv=None)
    • 对估计器的指定参数值进行详尽搜索
    • estimator:估计器对象
    • param_grid:估计器参数(dict) {‘n_neighbors’:[1,3,5]}
    • cv:指定几折交叉验证
    • fit:输入训练数据
    • score:准确率
    • 结果分析
      • bestscore_:在交叉验证中验证的最好结果
      • bestestimator:最好的参数模型
      • cvresults:每次交叉验证后的验证集准确率结果和训练集准确率结果
5.鸢尾花案例增加K值调优
# 1.获取数据集
from sklearn.datasets import load_iris()
iris = load_iris()

# 2.数据基本处理--划分数据集
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test = train_test_split(iris.data,iris.target,random_state=1)

# 3.特征工程--标准化
from sklearn.preprocessing import StandardScale
transfer = StandardScale()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)

# 4.KNN估计器
from sklearn.neighbors import KNeighborsClassifier
estimator = KNeighborsClassifier()
# 交叉验证和网络搜索
from sklearn.model_selection import GridSearchCV
param_dict = {'n_neighbors':[1,3,5]}
estimator = GridSerachCV(estimator,param_grid=param_dict,cv=3)
estimator.fit(x_train,y_train)

# 5.模型评估
# 方案一,对比真实值和预测值
y_predict = estimator.predict(x_test)
print("预测值为:",y_predict)
print("真实值和预测值的对比:",y_predict==y_test)
# 方案二,直接计算准确率
score = estimator.score(x_test,y_test)
print("准确率为:",score)

# 6.查看交叉验证和网络搜索的结果
print("交叉验证中验证的最好结果",estimator.bestscore_)
print("最好的参数模型",estimator.best_estimator_)
print("每次验证后的准确率结果",estimator.cv_results_)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

九、案例3:预测facebook将要签到的位置

1.项目描述

本次比赛的目的是预测一个人将要签到的地方。 为了本次比赛,Facebook创建了一个虚拟世界,其中包括10公里*10公里共100平方公里的约10万个地方。 对于给定的坐标集,您的任务将根据用户的位置,准确性和时间戳等预测用户下一次的签到位置。 数据被制作成类似于来自移动设备的位置数据。 请注意:您只能使用提供的数据进行预测。

2.数据集介绍
文件说明 train.csv, test.csv
  row id:签入事件的id
  x y:坐标
  accuracy: 准确度,定位精度
  time: 时间戳
  place_id: 签到的位置,这也是你需要预测的内容
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
3.步骤分析
  • 对于数据做一些基本处理(这里所做的一些处理不一定达到很好的效果,我们只是简单尝试,有些特征我们可以根据一些特征选择的方式去做处理)

    • 1 缩小数据集范围 DataFrame.query()
    • 2 选取有用的时间特征
    • 3 将签到位置少于n个用户的删除
  • 分割数据集

  • 标准化处理

  • k-近邻预测

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/954069
推荐阅读
相关标签
  

闽ICP备14008679号