当前位置:   article > 正文

天池机器学习训练营(一) —— 基于逻辑回归的分类预测_天池分类预测

天池分类预测

引言

本节主要介绍了逻辑回归方法以及其的应用,并且介绍了分析一个项目的过程与思路, 并以可视化的方式展现,重点在于理解逻辑回归以及其推广到多分类的过程,并学习可视化方式

文本链接

机器学习算法(一): 基于逻辑回归的分类预测

逻辑回归表达式

逻辑回归,分为logit方程和回归表达式两部分,其为二分类方法,其中logit方程表达式如下,也称为Sigmoid函数:
l o g i t ( x ) = 1 / 1 + e − z logit(x) = 1/ 1+e^{-z} logit(x)=1/1+ez
函数图像如下:
在这里插入图片描述
而回归表达式如下:
z ( x ) = Σ w i x i z(x) = Σw_ix_i z(x)=Σwixi
其中wi的数量取决于特征xi的数量
标准逻辑回归是二分类方程,将其推广到多分类任务时,scikit-learn用的方法是 One vs Rest (OVR) 方法,即每次讲一个类作为正例,其余类作为反例来训练N个分类器,在测试时若某一个分类器预测为正类,则对应的类别标记为最终分类结果。

可视化方式

本例用了许多可视化呈现方式,主要用到的库为 Matplotlib 和 Seaborn库
如下:

  1. 利用Seaborn画出两两特征影响下各个数据的分布情况
    在这里插入图片描述
  2. 画出各个特征的箱型分布图
    在这里插入图片描述
  3. 取三个特征画出三维分布图
    在这里插入图片描述
  4. 利用Seaborn算出混淆矩阵,并画出
    在这里插入图片描述
整体代码
'''
机器学习训练营
Task01: 基于逻辑回归的分类预测
Date: 2022-09-18
'''
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.linear_model import LogisticRegression

def demo1():
    ## 构造数据集
    x_fearures = np.array([[-1, -2], [-2, -1], [-3, -2], [1, 3], [2, 1], [3, 2]])
    y_label = np.array([0, 0, 0, 1, 1, 1])

    # 调用逻辑回归模型
    lr_cls = LogisticRegression()

    lr_cls = lr_cls.fit(x_fearures,y_label) #其拟合方程为 y=w0+w1*x1+w2*x2

    ## 查看其对应模型的w
    print('the weight of Logistic Regression:', lr_cls.coef_)

    ## 查看其对应模型的w0
    print('the intercept(w0) of Logistic Regression:', lr_cls.intercept_)

    # 可视化构造的数据样本点
    plt.figure()
    plt.scatter(x_fearures[:,0],x_fearures[:,1],c=y_label,s=50,cmap='viridis')
    plt.title('Dataset')

    # 可视化决策边界
    nx, ny = 200, 100
    xmin, xmax = plt.xlim()
    ymin, ymax = plt.ylim()
    x_grid, y_grid = np.meshgrid(np.linspace(xmin,xmax,nx),np.linspace(ymin,ymax,ny))
    print(np.array(x_grid).shape)
    print(np.array(y_grid).shape)
    z_probe = lr_cls.predict_proba(np.c_[x_grid.ravel(),y_grid.ravel()])
    z_probe = z_probe[:,1].reshape(x_grid.shape)
    plt.contour(x_grid,y_grid,z_probe,[0.5],linewidths=2,colors='blue')

    # 可视化预测新样本
    x_fearures_new1 = np.array([[0, -1]])
    plt.scatter(x_fearures_new1[:, 0], x_fearures_new1[:, 1], s=50, cmap='viridis')

    x_fearures_new2 = np.array([[1, 2]])
    plt.scatter(x_fearures_new2[:, 0], x_fearures_new2[:, 1], s=50, cmap='viridis')

    # 在训练集和测试集上分别利用训练好的模型进行预测
    y_label_new1 = lr_cls.predict(x_fearures_new1)
    y_label_new2 = lr_cls.predict(x_fearures_new2)
    print('The New point 1 predict class:\n', y_label_new1)
    print('The New point 2 predict class:\n', y_label_new2)

    y_proba_new1 = lr_cls.predict_proba(x_fearures_new1)
    y_proba_new2 = lr_cls.predict_proba(x_fearures_new2)
    print('The New point 1 predict Probability of each class:\n', y_proba_new1)
    print('The New point 2 predict Probability of each class:\n', y_proba_new2)

    plt.show()

