当前位置:   article > 正文

xgboost 多分类(六段age predict)_xgb多分类

xgb多分类

1. 相关包导入

# -*- coding: utf-8 -*-
import numpy as np
import xgboost as xgb
from xgboost import plot_importance
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn import metrics
from matplotlib import pyplot as plt
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

2. 数据预处理
2.1 加载数据

# 数据文件不要带表头
data = pd.read_csv('../data/data.csv', header=None, sep=',')
data.columns = all_names  # 赋值表头
X = data.loc[:, feature_names]  # hive中以int存储,则此处读出来也是int,不需要转换
Y = data.loc[:, 'monthly_income'] - 1  # 多标签从0开始
  • 1
  • 2
  • 3
  • 4
  • 5

2.1 数据处理

# 1.load data
data = np.loadtxt('/data/zz/age_predict/data.txt', delimiter=',')
data_num, feature_num = data.shape
print("data_num:   ", data_num)
print("feature_num:   ", feature_num)
# 2.shuffle data
# data = data.sample(frac=1, random_state=1024)
rng = np.random.RandomState(2021)
index = list(range(data_num))
rng.shuffle(index)
data = data[index]
# 3.split data
X, Y = data[:, 0:feature_num-1], data[:, feature_num-1]
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3, random_state=0)
# 4.transmfer data
xg_train = xgb.DMatrix(X_train, label=y_train)
xg_test = xgb.DMatrix(X_test, label=y_test)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

查看数据集切分后,测试集分布是否变化(一般不会变化)

test_data_new = pd.concat([X_test, y_test], axis=1, ignore_index=True)
test_data_new.columns = feature_names + ["label"]
print(test_data_new.loc[:, 'label'].value_counts())
  • 1
  • 2
  • 3

3. 模型训练及预测

params = {
    'booster': 'gbtree',
    'objective': 'multi:softmax',
    'num_class': 6, 
    'learning_rate': 0.2,
    'gamma': 0.1,
    'max_depth': 8,
    'lambda': 2,
    'subsample': 0.85,
    'colsample_bytree': 0.85,
    'min_child_weight': 3,
    'silent': 1,
    'eta': 0.05,
    'seed': 1000,
    'nthread': 4,
}

num_round = 20
watchlist = [(xg_train, 'train'), (xg_test, 'test')]
bst = xgb.train(params, xg_train, num_round, watchlist)
pred = bst.predict(xg_test)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

4. 模型评估

print('predicting, classification error=%f'
       % (sum(int(pred[i]) != y_test[i] for i in range(len(y_test))) / float(len(y_test))))

print('Accuracy: %.4f' % metrics.accuracy_score(y_test, pred))
print(metrics.confusion_matrix(y_test, pred))
  • 1
  • 2
  • 3
  • 4
  • 5

5. 重要特征打印

# 打印特征重要度
# plot_importance(bst)
# plt.show()
importance = bst.get_score(importance_type='gain')
sorted_importance = sorted(importance.items(), key=lambda x: x[1], reverse=True)
print('feature importances[gain]: ', sorted_importance)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/97377
推荐阅读
相关标签
  

闽ICP备14008679号