赞
踩
本文先把xgboost支持的多分类问题的demo写起来,打印出生成的树结构,然后理解xgboost实现多分类问题的原理。这个顺序比较好理解一些。
这个demo从xgboost的源代码中就可以看到。在这个位置:/demo/multiclass_classification/train.py。train.py文件里的数据(dermatology.data)可以在 https://archive.ics.uci.edu/ml/machine-learning-databases/dermatology/dermatology.data 这个网址下载。下载下来的文件的后缀是.data,改成 .csv 或者 .txt 就可以直接用了。我把数据改成了 'data.txt' 。
现在来看看train.py 里的代码吧~
我把代码直接写在下边:这份数据的标签有6类,下边的代码我设置迭代了2轮。
- import numpy as np
- import xgboost as xgb
-
- # label need to be 0 to num_class -1
- data = np.loadtxt('data.txt', delimiter='\t',
- converters={33: lambda x:int(x == '?'), 34: lambda x:int(x) - 1})
- sz = data.shape
-
- train = data[:int(sz[0] * 0.7), :]
- test = data[int(sz[0] * 0.7):, :]
-
- train_X = train[:, :33]
- train_Y = train[:, 34]
-
- test_X = test[:, :33]
- test_Y = test[:, 34]
-
- xg_train = xgb.DMatrix(train_X, label=train_Y)
- xg_test = xgb.DMatrix(test_X, label=test_Y)
- # setup parameters for xgboost
- param = {}
- # use softmax multi-class classification
- param['objective'] = 'multi:softmax'
- # scale weight of positive examples
- param['eta'] = 0.1
- param['max_depth'] = 6
- param['silent'] = 1
- param['nthread'] = 4
- param['num_class'] = 6
-
- watchlist = [(xg_train, 'train'), (xg_test, 'test')]
- num_round = 2 # 轮数设置成2轮
- bst = xgb.train(param, xg_train, num_round, watchlist)
- # get prediction
- pred = bst.predict(xg_test)
- error_rate = np.sum(pred != test_Y) / test_Y.shape[0]
- print('Test error using softmax = {}'.format(error_rate))
训练完之后,关键的一步是把训练好的树打印出来便于查看,下面的代码可以把树结构存为文本形式。我觉得文本形式比图形式好看很多。
bst.dump_model('multiclass_model')
然后我们打开这个文件。其中的每一个booster代表一棵树,这个模型一共有12棵树,booster从0到11。
- booster[0]:
- 0:[f19<0.5] yes=1,no=2,missing=1
- 1:[f21<0.5] yes=3,no=4,missing=3
- 3:leaf=-0.0587906
- 4:leaf=0.0906977
- 2:[f6<0.5] yes=5,no=6,missing=5
- 5:leaf=0.285523
- 6:leaf=0.0906977
- booster[1]:
- 0:[f27<1.5] yes=1,no=2,missing=1
- 1:[f12<0.5] yes=3,no=4,missing=3
- 3:[f31<0.5] yes=7,no=8,missing=7
- 7:leaf=-1.67638e-09
- 8:leaf=-0.056044
- 4:[f4<0.5] yes=9,no=10,missing=9
- 9:leaf=0.132558
- 10:leaf=-0.0315789
- 2:[f4<0.5] yes=5,no=6,missing=5
- 5:[f11<0.5] yes=11,no=12,missing=11
- 11:[f10<0.5] yes=15,no=16,missing=15
- 15:leaf=0.264427
- 16:leaf=0.0631579
- 12:leaf=-0.0428571
- 6:[f15<1.5] yes=13,no=14,missing=13
- 13:leaf=-0.00566038
- 14:leaf=-0.0539326
- booster[2]:
- 0:[f32<1.5] yes=1,no=2,missing=1
- 1:leaf=-0.0589339
- 2:[f9<0.5] yes=3,no=4,missing=3
- 3:leaf=0.280919
- 4:leaf=0.0631579
- booster[3]:
- 0:[f4<0.5] yes=1,no=2,missing=1
- 1:[f0<1.5] yes=3,no=4,missing=3
- 3:[f3<0.5] yes=7,no=8,missing=7
- 7:[f27<0.5] yes=13,no=14,missing=13
- 13:leaf=-0.0375
- 14:leaf=0.0631579
- 8:leaf=-0.0515625
- 4:leaf=-0.058371
- 2:[f2<1.5] yes=5,no=6,missing=5
- 5:[f32<0.5] yes=9,no=10,missing=9
- 9:[f15<0.5] yes=15,no=16,missing=15
- 15:leaf=-0.0348837
- 16:leaf=0.230097
- 10:leaf=-0.0428571
- 6:[f3<0.5] yes=11,no=12,missing=11
- 11:leaf=0.0622641
- 12:[f16<1.5] yes=17,no=18,missing=17
- 17:leaf=-1.67638e-09
- 18:[f3<1.5] yes=19,no=20,missing=19
- 19:leaf=-0.00566038
- 20:leaf=-0.0554622
- booster[4]:
- 0:[f14<0.5] yes=1,no=2,missing=1
- 1:leaf=-0.0590296
- 2:leaf=0.255665
- booster[5]:
- 0:[f30<0.5] yes=1,no=2,missing=1
- 1:leaf=-0.0591241
- 2:leaf=0.213253
- booster[6]:
- 0:[f19<0.5] yes=1,no=2,missing=1
- 1:[f21<0.5] yes=3,no=4,missing=3
- 3:leaf=-0.0580493
- 4:leaf=0.0831786
- 2:leaf=0.214441
- booster[7]:
- 0:[f27<1.5] yes=1,no=2,missing=1
- 1:[f12<0.5] yes=3,no=4,missing=3
- 3:[f31<0.5] yes=7,no=8,missing=7
- 7:leaf=0.000227226
- 8:leaf=-0.0551713
- 4:[f15<1.5] yes=9,no=10,missing=9
- 9:leaf=-0.0314418
- 10:leaf=0.121289
- 2:[f4<0.5] yes=5,no=6,missing=5
- 5:[f11<0.5] yes=11,no=12,missing=11
- 11:[f10<0.5] yes=15,no=16,missing=15
- 15:leaf=0.206326
- 16:leaf=0.0587528
- 12:leaf=-0.0420568
- 6:[f15<1.5] yes=13,no=14,missing=13
- 13:leaf=-0.00512865
- 14:leaf=-0.0531389
- booster[8]:
- 0:[f32<1.5] yes=1,no=2,missing=1
- 1:leaf=-0.0581933
- 2:[f11<0.5] yes=3,no=4,missing=3
- 3:leaf=0.0549185
- 4:leaf=0.218241
- booster[9]:
- 0:[f4<0.5] yes=1,no=2,missing=1
- 1:[f0<1.5] yes=3,no=4,missing=3
- 3:[f3<0.5] yes=7,no=8,missing=7
- 7:[f27<0.5] yes=13,no=14,missing=13
- 13:leaf=-0.0367718
- 14:leaf=0.0600201
- 8:leaf=-0.0506891
- 4:leaf=-0.0576147
- 2:[f27<0.5] yes=5,no=6,missing=5
- 5:[f3<0.5] yes=9,no=10,missing=9
- 9:leaf=0.0238016
- 10:leaf=-0.054874
- 6:[f5<1] yes=11,no=12,missing=11
- 11:leaf=0.200442
- 12:leaf=-0.0508502
- booster[10]:
- 0:[f14<0.5] yes=1,no=2,missing=1
- 1:leaf=-0.058279
- 2:leaf=0.201977
- booster[11]:
- 0:[f30<0.5] yes=1,no=2,missing=1
- 1:leaf=-0.0583675
- 2:leaf=0.178016
不要忘记我们是6分类问题,训练轮数 num_round 设置成了2。在这12棵树中,第一轮有6棵树,对应 booster0-booster5,第二轮有6棵树,对应 booster6-booster11。第二轮的第 n 棵树 在 第一轮的第 n 棵树基础上再学习,然后两棵树的结果加在一起,再经softmax函数,就得到了预测为第 n 类的概率。那么到这里应该知道xgboost是怎么训练多分类问题的了吧,其实在每一轮,都要训练6棵树。有不明白softmax的可以参考 多分类问题的softmax函数以及损失函数推导。
至于xgboost多分类问题的公式推导没有在这篇博客里写,打算专写一篇来讲公式推导~
到此就结束啦,有不对的地方欢迎各位大佬留言~
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。