当前位置:   article > 正文

数据挖掘(5)模型调优(sklearn中的网格搜索)_grid.predict

grid.predict

使用网格搜索法对5个模型进行调优(调参时采用五折交叉验证的方式),并进行模型评估

可以使用sklearn中的网格搜索

  1. from sklearn.model_selection import GridSearchCV
  2. clf = LogisticRegression(C=1.0, max_iter=1000).fit(train_data, train_label)
  3. parameters = {'C':[1.0,2.0,3.0,4.0,5.0], 'max_iter': [100,200,500,1000,1500,2000]}
  4. grid = GridSearchCV(clf, parameters, cv=5)
  5. grid = grid.fit(train_data, train_label)
  6. grid_test = grid.predict(test_data)
  7. grid_train = grid.predict(train_data)
  1. from sklearn import metrics
  2. import pandas as pd
  3. import numpy as np
  4. from sklearn.model_selection import StratifiedKFold
  5. from sklearn.linear_model import LogisticRegression
  6. from sklearn.ensemble import RandomForestClassifier
  7. from sklearn.tree import DecisionTreeClassifier
  8. from sklearn import svm
  9. from xgboost.sklearn import XGBClassifier
  10. from sklearn.feature_selection import RFE
  11. from sklearn.metrics import roc_curve, auc
  12. import matplotlib.pyplot as plt
  13. from sklearn.model_selection import GridSearchCV
  14. def LR_classifier(train_data, train_label, test_data, test_label):
  15. clf = LogisticRegression(C=1.0, max_iter=1000).fit(train_data, train_label)
  16. parameters = {'C':[1.0,2.0,3.0,4.0,5.0], 'max_iter': [100,200,500,1000,1500,2000]}
  17. grid = GridSearchCV(clf, parameters, cv=5)
  18. grid_dtc = grid.fit(train_data, train_label)
  19. grid_test = grid_dtc.predict(test_data)
  20. grid_train = grid_dtc.predict(train_data)
  21. return grid_test, grid_train
  22. def svm_classifier(train_data, train_label, test_data, test_label):
  23. clf = svm.SVC(C=1.0, kernel='linear', gamma=20).fit(train_data, train_label)
  24. parameters = {'C':[1.0,2.0,3.0], 'gamma':[5,10,15, 20, 25]}
  25. grid = GridSearchCV(clf, parameters, cv=5)
  26. grid_dtc = grid.fit(train_data, train_label)
  27. grid_test = grid_dtc.predict(test_data)
  28. grid_train = grid_dtc.predict(train_data)
  29. return grid_test, grid_train
  30. def dt_classifier(train_data, train_label, test_data, test_label):
  31. clf = DecisionTreeClassifier(max_depth=5).fit(train_data, train_label)
  32. parameters = {'max_depth':[2,5,8,10,15]}
  33. grid = GridSearchCV(clf, parameters, cv=5)
  34. grid_dtc = grid.fit(train_data, train_label)
  35. grid_test = grid_dtc.predict(test_data)
  36. grid_train = grid_dtc.predict(train_data)
  37. return grid_test, grid_train
  38. def rf_classifier(train_data, train_label, test_data, test_label):
  39. clf = RandomForestClassifier(n_estimators=8, random_state=5, max_depth=6, min_samples_split=2).fit(train_data, train_label)
  40. parameters = {'n_estimators':[3,5,8,10,14], 'random_state':[2,3,5,7,9],'max_depth':[5,6,8,9,10,15],'min_samples_split':[2,3,4,5,6]}
  41. grid = GridSearchCV(clf, parameters, cv=5)
  42. grid_dtc = grid.fit(train_data, train_label)
  43. grid_test = grid_dtc.predict(test_data)
  44. grid_train = grid_dtc.predict(train_data)
  45. return grid_test, grid_train
  46. def xgb_classifier(train_data, train_label, test_data, test_label):
  47. clf = XGBClassifier(n_estimators=8,learning_rate= 0.25, max_depth=20,subsample=1,gamma=13, seed=1000,num_class=1).fit(train_data, train_label)
  48. parameters = {'n_estimators':[3,5,8,10,14], 'learning_rate':[0.1,0.2,0.25,0.3,0.35,0.4],'max_depth':[5,10,15,20,25],'gamma':[6,9,12,13,15],'seed':[500,1000,1500]}
  49. grid = GridSearchCV(clf, parameters, cv=5)
  50. grid_dtc = grid.fit(train_data, train_label)
  51. grid_test = grid_dtc.predict(test_data)
  52. grid_train = grid_dtc.predict(train_data)
  53. return grid_test, grid_train
模型AccuracyPrecisionRecallF1_scoreAUCROC

Logistic

Regression

train:0.7913

test:0.787

train:0.7351

test:0.7195

train:0.2668, test:0.2552train:0.3915, test:0.3759train:0.6173, test:0.6105
Support Vector Machinetrain:0.7793, test:0.7762train:0.8025, test:0.7783train:0.1632, test:0.1554train:0.2712, test:0.2588train:0.5748, test:0.5702
Decision Treetrain:0.78, test:0.7732train:0.8241, test:0.7474train:0.1611, test:0.1537train:0.2632, test:0.2483train:0.5746, test:0.5676
Random Foresttrain:0.8311, test:0.778train:0.9326, test:0.7077train:0.3504, test:0.2058train:0.502, test:0.3172train:0.6716, test:0.5881
XGBoosttrain:0.839, test:0.7843train:0.8456, test:0.6484train:0.4402, test:0.3127train:0.5786, test:0.4215train:0.7067, test:0.6278

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/412086
推荐阅读
相关标签
  

闽ICP备14008679号