def demo2():
    import pandas as pd
    from sklearn.datasets import load_iris
    data = load_iris()
    iris_label = data.target
    iris_features = pd.DataFrame(data=data.data, columns=data.feature_names)
    print(iris_features.head())
    print(iris_features.info())
    # 利用value_counts函数查看每个类别数量
    print(pd.Series(iris_label).value_counts())
    print(iris_features.describe())
    # 合并标签和特征信息
    iris_all = iris_features.copy()
    iris_all['target'] = iris_label
    # 特征与标签组合的散点可视化
    sns.pairplot(data = iris_all, diag_kind = 'hist', hue = 'target')   # search
    plt.show()
    for col in iris_features.columns:
        sns.boxplot(x='target', y=col, saturation=0.5, palette='pastel', data=iris_all) # search
        plt.title(col)
        plt.show()
    # 选取其前三个特征绘制三维散点图
    from mpl_toolkits.mplot3d import Axes3D     # search
    fig = plt.figure(figsize=(10,8))
    ax = fig.add_subplot(111, projection='3d')

    iris_all_class0 = iris_all[iris_all['target'] == 0].values
    iris_all_class1 = iris_all[iris_all['target'] == 1].values
    iris_all_class2 = iris_all[iris_all['target'] == 2].values
    print(iris_all_class0)
    # 提取前三个特征
    ax.scatter(iris_all_class0[:,0], iris_all_class0[:,1], iris_all_class0[:,2], label='setosa')
    ax.scatter(iris_all_class1[:, 0], iris_all_class1[:, 1], iris_all_class1[:, 2], label='versicolor')
    ax.scatter(iris_all_class2[:, 0], iris_all_class2[:, 1], iris_all_class2[:, 2], label='virginica')

    plt.legend()
    plt.show()

    # 开始搭建模型、训练
    from sklearn.model_selection import train_test_split

    iris_feature_part = iris_features.iloc[:100]
    iris_label_part = iris_label[:100]

    x_train,x_test, y_train,y_test = train_test_split(iris_feature_part, iris_label_part, test_size=0.2, random_state=2022)

    # 其拟合方程为 y= w0 + w1*x1 + w2*x2 + w3*x3 + w4*x4
    clf = LogisticRegression(random_state=0, solver='lbfgs')    # search

    # train
    clf.fit(x_train,y_train)

    # 查看其对应的w
    print('the weight of Logistic Regression:', clf.coef_)
    # 查看其对应的w0
    print('the intercept(w0) of Logistic Regression:', clf.intercept_)

    # predict
    train_predict = clf.predict(x_train)
    test_predict = clf.predict(x_test)

    # metrics
    from sklearn import metrics
    # 利用accuracy(准确度)【预测正确的样本数目占总预测样本数目的比例】评估模型效果
    print('The accuracy of the Logistic Regression is:', metrics.accuracy_score(y_train, train_predict))
    print('The accuracy of the Logistic Regression is:', metrics.accuracy_score(y_test, test_predict))
    # 查看混淆矩阵 (预测值和真实值的各类情况统计矩阵)
    confusion_matrix_result = metrics.confusion_matrix(test_predict, y_test)    # search
    print('The confusion matrix result:\n', confusion_matrix_result)

    # 利用热力图对于结果进行可视化
    plt.figure(figsize=(8,6))
    sns.heatmap(confusion_matrix_result,annot=True, cmap='Blues')       # search
    plt.xlabel('Predicted labels')
    plt.ylabel('True labels')
    plt.show()

    # 三分类训练和预测
    x_train, x_test, y_train, y_test = train_test_split(iris_features, iris_label, test_size=0.2, random_state=2022)
    cls = LogisticRegression(random_state=0, solver='lbfgs')
    # train
    cls.fit(x_train, y_train)
    # 查看其对应的w
    print('the weight of Logistic Regression:\n', cls.coef_)
    # 查看其对应的w0
    print('the intercept(w0) of Logistic Regression:\n', cls.intercept_)
    # 由于这个是3分类,所有我们这里得到了三个逻辑回归模型的参数,其三个逻辑回归组合起来即可实现三分类。 # 为什么是三组  方程是怎样的?

    train_predict = cls.predict(x_train)
    test_predict = cls.predict(x_test)

    print('The accuracy of the Three Class Logistic Regression is:', metrics.accuracy_score(y_train, train_predict))
    print('The accuracy of the Three Class Logistic Regression is:', metrics.accuracy_score(y_test, test_predict))

    # 查看混淆矩阵 (预测值和真实值的各类情况统计矩阵)
    confusion_matrix_result = metrics.confusion_matrix(y_test,test_predict)
    print('The confusion matrix result:\n', confusion_matrix_result)

    # heatmap
    fig = plt.figure(figsize=(9,6))
    sns.heatmap(confusion_matrix_result, annot = True, cmap='Blues')
    plt.xlabel('Predicted labels')
    plt.ylabel('True labels')
    plt.show()

if __name__ == "__main__":
    # demo1()
    demo2()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/199143
推荐阅读
相关标签
  

闽ICP备14008679号