当前位置:   article > 正文

机器学习算法--python实现随机森林回归_随机森林 fit函数

随机森林 fit函数

随机森林算法是一种组合了多个决策树的技术。 由于随机性,随机森林通常比单个决策树具有更好的泛化性能,这有助于减少模型的方差。随机森林的其他优点还包括它对数据集中的离群值不敏感, 而且也不需要太多的参数优化。回归随机森林算法是用MSE准则来培育每棵决策树,并用决策树平均预测值来计算预测的目标变量。

from sklearn.ensemble import RandomForestRegressor
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error

df = pd.read_csv('xxx\\housing.data.txt',
                 header=None,
                 sep='\s+')

df.columns = ['CRIM', 'ZN', 'INDUS', 'CHAS',
              'NOX', 'RM', 'AGE', 'DIS', 'RAD',
              'TAX', 'PTRATIO', 'B', 'LSTAT', 'MEDV']
print(df.head())

X = df.iloc[:, :-1].values
y = df['MEDV'].values

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.4, random_state=1)

# criterion :
# 回归树衡量分枝质量的指标,支持的标准有三种:
# 1)输入"mse"使用均方误差mean squared error(MSE),父节点和叶子节点之间的均方误差的差额将被用来作为特征选择的标准,
# 这种方法通过使用叶子节点的均值来最小化L2损失
# 2)输入“friedman_mse”使用费尔德曼均方误差,这种指标使用弗里德曼针对潜在分枝中的问题改进后的均方误差
# 3)输入"mae"使用绝对平均误差MAE(mean absolute error),这种指标使用叶节点的中值来最小化L1损失
forest = RandomForestRegressor(n_estimators=1000,
                               criterion='mse',
                               random_state=1,
                               n_jobs=-1)
forest.fit(X_train, y_train)
y_train_pred = forest.predict(X_train)
y_test_pred = forest.predict(X_test)

print('MSE train: %.3f, test: %.3f' % (
        mean_squared_error(y_train, y_train_pred),
        mean_squared_error(y_test, y_test_pred)))
print('R^2 train: %.3f, test: %.3f' % (
        r2_score(y_train, y_train_pred),
        r2_score(y_test, y_test_pred)))

# 绘制残差图
# 残差分布似乎并不是围绕零中心点完全随机,这表明该模型不能捕捉所有的探索性信息
plt.scatter(y_train_pred,
            y_train_pred - y_train,
            c='steelblue',
            edgecolor='white',
            marker='o',
            s=35,
            alpha=0.9,
            label='training data')
plt.scatter(y_test_pred,
            y_test_pred - y_test,
            c='limegreen',
            edgecolor='white',
            marker='s',
            s=35,
            alpha=0.9,
            label='test data')

plt.xlabel('Predicted values')
plt.ylabel('Residuals')
plt.legend(loc='upper left')
plt.hlines(y=0, xmin=-10, xmax=50, lw=2, color='black')
plt.xlim([-10, 50])
plt.tight_layout()

# plt.savefig('images/10_14.png', dpi=300)
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72

运行结果:
CRIM ZN INDUS CHAS NOX … TAX PTRATIO B LSTAT MEDV
0 0.00632 18.0 2.31 0 0.538 … 296.0 15.3 396.90 4.98 24.0
1 0.02731 0.0 7.07 0 0.469 … 242.0 17.8 396.90 9.14 21.6
2 0.02729 0.0 7.07 0 0.469 … 242.0 17.8 392.83 4.03 34.7
3 0.03237 0.0 2.18 0 0.458 … 222.0 18.7 394.63 2.94 33.4
4 0.06905 0.0 2.18 0 0.458 … 222.0 18.7 396.90 5.33 36.2

[5 rows x 14 columns]
MSE train: 1.644, test: 11.085
R^2 train: 0.979, test: 0.877

运行结果图:
在这里插入图片描述

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号