当前位置:   article > 正文

基于逻辑回归的分类预测

基于逻辑回归的分类预测

1.逻辑回归简介

​ 逻辑回归(Logistic regression,简称LR)虽然其中带有"回归"两个字,但逻辑回归其实是一个分类模型,并且广泛应用于各个领域之中。

​ 将线性回归模型得到的结果通过一个非线性的sigmoid函数,得到[0,1]之间取值范围的值同时设置阈值为0.5,通过与阈值的比较达到二分类的效果,即为逻辑回归模型。

2.Sigmoid函数简介

2.1 逻辑斯特函数的由来

(1) 假设一事件发生的概率为P,则不发生的概率为1- P,我们把发生概率/不发生概率称之为发生的概率比,数学公式表示为:
o d d s ( 概 率 比 ) = p 1 − p odds(概率比) = \frac{p}{1-p} odds()=1pp
(2) 定义logit函数,即概率比的对数函数(log-odds)
l o g i t ( p ) = l o g ( p 1 − p ) logit(p) = log(\frac{p}{1-p}) logit(p)=log(1pp)
(3) logit函数输入值范围介于[0,1]之间,它能将输入值转为到整个实数范围内。

(4) 对logit函数求解反函数得到的函数为logistic函数
l o g i s t i c ( z ) = s i g m o d    ( 1 1 + e z ) logistic(z) = sig\mod (\frac{1}{1+e^z}) logistic(z)=sigmod(1+ez1)
(5) 给定测试数据为X,要学习的参数为θ
测 试 数 据 X ( x 0 , x 1 , x 2 . . . x n ) 学 习 参 数 θ ( θ 0 , θ 1 , θ 2 . . . x n ) 测试数据X(x_0,x_1,x_2...x_n) 学习参数θ(θ_0,θ_1,θ_2...x_n) X(x0,x1,x2...xn)θ(θ0,θ1,θ2...xn)
模型的线性表示为(样本特征与权重的线性组合):
z = θ 0 x 0 + θ 1 x 1 + θ 2 x 2 + . . . + θ n x n z = θ_0x_0+θ_1x_1+θ_2x_2+...+θ_nx_n z=θ0x0+θ1x1+θ2x2+...+θnxn
向量表示为:
z = θ 0 x 0 + θ 1 x 1 + θ 2 x 2 + . . . + θ n x n = θ T x n z = θ_0x_0+θ_1x_1+θ_2x_2+...+θ_nx_n = θ^Tx_n z=θ0x0+θ1x1+θ2x2+...+θnxn=θTxn
(6) 在处理而知数据时, 引入Sigmoid函数时曲线平滑化,该函数图像为:
g ( z ) = 1 1 + e − z g(z) = \frac{1}{1+e^{-z}} g(z)=1+ez1
在这里插入图片描述
Sigmoid函数以实数值作为输入并将其反射到[0,1]区间,拐点在y=0.5地方。

2.2 Sigmoid函数的特点

函数公式如下:
f ( x ) = s i g m o i d = 1 1 + e − z f(x) = sigmoid = \frac{1}{1+e^-z} f(x)=sigmoid=1+ez1
对f(x)求导:
∂ f ( x ) ∂ x = [ 1 − f ( x ) ] ⋅ f ( x ) \frac{\partial{f(x)}}{\partial{x}} = [1- f(x)]\cdot f(x) xf(x)=[1f(x)]f(x)
对于后面求解逻辑回归损失函数时,会简化求解过程。

2.3 Sigmoid函数绘图实战

# 绘制[-7,7]的sigmod函数图像
import matplotlib.pyplot as plt
import numpy as np


def sigmod(z):
    return 1.0 / (1.0 + np.exp(-z))


z = np.arange(-7, 7, 0.1)
phi_z = sigmod(z)
plt.plot(z, phi_z)
plt.axvline(0.0, color='k')
plt.axhspan(0.0, 1.0, facecolor='1.0', alpha=1.0, ls="dotted")
plt.yticks([0.0, 0.5, 1.0])
plt.ylim(-0.1, 1.1)
plt.xlabel('z')
plt.ylabel('$\phi(z)$')
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

在这里插入图片描述

3.逻辑回归使用sigmoid函数的原因

对于一般的线性回归模型,我们知道:自变量X和因变量Y都是连续的数值,通过X的输入就可以很好的预测Y值。在实际生活中,离散的数据类型也是比较的常见的,比如好和坏,男和女等等。那么问题来了:在线性回归模型的基础上,是否可以实现预测一个因变量为离散数据类型的模型呢?

