当前位置:   article > 正文

xgboost 实现多分类问题demo以及原理_xgboost demo

xgboost demo

本文先把xgboost支持的多分类问题的demo写起来,打印出生成的树结构,然后理解xgboost实现多分类问题的原理。这个顺序比较好理解一些。

xgboost 多分类问题 demo

这个demo从xgboost的源代码中就可以看到。在这个位置:/demo/multiclass_classification/train.py。train.py文件里的数据(dermatology.data)可以在  https://archive.ics.uci.edu/ml/machine-learning-databases/dermatology/dermatology.data  这个网址下载。下载下来的文件的后缀是.data,改成  .csv  或者  .txt  就可以直接用了。我把数据改成了  'data.txt'  。

现在来看看train.py 里的代码吧~

我把代码直接写在下边:这份数据的标签有6类,下边的代码我设置迭代了2轮。

  1. import numpy as np
  2. import xgboost as xgb
  3. # label need to be 0 to num_class -1
  4. data = np.loadtxt('data.txt', delimiter='\t',
  5. converters={33: lambda x:int(x == '?'), 34: lambda x:int(x) - 1})
  6. sz = data.shape
  7. train = data[:int(sz[0] * 0.7), :]
  8. test = data[int(sz[0] * 0.7):, :]
  9. train_X = train[:, :33]
  10. train_Y = train[:, 34]
  11. test_X = test[:, :33]
  12. test_Y = test[:, 34]
  13. xg_train = xgb.DMatrix(train_X, label=train_Y)
  14. xg_test = xgb.DMatrix(test_X, label=test_Y)
  15. # setup parameters for xgboost
  16. param = {}
  17. # use softmax multi-class classification
  18. param['objective'] = 'multi:softmax'
  19. # scale weight of positive examples
  20. param['eta'] = 0.1
  21. param['max_depth'] = 6
  22. param['silent'] = 1
  23. param['nthread'] = 4
  24. param['num_class'] = 6
  25. watchlist = [(xg_train, 'train'), (xg_test, 'test')]
  26. num_round = 2 # 轮数设置成2轮
  27. bst = xgb.train(param, xg_train, num_round, watchlist)
  28. # get prediction
  29. pred = bst.predict(xg_test)
  30. error_rate = np.sum(pred != test_Y) / test_Y.shape[0]
  31. print('Test error using softmax = {}'.format(error_rate))

xgboost 多分类问题实现原理

训练完之后,关键的一步是把训练好的树打印出来便于查看,下面的代码可以把树结构存为文本形式。我觉得文本形式比图形式好看很多。

bst.dump_model('multiclass_model')

