当前位置:   article > 正文

【XGBoost 多分类】XGBoost解决多分类问题_xgboost 多分类

xgboost 多分类

下面将以一个例子来讲解 XGBoost 解决多分类问题。

1、下载数据集,数据集我们采用小麦种子数据集,该数据集有3类,已知小麦种子包含 7个特征,分别为面积,周长,紧凑度,仔粒长度,仔粒宽度,不对称系数,仔粒腹沟长度,小麦类别为1,2,3

linux --下载数据集:

wget https://archive.ics.uci.edu/ml/machine-learning-databases/00236/seeds_dataset.txt
  • 1

window–下载数据集:
将地址复制到浏览器即可下载

https://archive.ics.uci.edu/ml/machine-learning-databases/00236/seeds_dataset.txt
  • 1

文件名叫 seeds_dataset.txt 数据类似如下:

15.26	14.84	0.871	5.763	3.312	2.221	5.22	1
14.88	14.57	0.8811	5.554	3.333	1.018	4.956	1
14.29	14.09	0.905	5.291	3.337	2.699	4.825	1
13.84	13.94	0.8955	5.324	3.379	2.259	4.805	1
16.14	14.99	0.9034	5.658	3.562	1.355	5.175	1
14.38	14.21	0.8951	5.386	3.312	2.462	4.956	1
14.69	14.49	0.8799	5.563	3.259	3.586	5.219	1
14.11	14.1	0.8911	5.42	3.302	2.7		5		1
16.63	15.46	0.8747	6.053	3.465	2.04	5.877	1
16.44	15.25	0.888	5.884	3.505	1.969	5.533	1
15.26	14.85	0.8696	5.714	3.242	4.543	5.314	1
14.03	14.16	0.8796	5.438	3.201	1.717	5.001	1
13.89	14.02	0.888	5.439	3.199	3.986	4.738	1
.....
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
# -*- coding: utf-8 -*-

import pandas as pd
import xgboost as xgb
import numpy as np
import warnings
warnings.filterwarnings('ignore')
from sklearn.model_selection import train_test_split

data_path='./datasets/seeds_dataset.txt'
data=pd.read_csv(data_path,header=None,sep='\s+',converters={7:lambda x:int(x)-1})
data.rename(columns={7:'lable'},inplace=True)
print(data)

# # # 生产一个随机数并选择小于0.8的数据
# mask=np.random.rand(len(data))<0.8
# train=data[mask]
# test=data[~mask]
#
# # 生产DMatrix
# xgb_train=xgb.DMatrix(train.iloc[:,:6],label=train.lable)
# xgb_test=xgb.DMatrix(test.iloc[:,:6],label=test.lable)



X=data.iloc[:,:6]
Y=data.iloc[:,7]
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.25, random_state=100)

xgb_train=xgb.DMatrix(X_train,label=y_train)
xgb_test=xgb.DMatrix(X_test,label=y_test)



# 设置模型参数

params={
    'objective':'multi:softmax',
    'eta':0.1,
    'max_depth':5,
    'num_class':3
}

watchlist=[(xgb_train,'train'),(xgb_test,'test')]
# 设置训练轮次,这里设置60轮
num_round=60
bst=xgb.train(params,xgb_train,num_round,watchlist)

# 模型预测

pred=bst.predict(xgb_test)
print(pred)

#模型评估

# error_rate=np.sum(pred!=test.lable)/test.lable.shape[0]
error_rate=np.sum(pred!=y_test)/y_test.shape[0]

print('测试集错误率(softmax):{}'.format(error_rate))

accuray=1-error_rate
print('测试集准确率:%.4f' %accuray)


# 模型保存
bst.save_model("./datasets/002.model")


# 模型加载
bst=xgb.Booster()
bst.load_model("./datasets/002.model")
pred=bst.predict(xgb_test)
print(pred)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73

运行结果:

       0      1       2      3      4      5      6  lable
0    15.26  14.84  0.8710  5.763  3.312  2.221  5.220      0
1    14.88  14.57  0.8811  5.554  3.333  1.018  4.956      0
2    14.29  14.09  0.9050  5.291  3.337  2.699  4.825      0
3    13.84  13.94  0.8955  5.324  3.379  2.259  4.805      0
4    16.14  14.99  0.9034  5.658  3.562  1.355  5.175      0
..     ...    ...     ...    ...    ...    ...    ...    ...
205  12.19  13.20  0.8783  5.137  2.981  3.631  4.870      2
206  11.23  12.88  0.8511  5.140  2.795  4.325  5.003      2
207  13.20  13.66  0.8883  5.236  3.232  8.315  5.056      2
208  11.84  13.21  0.8521  5.175  2.836  3.598  5.044      2
209  12.30  13.34  0.8684  5.243  2.974  5.637  5.063      2

