当前位置:   article > 正文

利用逻辑回归进行鸢尾花分类_鸢尾花逻辑回归

鸢尾花逻辑回归

利用逻辑回归进行鸢尾花分类

  1. 数据集处理
  2. 数据可视化
  3. 模型训练

首先导入我们所需要用到的库

import numpy as np 
import pandas as pd

## 绘图函数库
import matplotlib.pyplot as plt
import seaborn as sns
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

1.数据集导入

要训练模型,首先要处理数据集,我们使用的数据集是sklearn中的鸢尾花数据集,该数据集一共有四个特征变量,一个目标分类变量,共有150个样本。

变量描述
sepal length花萼长度(cm)
sepal width花萼宽度(cm)
petal length花瓣长度(cm)
petal width花瓣宽度(cm)
target鸢尾的三个亚属类别,‘setosa’(0), ‘versicolor’(1), ‘virginica’(2)
#从sklearn中导入鸢尾花数据集
from sklearn.datasets import load_iris

data = load_iris()
iris_label = data.target
iris_features = pd.DataFrame(data=data.data, columns=data.feature_names)

#查看数据大致情况
iris_features.info()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

请添加图片描述
当我们拿到数据集后首先要看这个数据集的大致情况,都有什么特征,以及有没有缺失值,如果有缺失值就要对缺失值进行处理。从上面我们可以看到,这个数据集并没有缺失值,因此不需要进行缺失值处理。

查看数据集的前五行

iris_features.head()
  • 1

在这里插入图片描述
查看标签,0,1,2分布代表不同类别的鸢尾花

## 其对应的类别标签为,其中0,1,2分别代表'setosa', 'versicolor', 'virginica'三种不同花的类别。
iris_label
  • 1
  • 2

在这里插入图片描述

2.数据可视化

在训练数据之前,我们先在图中查看一下该数据集

在这里为了避免我们在可视化的过程中不小心修改到原数据,这里我们对原数据进行一下拷贝

iris_demo = iris_features.copy()
iris_demo['target'] = iris_label
#查看一下用于可视化的数据集
iris_demo.head()
  • 1
  • 2
  • 3
  • 4

在这里插入图片描述

## 特征与标签组合的散点可视化
sns.pairplot(data=iris_demo, diag_kind='hist', hue = 'target')
plt.show()
  • 1
  • 2
  • 3

箱型图绘制

for col in iris_features.columns:
    sns.boxplot(x = 'target', y = col, saturation=0.5, palette='pastel', data=iris_demo)
    plt.title(col)
    plt.show()
  • 1
  • 2
  • 3
  • 4

在这里插入图片描述
三维散点图绘制

#三维散点图绘制
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(10,8))
ax = fig.add_subplot(111, projection='3d')

iris_demo_label1 = iris_demo[iris_demo['target']==0].values
iris_demo_label2 = iris_demo[iris_demo['target']==1].values
iris_demo_label3 = iris_demo[iris_demo['target']==2].values

ax.scatter(iris_demo_label1[:,0], iris_demo_label1[:,1], iris_demo_label1[:,2],label='setosa')
ax.scatter(iris_demo_label2[:,0], iris_demo_label2[:,1], iris_demo_label2[:,2],label='versicolor')
ax.scatter(iris_demo_label3[:,0], iris_demo_label3[:,1], iris_demo_label3[:,2],label='virginica')
plt.legend()

plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

在这里插入图片描述

3.利用逻辑回归模型,在二分类任务上进行训练和预测

from sklearn.model_selection import train_test_split

#选择类别为0和1的样本
iris_features_part = iris_features.iloc[:100]
iris_label_part = pd.Series(iris_label).iloc[:100].values

#划分训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(iris_features_part,iris_label_part,test_size = 0.2,
                                                    random_state=2020)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

这里我们做的是二分类,因此我们在样本数据集中选择两种鸢尾花的数据进行训练

from sklearn.linear_model import LogisticRegression

clf = LogisticRegression(random_state=0, solver='lbfgs')
  • 1
  • 2
  • 3
clf.fit(x_train,y_train)
  • 1
## 查看其对应的w
print('the weight of Logistic Regression:',clf.coef_)

## 查看其对应的w0
print('the intercept(w0) of Logistic Regression:',clf.intercept_)
  • 1
  • 2
  • 3
  • 4
  • 5

在这里插入图片描述
在训练集和测试集上分别利用训练好的模型进行预测

train_predict = clf.predict(x_train)
print(train_predict,y_train)
  • 1
  • 2

在这里插入图片描述

test_predict = clf.predict(x_test)
print(test_predict,y_test)
  • 1
  • 2

在这里插入图片描述
从上面的结果我们可以看到,无论是在训练集还是测试集,我们的预测值和实际值是一样的。接下来我们用利用accuracy(准确度)评估模型效果

精确度指的是 预测正确的样本数目占总预测样本数目的比例

accuracy也有缺点,就是在样本不平衡的情况下并不能作为很好的指标来衡量结果

from sklearn import metrics

## 利用accuracy(准确度)【预测正确的样本数目占总预测样本数目的比例】评估模型效果
print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_train,train_predict))
print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_test,test_predict))
  • 1
  • 2
  • 3
  • 4
  • 5

在这里插入图片描述
这里accuracy的值是1,我们在上面测试的结果是一致的。

## 查看混淆矩阵 (预测值和真实值的各类情况统计矩阵)
confusion_matrix_result = metrics.confusion_matrix(test_predict,y_test)
print('The confusion matrix result:\n',confusion_matrix_result)
  • 1
  • 2
  • 3

在这里插入图片描述

# 利用热力图对于结果进行可视化
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_matrix_result, annot=True, cmap='Blues')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/337688
推荐阅读
相关标签
  

闽ICP备14008679号