答案当然是可以的。我们可能会想想到阶跃函数:
ϕ ( z ) = { 0 , i f   z < 0 0.5 , i f   z = 0 1 , i f   z > 0 \phi(z) = {0,if z<00.5if z=01if z>0

ϕ(z)=0,0.51if z<0if z=0if z>0
但是用在这里是不合适的,正如我们神经网络激活函数不选择阶跃函数的原因一样,因为它不连续不可微。而能满足分类效果,且是连续的函数,sigmoid函数是再好不过的选择了。因此逻辑回归模型是在线性回归模型的基础上,套一个sigmoid函数,得到一个处于[0,1]之间的数值,同时设置一个阈值,通过与阈值的比较来实现分类的效果

4.逻辑回归模型的假设

1)假设数据服从伯努利分布(0-1分布)

2)假设模型的输出值是样本正例的概率

5.逻辑回归的损失函数以及求解

LR模型推导过程
  1. 一个事件的几率(odds)
    o d d s = 事 件 发 生 的 概 率 事 件 不 发 生 的 概 率 = p 1 − p odds=\frac {事件发生的概率}{事件不发生的概率}=\frac{p}{1-p} odds==1pp

  2. 该事件的对数几率(log odds)或logit函数:
    l o g i t ( p ) = l o g p 1 − p logit(p)=log^{\frac{p}{1-p}} logit(p)=log1pp

  3. 输出y^(i)=1的多数几率是由输入x的线性函数表示的模型,即逻辑斯谛回归模型。

    (逻辑斯谛回归模型可以将线性函数w·x转换为概率)

p ( y ( i ) = 1 ∣ x ) = 1 1 + e − w T x p(y^{(i)}=1|x)=\frac{1}{1+e^{-w^Tx}} p(y(i)=1x)=1+ewTx1
参 数 输 入 参 数 w = [ w ( 1 ) , w ( 2 ) , . . . w ( n ) , b ] T , 输 入 x = [ x ( 1 ) , x ( 2 ) , . . . x ( n ) , 1 ] T 参数输入参数w=[w^{(1)},w^{(2)},...w^{(n)},b]^T,输入x=[x^{(1)},x^{(2)},...x^{(n)},1]^T w=[w(1),w(2),...w(n),b]T,x=[x(1),x(2),...x(n),1]T

​ 4.函数模型:
正 例 ( y = 1 ) : P ( y = 1 ∣ x ) = p 反 例 ( y = 0 ) : P ( y = 0 ∣ x ) = 1 − p 函 数 合 二 为 一 : P ( y i ∣ x i ) = p y i ( 1 − p ) 1 − y i ( 当 y i = 1 , 结 果 是 p ; 当 y i = 0 , 结 果 是 1 − p ) 。 正例(y=1):P(y=1|x)=p\\ 反例(y=0):P(y=0|x)=1- p\\ 函数合二为一:P(y_i|x_i)=p^{y_i}(1-p)^{1-y_i}\\ (当y_i=1,结果是p;当y_i=0,结果是1-p)。 (y=1):P(y=1x)=p(y=0):P(y=0x)=1pP(yixi)=pyi(1p)1yiyi=1pyi=01p

  1. 似然函数:
    ∏ i = 1 N p y i ( 1 − p ) 1 − y i \prod_{i=1}^Np^{y_i}(1-p)^{1-y_i} i=1Npyi(1p)1yi
    对数似然函数:
    L ( w ) = l n P 总 = l n ( ∏ i = 1 N p y i ( 1 − p ) 1 − y i ) = ∑ i = 1 N l n ( p y i ( 1 − p ) 1 − y i ) = ∑ i = 1 N ( y i l n p + ( 1 − y i ) l n 1 − p ) 其 中 , p = 1 1 + e − w T x L_{(w)}=ln^{P_总}=ln^{(\prod^{N}_{i=1}p^{y_i}(1-p)^{1-y_i})}\\ =\sum^N_{i=1}ln^{(p^{y_i}(1-p)^{1-y_i})}\\ =\sum^N_{i=1}(y_iln^p+(1-y_i)ln^{1-p})\\ 其中,p=\frac{1}{1+e^{-w^Tx}} L(w)=lnP=ln(i=1Npyi(1p)1yi)=i=1Nln(pyi(1p)1yi)=i=1N(yilnp+(1yi)ln1p)p=1+ewTx1
    对L(w)求极大值,得到w的估计值。这样,问题就变成了以对数似然函数为目标函数的最优化问题。

  2. 对数似然函数求导(sigmoid函数的特点)
    ∂ L ( w ) ∂ w j = ∑ i = 1 N ( y ( i ) 1 p ∂ p ∂ θ ∂ θ ∂ w j ) + ( 1 − y ( i ) ) 1 1 − p ( − 1 ) ∂ p ∂ θ ∂ θ ∂ w j ) < 1 > ∂ p ∂ θ = ∂ ∂ θ 1 [ 1 + e − θ ] = 1 ( 1 + e − θ ) ( 1 − 1 ( 1 + e − θ ) ) = p ( 1 − p ) < 2 > ∂ θ ∂ w j = ∂ ∂ w j [ w T x ] = x j < 3 > \frac{\partial_{L_{(w)}}}{\partial_{w_j}}=\sum^N_{i=1}(y^{(i)}\frac{1}{p}\frac{\partial_p}{\partial_\theta}\frac{\partial_\theta}{\partial_{w_j}})+(1-y^{(i)})\frac{1}{1-p}(-1)\frac{\partial_p}{\partial_\theta}\frac{\partial_\theta}{\partial_{w_j}}) \qquad\quad <1>\\ \frac{\partial_p}{\partial_\theta}=\frac{\partial}{\partial_\theta}\frac{1}{[1+e^{-\theta}]}=\frac{1}{(1+e^{-\theta})}(1-\frac{1}{(1+e^{-\theta})})=p(1-p) \qquad<2>\\ \frac{\partial_\theta}{\partial_{w_j}}=\frac{\partial}{\partial_{w_j}}[w^Tx]=x_j \qquad \qquad\qquad\qquad\qquad\qquad\qquad\qquad\qquad<3>\\ wjL(w)=i=1N(y(i)p1θpwjθ)+(1y(i))1p1(1)θpwjθ)<1>θp=θ[1+eθ]1=(1+eθ)1(1(1+eθ)1)=p(1p)<2>wjθ=wj[wTx]=xj<3>

    将 < 3 > < 2 > 代 入 < 1 > 中 ∂ L ( w ) ∂ w j = ∑ i = 1 N ( y ( i ) 1 p p ( 1 − p ) x j + ( 1 − y ( i ) ) 1 1 − p ( − 1 ) p ( 1 − p ) x j ) = ∑ i = 1 N ( y ( i ) − p ) x j 将<3><2>代入<1>中\\ \frac{\partial_{L_{(w)}}}{\partial_{w_j}}=\sum^N_{i=1}(y^{(i)}\frac{1}{p}p(1-p)x_j+(1-y^{(i)})\frac{1}{1-p}(-1)p(1-p)x_j)\\ =\sum^N_{i=1}(y^{(i)}-p)x_j <3><2><1>wjL(w)=i=1N(y(i)p1p(1p)xj+(1y(i))1p1(1)p(1p)xj)=i=1N(y(i)p)xj

    7.利用梯度下降法求解目标函数的最大值

    给定训练步长和初始值w,迭代收敛

    更新规则
    w j = w j + α ∇ L ( w ) 而 ∇ L ( w ) = ∂ L ( w j ) ∂ w j = ( y ( i ) − p ) x j w_j=w_j+\alpha\nabla L_{(w)}\\ 而\nabla L_{(w)}=\frac{\partial L_{(w_j)}}{\partial_{w_j}}=(y^{(i)}-p)x_j wj=wj+αL(w)L(w)=wjL(wj)=(y(i)p)xj

