当前位置:   article > 正文

CatBoost模型Python代码——用CatBoost模型实现机器学习

CatBoost模型Python代码——用CatBoost模型实现机器学习

一、CatBoost模型简介

1.1适用范围

CatBoost(Categorical Boosting)是一种基于梯度提升的机器学习算法,特别适用于处理具有类别特征的数据集。它可以用于分类、回归和排序任务,并且在处理具有大量类别特征的数据时表现优异。典型应用包括但不限于:

  • 电子商务中的推荐系统
  • 客户行为分析
  • 财务风险评估
  • 医疗数据分析
1.2原理

CatBoost使用梯度提升决策树(GBDT)作为其核心算法。其主要特点包括:

  1. 处理类别特征:CatBoost原生支持类别特征,并在内部使用目标编码(target encoding)来处理它们,从而减少了类别变量处理的复杂性。
  2. 顺序增强(Ordered Boosting):在构建每棵树时,CatBoost通过引入一种新的顺序提升方法来避免传统梯度提升中的预测偏差问题。
  3. 随机分片:为了进一步减少过拟合,CatBoost在每次树构建时随机分割数据集。
1.3优点
  • 高效处理类别特征:无需复杂的预处理步骤。
  • 减少过拟合:通过顺序增强和随机分片技术。
  • 易于使用:内置了许多默认的优化参数,适合初学者和快速原型开发。
  • 高性能:在许多实际应用中表现优于其他GBDT算法(如XGBoost和LightGBM)。
1.4缺点
  • 模型训练时间较长:尽管有许多优化,训练时间可能比其他简单模型更长。
  • 内存占用较高:在处理大规模数据时,内存需求较大。

二、实现CatBoost模型的Python代码

下面是一个使用CatBoost进行分类任务的完整Python代码示例,包含详细注释。

2.1导入必要的包和测试数据
  1. import pandas as pd
  2. from catboost import CatBoostClassifier, Pool
  3. from sklearn.model_selection import train_test_split
  4. from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_curve, auc
  5. import matplotlib.pyplot as plt
  6. import seaborn as sns
  7. # 加载Titanic数据集
  8. url = 'https://web.stanford.edu/class/archive/cs/cs109/cs109.1166/stuff/titanic.csv'
  9. data = pd.read_csv(url)
  10. # 查看数据集的列名
  11. print("Columns in the dataset:", data.columns)
2.2简单的数据预处理
  1. # 简单的数据预处理
  2. # 填充缺失值
  3. # data['Age'].fillna(data['Age'].median(), inplace=True)
  4. # data['Embarked'].fillna(data['Embarked'].mode()[0], inplace=True)
  5. # 将Sex和Embarked转换为类别型特征
  6. data['Sex'] = data['Sex'].astype('category')
  7. # data['Pclass'] = data['Pclass'].astype('Pclass')
  8. # 选择特征和目标
  9. features = ['Pclass', 'Sex', 'Age', 'Siblings/Spouses Aboard', 'Parents/Children Aboard', 'Fare']
  10. target = 'Survived'
  11. X = data[features]
  12. y = data[target]
2.3构建CatBoost模型
  1. # 分割数据集为训练集和测试集
  2. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  3. # 创建CatBoost数据池
  4. categorical_features = ['Sex', 'Pclass']
  5. train_pool = Pool(X_train, y_train, cat_features=categorical_features)
  6. test_pool = Pool(X_test, y_test, cat_features=categorical_features)
  7. # 初始化并训练CatBoost分类器
  8. model = CatBoostClassifier(
  9. iterations=1000,
  10. learning_rate=0.1,
  11. depth=6,
  12. loss_function='Logloss', # 二分类任务使用'Logloss'
  13. verbose=100 # 每100次迭代打印一次信息
  14. )
  15. # 训练模型
  16. model.fit(train_pool)
  17. # 在测试集上进行预测
  18. y_pred = model.predict(test_pool)
  19. y_pred_proba = model.predict_proba(test_pool)[:, 1]
2.4模型评估
  1. # 评估模型
  2. accuracy = accuracy_score(y_test, y_pred)
  3. print(f'Accuracy: {accuracy}')
  4. print(classification_report(y_test, y_pred))

模型评估输出结果如下 :

  1. 0: learn: 0.6538633 total: 159ms remaining: 2m 39s
  2. 100: learn: 0.2814504 total: 891ms remaining: 7.93s
  3. 200: learn: 0.2007734 total: 1.68s remaining: 6.68s
  4. 300: learn: 0.1536222 total: 2.45s remaining: 5.69s
  5. 400: learn: 0.1220845 total: 3.19s remaining: 4.77s
  6. 500: learn: 0.0961718 total: 3.95s remaining: 3.93s
  7. 600: learn: 0.0810769 total: 4.7s remaining: 3.12s
  8. 700: learn: 0.0694396 total: 5.45s remaining: 2.33s
  9. 800: learn: 0.0598153 total: 6.2s remaining: 1.54s
  10. 900: learn: 0.0527771 total: 6.93s remaining: 761ms
  11. 999: learn: 0.0474017 total: 7.67s remaining: 0us
  12. Accuracy: 0.8033707865168539
  13. precision recall f1-score support
  14. 0 0.84 0.85 0.84 111
  15. 1 0.74 0.73 0.74 67
  16. accuracy 0.80 178
  17. macro avg 0.79 0.79 0.79 178
  18. weighted avg 0.80 0.80 0.80 178
  19. Feature: Pclass, Importance: 16.480181005946406
  20. Feature: Sex, Importance: 24.322199798316337
  21. Feature: Age, Importance: 27.28642174968946
  22. Feature: Siblings/Spouses Aboard, Importance: 5.125530737270014
  23. Feature: Parents/Children Aboard, Importance: 3.006729091175773
  24. Feature: Fare, Importance: 23.77893761760206
2.5可视化特征重要性(可选)
  1. # 可视化特征重要性(可选)
  2. plt.figure(figsize=(10, 6))
  3. plt.barh(X.columns, feature_importances)
  4. plt.xlabel('Feature Importance')
  5. plt.title('CatBoost Feature Importances')
  6. plt.show()

特征重要性输出结果如下:

 2.6绘制混淆矩阵
  1. # 绘制混淆矩阵
  2. conf_matrix = confusion_matrix(y_test, y_pred)
  3. plt.figure(figsize=(8, 6))
  4. sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
  5. plt.xlabel('Predicted')
  6. plt.ylabel('Actual')
  7. plt.title('Confusion Matrix')
  8. plt.show()

绘制混淆矩阵输出结果如下:

2.7绘制ROC曲线
  1. # 绘制ROC曲线
  2. fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
  3. roc_auc = auc(fpr, tpr)
  4. plt.figure(figsize=(8, 6))
  5. plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
  6. plt.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--')
  7. plt.xlim([0.0, 1.0])
  8. plt.ylim([0.0, 1.05])
  9. plt.xlabel('False Positive Rate')
  10. plt.ylabel('True Positive Rate')
  11. plt.title('Receiver Operating Characteristic (ROC) Curve')
  12. plt.legend(loc='lower right')
  13. plt.show()

绘制ROC曲线输出结果如下:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Li_阴宅/article/detail/871121
推荐阅读
相关标签
  

闽ICP备14008679号