[210 rows x 8 columns]
[0]	train-merror:0.012739	test-merror:0.075472
[1]	train-merror:0.012739	test-merror:0.056604
[2]	train-merror:0.006369	test-merror:0.075472
[3]	train-merror:0.012739	test-merror:0.075472
[4]	train-merror:0.006369	test-merror:0.075472
[5]	train-merror:0	test-merror:0.075472
[6]	train-merror:0	test-merror:0.075472
[7]	train-merror:0	test-merror:0.075472
[8]	train-merror:0	test-merror:0.075472
[9]	train-merror:0	test-merror:0.075472
[10]	train-merror:0	test-merror:0.075472
[11]	train-merror:0	test-merror:0.075472
[12]	train-merror:0	test-merror:0.075472
[13]	train-merror:0	test-merror:0.075472
[14]	train-merror:0	test-merror:0.075472
[15]	train-merror:0	test-merror:0.075472
[16]	train-merror:0	test-merror:0.075472
[17]	train-merror:0	test-merror:0.075472
[18]	train-merror:0	test-merror:0.075472
[19]	train-merror:0	test-merror:0.075472
[20]	train-merror:0	test-merror:0.075472
[21]	train-merror:0	test-merror:0.075472
[22]	train-merror:0	test-merror:0.075472
[23]	train-merror:0	test-merror:0.075472
[24]	train-merror:0	test-merror:0.075472
[25]	train-merror:0	test-merror:0.075472
[26]	train-merror:0	test-merror:0.075472
[27]	train-merror:0	test-merror:0.075472
[28]	train-merror:0	test-merror:0.075472
[29]	train-merror:0	test-merror:0.075472
[30]	train-merror:0	test-merror:0.075472
[31]	train-merror:0	test-merror:0.075472
[32]	train-merror:0	test-merror:0.075472
[33]	train-merror:0	test-merror:0.075472
[34]	train-merror:0	test-merror:0.075472
[35]	train-merror:0	test-merror:0.075472
[36]	train-merror:0	test-merror:0.075472
[37]	train-merror:0	test-merror:0.075472
[38]	train-merror:0	test-merror:0.075472
[39]	train-merror:0	test-merror:0.075472
[40]	train-merror:0	test-merror:0.075472
[41]	train-merror:0	test-merror:0.075472
[42]	train-merror:0	test-merror:0.075472
[43]	train-merror:0	test-merror:0.075472
[44]	train-merror:0	test-merror:0.075472
[45]	train-merror:0	test-merror:0.075472
[46]	train-merror:0	test-merror:0.075472
[47]	train-merror:0	test-merror:0.075472
[48]	train-merror:0	test-merror:0.075472
[49]	train-merror:0	test-merror:0.075472
[50]	train-merror:0	test-merror:0.075472
[51]	train-merror:0	test-merror:0.075472
[52]	train-merror:0	test-merror:0.075472
[53]	train-merror:0	test-merror:0.075472
[54]	train-merror:0	test-merror:0.075472
[55]	train-merror:0	test-merror:0.075472
[56]	train-merror:0	test-merror:0.075472
[57]	train-merror:0	test-merror:0.075472
[58]	train-merror:0	test-merror:0.075472
[59]	train-merror:0	test-merror:0.075472
[0. 2. 2. 2. 1. 0. 2. 1. 2. 2. 1. 1. 1. 0. 1. 0. 2. 1. 0. 2. 1. 0. 1. 0.
 2. 2. 1. 2. 0. 0. 2. 0. 2. 2. 1. 2. 2. 2. 2. 1. 1. 0. 1. 1. 0. 2. 0. 2.
 2. 1. 0. 2. 2.]
测试集错误率(softmax):0.07547169811320754
测试集准确率:0.9245
[0. 2. 2. 2. 1. 0. 2. 1. 2. 2. 1. 1. 1. 0. 1. 0. 2. 1. 0. 2. 1. 0. 1. 0.
 2. 2. 1. 2. 0. 0. 2. 0. 2. 2. 1. 2. 2. 2. 2. 1. 1. 0. 1. 1. 0. 2. 0. 2.
 2. 1. 0. 2. 2.]

Process finished with exit code 0

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/97387
推荐阅读
相关标签
  

闽ICP备14008679号