6.逻辑回归的优缺点

优点:

1)LR能以概率的形式输出结果,而非只是0,1判定。

2)LR的可解释性强,可控度高

3)训练快,feature engineering之后效果赞。

4)因为结果是概率,可以做ranking model。

缺点:

1)容易过拟合

2)分类精度可能不高

7.逻辑回归的应用

1)CTR预估/推荐系统的learning to rank/各种分类场景。
2)某搜索引擎厂的广告CTR预估基线版是LR。
3)某电商搜索排序/广告CTR预估基线版是LR。
4)某电商的购物搭配推荐用了大量LR。
5)某现在一天广告赚1000w+的新闻app排序基线是LR。

8.基于鸢尾花(iris)数据集的逻辑回归分类实践

8.1 导入基本函数库
# 基础函数库
import numpy as np
import pandas as pd

# 绘图函数库
import matplotlib.pyplot as plt
import seaborn as sns
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

鸢尾花数据集,共包含5个变量,其中4个特征变量,1个目标分类变量;共150个样本。目标变量为 花的类别 其都属于鸢尾属下的三个亚属,分别是山鸢尾 (Iris-setosa),变色鸢 尾(Iris-versicolor)和维吉尼亚鸢尾(Iris-virginica)。通过鸢尾花下表的四个特征进行物种的识别。

变量描述
sepal length花萼长度(cm)
sepal width花萼宽度(cm)
petal length花瓣长度(cm)
petal width花瓣宽度(cm)
target鸢尾花的三个亚属类别,‘setosa(0)’,‘versicolor(1)’,‘virginica(2)’
8.2 读取鸢尾花数据集
# 读取sklearn库自带的iris数据集
from sklearn.datasets import load_iris
data = load_iris() # 获取数据特征
print(data)
  • 1
  • 2
  • 3
  • 4

输出结果:(数据集的基本信息)