然后我们打开这个文件。其中的每一个booster代表一棵树,这个模型一共有12棵树,booster从0到11。

  1. booster[0]:
  2. 0:[f19<0.5] yes=1,no=2,missing=1
  3. 1:[f21<0.5] yes=3,no=4,missing=3
  4. 3:leaf=-0.0587906
  5. 4:leaf=0.0906977
  6. 2:[f6<0.5] yes=5,no=6,missing=5
  7. 5:leaf=0.285523
  8. 6:leaf=0.0906977
  9. booster[1]:
  10. 0:[f27<1.5] yes=1,no=2,missing=1
  11. 1:[f12<0.5] yes=3,no=4,missing=3
  12. 3:[f31<0.5] yes=7,no=8,missing=7
  13. 7:leaf=-1.67638e-09
  14. 8:leaf=-0.056044
  15. 4:[f4<0.5] yes=9,no=10,missing=9
  16. 9:leaf=0.132558
  17. 10:leaf=-0.0315789
  18. 2:[f4<0.5] yes=5,no=6,missing=5
  19. 5:[f11<0.5] yes=11,no=12,missing=11
  20. 11:[f10<0.5] yes=15,no=16,missing=15
  21. 15:leaf=0.264427
  22. 16:leaf=0.0631579
  23. 12:leaf=-0.0428571
  24. 6:[f15<1.5] yes=13,no=14,missing=13
  25. 13:leaf=-0.00566038
  26. 14:leaf=-0.0539326
  27. booster[2]:
  28. 0:[f32<1.5] yes=1,no=2,missing=1
  29. 1:leaf=-0.0589339
  30. 2:[f9<0.5] yes=3,no=4,missing=3
  31. 3:leaf=0.280919
  32. 4:leaf=0.0631579
  33. booster[3]:
  34. 0:[f4<0.5] yes=1,no=2,missing=1
  35. 1:[f0<1.5] yes=3,no=4,missing=3
  36. 3:[f3<0.5] yes=7,no=8,missing=7
  37. 7:[f27<0.5] yes=13,no=14,missing=13
  38. 13:leaf=-0.0375
  39. 14:leaf=0.0631579
  40. 8:leaf=-0.0515625
  41. 4:leaf=-0.058371
  42. 2:[f2<1.5] yes=5,no=6,missing=5
  43. 5:[f32<0.5] yes=9,no=10,missing=9
  44. 9:[f15<0.5] yes=15,no=16,missing=15
  45. 15:leaf=-0.0348837
  46. 16:leaf=0.230097
  47. 10:leaf=-0.0428571
  48. 6:[f3<0.5] yes=11,no=12,missing=11
  49. 11:leaf=0.0622641
  50. 12:[f16<1.5] yes=17,no=18,missing=17
  51. 17:leaf=-1.67638e-09
  52. 18:[f3<1.5] yes=19,no=20,missing=19
  53. 19:leaf=-0.00566038
  54. 20:leaf=-0.0554622
  55. booster[4]:
  56. 0:[f14<0.5] yes=1,no=2,missing=1
  57. 1:leaf=-0.0590296
  58. 2:leaf=0.255665
  59. booster[5]:
  60. 0:[f30<0.5] yes=1,no=2,missing=1
  61. 1:leaf=-0.0591241
  62. 2:leaf=0.213253
  63. booster[6]:
  64. 0:[f19<0.5] yes=1,no=2,missing=1
  65. 1:[f21<0.5] yes=3,no=4,missing=3
  66. 3:leaf=-0.0580493
  67. 4:leaf=0.0831786
  68. 2:leaf=0.214441
  69. booster[7]:
  70. 0:[f27<1.5] yes=1,no=2,missing=1
  71. 1:[f12<0.5] yes=3,no=4,missing=3
  72. 3:[f31<0.5] yes=7,no=8,missing=7
  73. 7:leaf=0.000227226
  74. 8:leaf=-0.0551713
  75. 4:[f15<1.5] yes=9,no=10,missing=9
  76. 9:leaf=-0.0314418
  77. 10:leaf=0.121289
  78. 2:[f4<0.5] yes=5,no=6,missing=5
  79. 5:[f11<0.5] yes=11,no=12,missing=11
  80. 11:[f10<0.5] yes=15,no=16,missing=15
  81. 15:leaf=0.206326
  82. 16:leaf=0.0587528
  83. 12:leaf=-0.0420568
  84. 6:[f15<1.5] yes=13,no=14,missing=13
  85. 13:leaf=-0.00512865
  86. 14:leaf=-0.0531389
  87. booster[8]:
  88. 0:[f32<1.5] yes=1,no=2,missing=1
  89. 1:leaf=-0.0581933
  90. 2:[f11<0.5] yes=3,no=4,missing=3
  91. 3:leaf=0.0549185
  92. 4:leaf=0.218241
  93. booster[9]:
  94. 0:[f4<0.5] yes=1,no=2,missing=1
  95. 1:[f0<1.5] yes=3,no=4,missing=3
  96. 3:[f3<0.5] yes=7,no=8,missing=7
  97. 7:[f27<0.5] yes=13,no=14,missing=13
  98. 13:leaf=-0.0367718
  99. 14:leaf=0.0600201
  100. 8:leaf=-0.0506891
  101. 4:leaf=-0.0576147
  102. 2:[f27<0.5] yes=5,no=6,missing=5
  103. 5:[f3<0.5] yes=9,no=10,missing=9
  104. 9:leaf=0.0238016
  105. 10:leaf=-0.054874
  106. 6:[f5<1] yes=11,no=12,missing=11
  107. 11:leaf=0.200442
  108. 12:leaf=-0.0508502
  109. booster[10]:
  110. 0:[f14<0.5] yes=1,no=2,missing=1
  111. 1:leaf=-0.058279
  112. 2:leaf=0.201977
  113. booster[11]:
  114. 0:[f30<0.5] yes=1,no=2,missing=1
  115. 1:leaf=-0.0583675
  116. 2:leaf=0.178016

不要忘记我们是6分类问题,训练轮数 num_round 设置成了2。在这12棵树中,第一轮有6棵树,对应  booster0-booster5,第二轮有6棵树,对应  booster6-booster11。第二轮的第 n 棵树 在 第一轮的第 n 棵树基础上再学习,然后两棵树的结果加在一起,再经softmax函数,就得到了预测为第 n 类的概率。那么到这里应该知道xgboost是怎么训练多分类问题的了吧,其实在每一轮,都要训练6棵树。有不明白softmax的可以参考 多分类问题的softmax函数以及损失函数推导

至于xgboost多分类问题的公式推导没有在这篇博客里写,打算专写一篇来讲公式推导~

到此就结束啦,有不对的地方欢迎各位大佬留言~

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/97393
推荐阅读
相关标签
  

闽ICP备14008679号