当前位置:   article > 正文

【机器学习】01. python随机森林实现回归,相关性分析,特征重要性分析_随机森林相关性分析

随机森林相关性分析

背景:有个关于回归的任务,因保护客户数据资料,用鸢尾花数据集代替,完成随机森林算法实现部分功能。

完整代码在最后

1. 加载数据集

  1. from sklearn.datasets import load_iris
  2. from sklearn.model_selection import train_test_split
  3. from sklearn.ensemble import RandomForestRegressor
  4. from sklearn.metrics import mean_squared_error, r2_score
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import seaborn as sns
  8. import joblib
  9. # 加载示例数据集
  10. iris = load_iris()
  11. X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
  12. print(iris.DESCR)

此时会显示当前数据的部分相关描述

 :Summary Statistics:

    ============== ==== ==== ======= ===== ====================
                    Min  Max   Mean    SD   Class Correlation
    ============== ==== ==== ======= ===== ====================
    sepal length:   4.3  7.9   5.84   0.83    0.7826
    sepal width:    2.0  4.4   3.05   0.43   -0.4194
    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)
    ============== ==== ==== ======= ===== ====================

    :Missing Attribute Values: None
    :Class Distribution: 33.3% for each of 3 classes.
    :Creator: R.A. Fisher
    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
    :Date: July, 1988

2. 输出数据特征之间的相关性矩阵        

  1. # 输出特征之间的相关性矩阵
  2. correlation_matrix = np.corrcoef(X_train, rowvar=False)
  3. # 使用热图可视化相关性矩阵
  4. plt.figure(figsize=(10, 8))
  5. sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', xticklabels=feature_names, yticklabels=feature_names)
  6. plt.title('Correlation Matrix of Iris Features')
  7. plt.show()

3. 训练模型并保存joblib文件

  1. # 创建随机森林模型
  2. rf_model = RandomForestRegressor(n_estimators=100, random_state=42)
  3. # 训练模型
  4. rf_model.fit(X_train, y_train)
  5. # 保存模型
  6. joblib.dump(rf_model, 'random_forest_model.joblib')

4. 加载模型并预测输出均方误差和R方评估指标

  1. # 加载模型
  2. loaded_model = joblib.load('random_forest_model.joblib')
  3. # 使用加载的模型进行预测
  4. y_pred = loaded_model.predict(X_test)
  5. # 评估模型性能
  6. mse = mean_squared_error(y_test, y_pred)
  7. r2 = r2_score(y_test, y_pred)
  8. print(f'Mean Squared Error: {mse}')
  9. print(f'R-squared: {r2}')

Mean Squared Error: 0.0013833333333333336
R-squared: 0.9980206677265501

5. 特征重要性分析

  1. # 输出特征的重要性
  2. feature_importances = loaded_model.feature_importances_
  3. print('Feature Importances:')
  4. for i, importance in enumerate(feature_importances):
  5. print(f'Feature {i+1}: {importance}')
  6. # 将特征重要性进行可视化
  7. plt.figure(figsize=(10, 6))
  8. sorted_idx = np.argsort(feature_importances)[::-1] # 反向排序
  9. plt.bar(list(range(len(feature_importances))), feature_importances[sorted_idx], align='center')
  10. plt.xticks(list(range(len(feature_importances))), np.array(feature_names)[sorted_idx], rotation=0)
  11. plt.xlabel('Feature')
  12. plt.ylabel('Importance Score')
  13. plt.title('Feature Importance Scores')
  14. plt.show()

Feature Importances:
Feature 1: 0.007247638926907056
Feature 2: 0.01241623468021743
Feature 3: 0.4956256973314748
Feature 4: 0.48471042906140077

6. 完整代码

  1. from sklearn.datasets import load_iris
  2. from sklearn.model_selection import train_test_split
  3. from sklearn.ensemble import RandomForestRegressor
  4. from sklearn.metrics import mean_squared_error, r2_score
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import seaborn as sns
  8. import joblib
  9. # 加载示例数据集
  10. iris = load_iris()
  11. X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
  12. # print(iris.DESCR)
  13. # 创建随机森林模型
  14. rf_model = RandomForestRegressor(n_estimators=100, random_state=42)
  15. # 训练模型
  16. rf_model.fit(X_train, y_train)
  17. # 保存模型
  18. joblib.dump(rf_model, 'random_forest_model.joblib')
  19. # 加载模型
  20. loaded_model = joblib.load('random_forest_model.joblib')
  21. # 使用加载的模型进行预测
  22. y_pred = loaded_model.predict(X_test)
  23. # 评估模型性能
  24. mse = mean_squared_error(y_test, y_pred)
  25. r2 = r2_score(y_test, y_pred)
  26. print(f'Mean Squared Error: {mse}')
  27. print(f'R-squared: {r2}')
  28. feature_names = iris.feature_names
  29. # 输出特征之间的相关性矩阵
  30. correlation_matrix = np.corrcoef(X_train, rowvar=False)
  31. # 使用热图可视化相关性矩阵
  32. plt.figure(figsize=(10, 8))
  33. sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', xticklabels=feature_names, yticklabels=feature_names)
  34. plt.title('Correlation Matrix of Iris Features')
  35. plt.show()
  36. # 输出特征的重要性
  37. feature_importances = loaded_model.feature_importances_
  38. print('Feature Importances:')
  39. for i, importance in enumerate(feature_importances):
  40. print(f'Feature {i+1}: {importance}')
  41. # 将特征重要性进行可视化
  42. plt.figure(figsize=(10, 6))
  43. sorted_idx = np.argsort(feature_importances)[::-1] # 反向排序
  44. plt.bar(list(range(len(feature_importances))), feature_importances[sorted_idx], align='center')
  45. plt.xticks(list(range(len(feature_importances))), np.array(feature_names)[sorted_idx], rotation=0)
  46. plt.xlabel('Feature')
  47. plt.ylabel('Importance Score')
  48. plt.title('Feature Importance Scores')
  49. plt.show()

后续还可以添加一些寻优逻辑,比如网格搜索,交叉验证等。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/97621
推荐阅读
相关标签
  

闽ICP备14008679号