{'data': array([[ 5.1,  3.5,  1.4,  0.2],
       [ 4.9,  3. ,  1.4,  0.2],
       [ 4.7,  3.2,  1.3,  0.2],
       [ 4.6,  3.1,  1.5,  0.2],
       [ 5. ,  3.6,  1.4,  0.2],
       [ 5.4,  3.9,  1.7,  0.4],
       [ 4.6,  3.4,  1.4,  0.3],
       [ 5. ,  3.4,  1.5,  0.2],
       [ 4.4,  2.9,  1.4,  0.2],
       [ 4.9,  3.1,  1.5,  0.1],
       [ 5.4,  3.7,  1.5,  0.2],
       [ 4.8,  3.4,  1.6,  0.2],
       [ 4.8,  3. ,  1.4,  0.1],
       [ 4.3,  3. ,  1.1,  0.1],
       [ 5.8,  4. ,  1.2,  0.2],
       [ 5.7,  4.4,  1.5,  0.4],
       [ 5.4,  3.9,  1.3,  0.4],
       [ 5.1,  3.5,  1.4,  0.3],
       [ 5.7,  3.8,  1.7,  0.3],
       [ 5.1,  3.8,  1.5,  0.3],
       [ 5.4,  3.4,  1.7,  0.2],
       [ 5.1,  3.7,  1.5,  0.4],
       [ 4.6,  3.6,  1. ,  0.2],
       [ 5.1,  3.3,  1.7,  0.5],
       [ 4.8,  3.4,  1.9,  0.2],
       [ 5. ,  3. ,  1.6,  0.2],
       [ 5. ,  3.4,  1.6,  0.4],
       [ 5.2,  3.5,  1.5,  0.2],
       [ 5.2,  3.4,  1.4,  0.2],
       [ 4.7,  3.2,  1.6,  0.2],
       [ 4.8,  3.1,  1.6,  0.2],
       [ 5.4,  3.4,  1.5,  0.4],
       [ 5.2,  4.1,  1.5,  0.1],
       [ 5.5,  4.2,  1.4,  0.2],
       [ 4.9,  3.1,  1.5,  0.1],
       [ 5. ,  3.2,  1.2,  0.2],
       [ 5.5,  3.5,  1.3,  0.2],
       [ 4.9,  3.1,  1.5,  0.1],
       [ 4.4,  3. ,  1.3,  0.2],
       [ 5.1,  3.4,  1.5,  0.2],
       [ 5. ,  3.5,  1.3,  0.3],
       [ 4.5,  2.3,  1.3,  0.3],
       [ 4.4,  3.2,  1.3,  0.2],
       [ 5. ,  3.5,  1.6,  0.6],
       [ 5.1,  3.8,  1.9,  0.4],
       [ 4.8,  3. ,  1.4,  0.3],
       [ 5.1,  3.8,  1.6,  0.2],
       [ 4.6,  3.2,  1.4,  0.2],
       [ 5.3,  3.7,  1.5,  0.2],
       [ 5. ,  3.3,  1.4,  0.2],
       [ 7. ,  3.2,  4.7,  1.4],
       [ 6.4,  3.2,  4.5,  1.5],
       [ 6.9,  3.1,  4.9,  1.5],
       [ 5.5,  2.3,  4. ,  1.3],
       [ 6.5,  2.8,  4.6,  1.5],
       [ 5.7,  2.8,  4.5,  1.3],
       [ 6.3,  3.3,  4.7,  1.6],
       [ 4.9,  2.4,  3.3,  1. ],
       [ 6.6,  2.9,  4.6,  1.3],
       [ 5.2,  2.7,  3.9,  1.4],
       [ 5. ,  2. ,  3.5,  1. ],
       [ 5.9,  3. ,  4.2,  1.5],
       [ 6. ,  2.2,  4. ,  1. ],
       [ 6.1,  2.9,  4.7,  1.4],
       [ 5.6,  2.9,  3.6,  1.3],
       [ 6.7,  3.1,  4.4,  1.4],
       [ 5.6,  3. ,  4.5,  1.5],
       [ 5.8,  2.7,  4.1,  1. ],
       [ 6.2,  2.2,  4.5,  1.5],
       [ 5.6,  2.5,  3.9,  1.1],
       [ 5.9,  3.2,  4.8,  1.8],
       [ 6.1,  2.8,  4. ,  1.3],
       [ 6.3,  2.5,  4.9,  1.5],
       [ 6.1,  2.8,  4.7,  1.2],
       [ 6.4,  2.9,  4.3,  1.3],
       [ 6.6,  3. ,  4.4,  1.4],
       [ 6.8,  2.8,  4.8,  1.4],
       [ 6.7,  3. ,  5. ,  1.7],
       [ 6. ,  2.9,  4.5,  1.5],
       [ 5.7,  2.6,  3.5,  1. ],
       [ 5.5,  2.4,  3.8,  1.1],
       [ 5.5,  2.4,  3.7,  1. ],
       [ 5.8,  2.7,  3.9,  1.2],
       [ 6. ,  2.7,  5.1,  1.6],
       [ 5.4,  3. ,  4.5,  1.5],
       [ 6. ,  3.4,  4.5,  1.6],
       [ 6.7,  3.1,  4.7,  1.5],
       [ 6.3,  2.3,  4.4,  1.3],
       [ 5.6,  3. ,  4.1,  1.3],
       [ 5.5,  2.5,  4. ,  1.3],
       [ 5.5,  2.6,  4.4,  1.2],
       [ 6.1,  3. ,  4.6,  1.4],
       [ 5.8,  2.6,  4. ,  1.2],
       [ 5. ,  2.3,  3.3,  1. ],
       [ 5.6,  2.7,  4.2,  1.3],
       [ 5.7,  3. ,  4.2,  1.2],
       [ 5.7,  2.9,  4.2,  1.3],
       [ 6.2,  2.9,  4.3,  1.3],
       [ 5.1,  2.5,  3. ,  1.1],
       [ 5.7,  2.8,  4.1,  1.3],
       [ 6.3,  3.3,  6. ,  2.5],
       [ 5.8,  2.7,  5.1,  1.9],
       [ 7.1,  3. ,  5.9,  2.1],
       [ 6.3,  2.9,  5.6,  1.8],
       [ 6.5,  3. ,  5.8,  2.2],
       [ 7.6,  3. ,  6.6,  2.1],
       [ 4.9,  2.5,  4.5,  1.7],
       [ 7.3,  2.9,  6.3,  1.8],
       [ 6.7,  2.5,  5.8,  1.8],
       [ 7.2,  3.6,  6.1,  2.5],
       [ 6.5,  3.2,  5.1,  2. ],
       [ 6.4,  2.7,  5.3,  1.9],
       [ 6.8,  3. ,  5.5,  2.1],
       [ 5.7,  2.5,  5. ,  2. ],
       [ 5.8,  2.8,  5.1,  2.4],
       [ 6.4,  3.2,  5.3,  2.3],
       [ 6.5,  3. ,  5.5,  1.8],
       [ 7.7,  3.8,  6.7,  2.2],
       [ 7.7,  2.6,  6.9,  2.3],
       [ 6. ,  2.2,  5. ,  1.5],
       [ 6.9,  3.2,  5.7,  2.3],
       [ 5.6,  2.8,  4.9,  2. ],
       [ 7.7,  2.8,  6.7,  2. ],
       [ 6.3,  2.7,  4.9,  1.8],
       [ 6.7,  3.3,  5.7,  2.1],
       [ 7.2,  3.2,  6. ,  1.8],
       [ 6.2,  2.8,  4.8,  1.8],
       [ 6.1,  3. ,  4.9,  1.8],
       [ 6.4,  2.8,  5.6,  2.1],
       [ 7.2,  3. ,  5.8,  1.6],
       [ 7.4,  2.8,  6.1,  1.9],
       [ 7.9,  3.8,  6.4,  2. ],
       [ 6.4,  2.8,  5.6,  2.2],
       [ 6.3,  2.8,  5.1,  1.5],
       [ 6.1,  2.6,  5.6,  1.4],
       [ 7.7,  3. ,  6.1,  2.3],
       [ 6.3,  3.4,  5.6,  2.4],
       [ 6.4,  3.1,  5.5,  1.8],
       [ 6. ,  3. ,  4.8,  1.8],
       [ 6.9,  3.1,  5.4,  2.1],
       [ 6.7,  3.1,  5.6,  2.4],
       [ 6.9,  3.1,  5.1,  2.3],
       [ 5.8,  2.7,  5.1,  1.9],
       [ 6.8,  3.2,  5.9,  2.3],
       [ 6.7,  3.3,  5.7,  2.5],
       [ 6.7,  3. ,  5.2,  2.3],
       [ 6.3,  2.5,  5. ,  1.9],
       [ 6.5,  3. ,  5.2,  2. ],
       [ 6.2,  3.4,  5.4,  2.3],
       [ 5.9,  3. ,  5.1,  1.8]]), 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]), 'target_names': array(['setosa', 'versicolor', 'virginica'],
      dtype='<U10'), 
      'DESCR': 'Iris Plants Database
      Data Set Characteristics:
      Number of Instances: 150 (50 in each of three classes)
      Number of Attributes: 4 numeric, predictive attributes and the class
      Attribute Information:        
      - sepal length in cm        
      - sepal width in cm       
      - petal length in cm       
      - petal width in cm      
      class:               
      - Iris-Setosa              
      - Iris-Versicolour              
      - Iris-Virginica
      Summary Statistics:                
                       Min  Max    Mean    SD   Class Correlation 
      sepal length:   4.3   7.9    5.84   0.83     0.7826    
      sepal width:    2.0   4.4    3.05   0.43    -0.4194   
      petal length:   1.0   6.9    3.76   1.76     0.9490  (high!) 
      petal width:    0.1   2.5    1.20   0.76     0.9565  (high!)
      Missing Attribute Values: None
      Class Distribution: 33.3% for each of 3 classes.
      Creator: R.A. Fisher
      Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
      Date: July, 1988
      This is a copy of UCI ML iris datasets.
 			http://archive.ics.uci.edu/ml/datasets/Iris
      The famous Iris database, first used by Sir R.A Fisher
 			This is perhaps the best known database to be found in the pattern recognition literature.  
 			The paper of Fisher  is a classic in the field and is referenced frequently to this day.  (See Duda & Hart, for example.)  
      The dataset contains 3 classes of 50 instances each, where each class refers to a type of iris plant.  
      One class is linearly separable from the other 2; the latter are NOT linearly separable from each other.
      References 
      - Fisher,R.A. "The use of multiple measurements in taxonomic problems"    
      Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to   
      Mathematical Statistics" (John Wiley, NY, 1950).
      - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.    
      (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218. 
      - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System     
      Structure and Classification Rule for Recognition in Partially Exposed     
      Environments".  IEEE Transactions on Pattern Analysis and Machine    I
      ntelligence, Vol. PAMI-2, No. 1, 67-71.  
      - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE Transactions on Information Theory, May 1972, 431-433.   
      - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II
      conceptual clustering system finds 3 classes in the data.
      - Many, many more ...
      ', 'feature_names': ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
# 获取特征数据对应的标签
iris_target = data.target
print(iris_target)

#  输出结果
"""
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]
"""
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
# 利用pandas将数据转换为DataFrame格式的数据
iris_features = pd.DataFrame(data=data.data, columns=data.feature_names)
  • 1
  • 2
8.3 数据信息的简单查看
# 利用info()查看数据的整体信息
iris_features.info()
"""
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 150 entries, 0 to 149
Data columns (total 4 columns):
sepal length (cm)    150 non-null float64
sepal width (cm)     150 non-null float64
petal length (cm)    150 non-null float64
petal width (cm)     150 non-null float64
dtypes: float64(4)
"""
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
# 查看数据,使用.head() .tail()方法,默认取5行数据
print(iris_features.head())
"""
   sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)
0                5.1               3.5                1.4               0.2
1                4.9               3.0                1.4               0.2
2                4.7               3.2                1.3               0.2
3                4.6               3.1                1.5               0.2
4                5.0               3.6                1.4               0.2
"""
print(iris_features.tail())
"""
     sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)
145                6.7               3.0                5.2               2.3
146                6.3               2.5                5.0               1.9
147                6.5               3.0                5.2               2.0
148                6.2               3.4                5.4               2.3
149                5.9               3.0                5.1               1.8
"""
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
#  利用value_counts查看每个类别的数量
print(pd.Series(iris_target).value_counts())
"""
2    50
1    50
0    50
dtype: int64
"""
#  对于特征进行一些简单的统计 统计的结果可以看到不同数数值特征的变化范围
print(iris_features.describe())
"""
        sepal length (cm)  sepal width (cm)  petal length (cm)   petal width (cm)  
count      150.000000        150.000000         150.000000      150.000000
mean         5.843333          3.054000           3.758667        1.198667
std          0.828066          0.433594           1.764420        0.763161
min          4.300000          2.000000           1.000000        0.100000
25%          5.100000          2.800000           1.600000        0.300000
50%          5.800000          3.000000           4.350000        1.300000
75%          6.400000          3.300000           5.100000        1.800000
max          7.900000          4.400000           6.900000        2.500000
"""
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
8.4 数据可视化描述
# 合并标签和特征信息
iris_all = iris_features.copy()  # 进行数据的浅拷贝
print(iris_all)
"""
取出结果前5行
     sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)
0                  5.1               3.5                1.4               0.2
1                  4.9               3.0                1.4               0.2
2                  4.7               3.2                1.3               0.2
3                  4.6               3.1                1.5               0.2
4                  5.0               3.6                1.4               0.2
"""
iris_all['target'] = iris_target


# 特征与标签组合的散点可视化
sns.pairplot(data=iris_all, diag_kind='hist', hue='target')
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

在这里插入图片描述

从上图可以发现,在2D情况下不同的特征组合对于不同类别的花的散点分布,以及大概的区分能力。

for col in iris_features.columns:
  sns.boxplot(x ='target', y = col, saturtaion = 0.5, palette='pastel', data=iris_all)
  plt.title(col)
  plt.show()
  • 1
  • 2
  • 3
  • 4

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
利用箱型图我们也可以得到不同类别在不同特征上的分布差异情况。

# 选取其前三个特征绘制三维散点图
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure(figsize=(10,8))
ax = fig.add_subplot(111, projection='3d')
iris_all_class0 = iris_all[iris_all['target']==0].values
iris_all_class1 = iris_all[iris_all['target']==1].values
iris_all_class2 = iris_all[iris_all['target']==2].values
# 'setosa'(0), 'versicolor'(1), 'virginica'(2)
ax.scatter(iris_all_class0[:,0], iris_all_class0[:,1], iris_all_class0[:,2],label='setosa')
ax.scatter(iris_all_class1[:,0], iris_all_class1[:,1], iris_all_class1[:,2],label='versicolor')
ax.scatter(iris_all_class2[:,0], iris_all_class2[:,1], 
iris_all_class2[:,2],label='virginica')
plt.legend()
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

在这里插入图片描述

8.5 利用逻辑回归模型在二分类上进行训练和预测
# 为了正确评估模型性能,将数据集划分为训练集和测试集,并在训练集上训练模型,在测试集上验证模型性能
from sklearn.model_selection import train_test_split

# 选择其类别为0和1的样本,即前100个样本(不包括类别为2的样本)
iris_features_part = iris_features.iloc[:100]
iris_target_part = iris_target[:100]

# 数据集的切分 训练集:测试集=0.8:0.2 random_state保证结果的可重复性
x_tain, x_test, y_train, y_test = train_test_split(iris_features_part, iris_target_part, test_size=0.2,
                                                   random_state=2020)
# 从sklearn库导入LR模型
from sklearn.linear_model import LogisticRegression as LR

# 定义逻辑回归模型
clf = LR(random_state=0, solver='lbfgs')

# 在训练集上训练逻辑回归模型
LR_model = clf.fit(x_tain, y_train)
print(LR_model)

"""
LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
          penalty='l2', random_state=0, solver='lbfgs', tol=0.0001,
          verbose=0, warm_start=False)
"""

# 查看其对应的w
print('the weight of Logistic Regression:', clf.coef_)

"""
the weight of Logistic Regression: [[ 0.45244919 -0.81010583  2.14700385  0.90450733]]
"""

# 查看其对应的w0
print('the intercept(w0) of Logistic Regression:', clf.intercept_)

"""
the intercept(w0) of Logistic Regression: [-6.57504448]
"""
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
# 在训练集和测试集上使用训练好的模型进行预测
train_predict = clf.predict(x_train)
test_predict = clf.predict(x_test)

from sklearn import metrics

# 使用accuracy(准确度)[预测正确的样本数目占总预测样本数目的比例]来评估模型的效果
print('The accuracy of Logistic Regression on train_dataset is:', metrics.accuracy_score(y_train, train_predict))
print('The accuracy of Logistic Regression on test_dataset is:', metrics.accuracy_score(y_test, test_predict))

# 查看混淆矩阵(预测值和真实值的各类情况统计矩阵)
confusion_matrix_result = metrics.confusion_matrix(test_predict, y_test)
print('The confusion matrix result:\n', confusion_matrix_result)

# 利用热力图对于结果进行可视化
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_matrix_result, annot=True, cmap='Blues')
plt.xlabel('Predict labels')
plt.ylabel('True label')
plt.show()

# 输出结果
"""
The accuracy of Logistic Regression on train_dataset is: 1.0
The accuracy of Logistic Regression on test_dataset is: 1.0
The confusion matrix result:
 [[ 9  0]
 [ 0 11]]
"""
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

在这里插入图片描述

8.6 利用逻辑回归模型在三分类(多分类)上进行训练和预测
# 数据集的切分 训练集:测试集=0.8:0.2 random_state保证结果的可重复性
x_train, x_test, y_train, y_test = train_test_split(iris_features, iris_target, test_size=0.2, random_state=2020)

# 定义LR模型
lr_model = LR(random_state=0, solver='lbfgs')

# 在训练集上训练LR模型
lr_model = clf.fit(x_train, y_train)
print(lr_model)
# 查看其对应的w
print('the weight of Logistic Regression:', clf.coef_)

# 查看其对应的w0
print('the intercept(w0) of Logistic Regression:', clf.intercept_)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

输出结果:

LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
          penalty='l2', random_state=0, solver='lbfgs', tol=0.0001,
          verbose=0, warm_start=False)
# 由于这个是3分类,所有我们这里得到了三个逻辑回归模型的参数,其三个逻辑回归组合起来即可实现三分类。
the weight of Logistic Regression:
 [[-0.43538857  0.87888013 -2.19176678 -0.94642091]
 [-0.39434234 -2.6460985   0.76204684 -1.35386989]
 [-0.00806312  0.11304846  2.52974343  2.3509289 ]]
the intercept(w0) of Logistic Regression: [  6.30620875   8.25761672 -16.63629247]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
# 在训练集和测试集上使用训练好的模型进行预测
train_predict = clf.predict(x_train)
test_predict = clf.predict(x_test)

# 由于逻辑回归模型是概率预测模型,所有我们可以利用predict_proba 函数预测其概率
train_predict_proba = clf.predict_proba(x_train)
test_predict_proba = clf.predict_proba(x_test)

# 其中第一列代表预测为0类的概率,第二列代表预测为1类的概率,第三列代表预测为2类的概率。
print('The test predict Probability of each class:\n', test_predict_proba)

# 使用accuracy(准确度)[预测正确的样本数目占总预测样本数目的比例]来评估模型的效果
print('The accuracy of Logistic Regression on train_dataset is:', metrics.accuracy_score(y_train, train_predict))
print('The accuracy of Logistic Regression on test_dataset is:', metrics.accuracy_score(y_test, test_predict))

# 查看混淆矩阵(预测值和真实值的各类情况统计矩阵)
confusion_matrix_result = metrics.confusion_matrix(test_predict, y_test)
print('The confusion matrix result:\n', confusion_matrix_result)

# 利用热力图对于结果进行可视化
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_matrix_result, annot=True, cmap='Blues')
plt.xlabel('Predict labels')
plt.ylabel('True label')
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

输出结果:

The test predict Probability of each class:
 [[  1.32525870e-04   2.41745142e-01   7.58122332e-01]
 [  7.02970475e-01   2.97026349e-01   3.17667822e-06]
 [  3.37367886e-02   7.25313901e-01   2.40949311e-01]
 [  5.66207138e-03   6.53245545e-01   3.41092383e-01]
 [  1.06817066e-02   6.72928600e-01   3.16389693e-01]
 [  8.98402870e-04   6.64470713e-01   3.34630884e-01]
 [  4.06382037e-04   3.86192249e-01   6.13401369e-01]
 [  1.26979439e-01   8.69440588e-01   3.57997319e-03]
 [  8.75544317e-01   1.24437252e-01   1.84312617e-05]
 [  9.11209514e-01   8.87814689e-02   9.01671605e-06]
 [  3.86067682e-04   3.06912689e-01   6.92701243e-01]
 [  6.23261939e-03   7.19220636e-01   2.74546745e-01]
 [  8.90760124e-01   1.09235653e-01   4.22292409e-06]
 [  2.32339490e-03   4.47236837e-01   5.50439768e-01]
 [  8.59945211e-04   4.22804376e-01   5.76335679e-01]
 [  9.24814068e-01   7.51814638e-02   4.46852786e-06]
 [  2.01307999e-02   9.35166320e-01   4.47028801e-02]
 [  1.71215635e-02   5.07246971e-01   4.75631465e-01]
 [  1.83964097e-04   3.17849048e-01   6.81966988e-01]
 [  5.69461042e-01   4.30536566e-01   2.39269631e-06]
 [  8.26025475e-01   1.73971556e-01   2.96936737e-06]
 [  3.05327704e-04   5.15880492e-01   4.83814180e-01]
 [  4.69978972e-03   2.90561777e-01   7.04738434e-01]
 [  8.61077168e-01   1.38915993e-01   6.83858427e-06]
 [  6.99887637e-04   2.48614010e-01   7.50686102e-01]
 [  5.33421842e-02   8.31557126e-01   1.15100690e-01]
 [  2.34973018e-02   3.54915328e-01   6.21587370e-01]
 [  1.63311193e-03   3.48301765e-01   6.50065123e-01]
 [  7.72156866e-01   2.27838662e-01   4.47157219e-06]
 [  9.30816593e-01   6.91640361e-02   1.93708074e-05]]
The accuracy of Logistic Regression on train_dataset is: 0.958333333333
The accuracy of Logistic Regression on test_dataset is: 0.8
The confusion matrix result:
 [[10  0  0]
 [ 0  7  3]
 [ 0  3  7]]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

在这里插入图片描述

9. 逻辑回归的特征离散化

逻辑回归对特征进行离散化的好处:
  • 非线性!非线性!非线性!逻辑回归属于广义线性模型,表达能力受限;单变量离散化为N个后,每个变量有单独的权重,相当于为模型引入了非线性,能够提升模型表达能力,加大拟合; 离散特征的增加和减少都很容易,易于模型的快速迭代;
  • 速度快!速度快!速度快!稀疏向量内积乘法运算速度快,计算结果方便存储,容易扩展;
  • 鲁棒性!鲁棒性!鲁棒性!离散化后的特征对异常数据有很强的鲁棒性:比如一个特征是年龄>30是1,否则0。如果特征没有离散化,一个异常数据“年龄300岁”会给模型造成很大的干扰;
  • 方便交叉与特征组合:离散化后可以进行特征交叉,由M+N个变量变为M*N个变量,进一步引入非线性,提升表达能力;
  • 稳定性:特征离散化后,模型会更稳定,比如如果对用户年龄离散化,20-30作为一个区间,不会因为一个用户年龄长了一岁就变成一个完全不同的人。当然处于区间相邻处的样本会刚好相反,所以怎么划分区间是门学问;
  • 简化模型:特征离散化以后,起到了简化了逻辑回归模型的作用,降低了模型过拟合的风险。

10. 逻辑回归总结

1)逻辑回归即为数据服从伯努利分布,通过极大似然函数的方法,运用梯度下降法求解参数,来达到二分类的目的。

2)逻辑回归是一个分类模型,解决分类问题(类别+概率),可以做ranking model。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/199135?site
推荐阅读
相关标签
  

闽ICP备14008679号