当前位置:   article > 正文

一文读懂机器学习分类全流程_平衡数据集

平衡数据集

608ab26ec97840f2a4fa8dd46da22706.png

目录

 

前言

提出问题

一、介绍

1.分类简介

2.imblearn的安装

二、数据加载及预处理

1.加载并查看数据

①导入Python第三方库   

②调用并查看数据

2.查看数据分布

①各国样本分布直方图

②各国样本划分

3.各国最受欢迎食材可视化

4.平衡数据集

①样本插值采样

三、分类器选择

四、逻辑回归模型构建

1.数据导入及查看

2.可训练特征与预测标签选择及数据集划分

3.构建逻辑回归模型及精度评价

4.模型测试

五、更多分类模型

1.第三方库导入

2.测试不同分类器

①线性SVC分类器

②K-近邻分类器

③SVM分类器

④集成分类器

六、模型发布为Web应用

1.模型打包

2.配置Flask应用

②app.py

3.应用运行及测试

结论


 

前言

        在本文中,你将学到:

0 复习数据预处理及可视化

1 了解分类的基本概念

2 使用多种分类器来对比模型精度

3 掌握使用分类器列表的方式来批处理不同模型

4 将机器学习分类模型部署为Web应用

提出问题

        本文我们所用的数据集是亚洲美食数据集,其包括了亚洲5个国家的美食食谱与所属的国家。我们构建模型的目的是解决:

如何根据美食所用食材判断其所属国家

        即以美食食材为可训练特征,所属国家为预测标签构建机器学习分类模型。

 

一、介绍

1.分类简介

        分类是经典机器学习的基本重点,也是监督学习的一种形式,与回归技术有很多共同之处。其通常分为两类:二元分类和多元分类。本文中,我将使用亚洲美食数据集贯穿本次学习。

还记得我在之前文章中提到的:

0 线性回归可帮助我们预测变量之间的关系,并准确预测新数据点相对于该线的位置。因此,例如,预测南瓜在9月与12月的价格。

1 Logistic回归帮助我们发现“二元类别”:在这个价格点上,这是橙子还是非橙子?

        分类也是机器学习人员和数据科学家的基本工作之一。从二分类(判断邮件是否是垃圾邮件),到使用计算机视觉的复杂分类和分割,其在很多领域都有着很大的作用。

        以更科学的方式陈述该过程:

       我们所使用的分类方法创建了一个预测模型,这个模型使我们能够将输入变量之间的关系映射到输出变量。

670725bb7b904a1badb0eb98183251f8.png

        分类使用各种算法来确定数据点的标签或类别。我们以亚洲美食数据集为例,看看通过输入一组特征样本,我们是否可以确定其菜肴所属的国家。

        以下是经典机器学习常用的分类方法。

0 逻辑回归

1 决策树分类(ID3、C4.5、CART)

2 基于规则分类

3 K-近邻算法(K-NN)

4 贝叶斯分类

5 支持向量机(SVM)

6 随机森林(Random Forest)

2.imblearn的安装

        在开始本文学习之前,我们第一个任务是清理和调整数据集以获得更好的分析结果。我们需要安装的是imblearn库。这是一个基于Python的Scikit-learn软件包,它可以让我们更好地平衡数据。

        在命令行中输入以下代码,使用阿里的镜像来安装 imblearn

pip install -i https://mirrors.aliyun.com/pypi/simple/ imblearn

        看到如下图所示代表安装成功。

9f8e99e936714211aba6d1b4194972e9.png

 

二、数据加载及预处理

1.加载并查看数据

①导入Python第三方库   

  1. import pandas as pd
  2. import matplotlib.pyplot as plt
  3. import matplotlib as mpl
  4. import numpy as np
  5. from imblearn.over_sampling import SMOTE

②调用并查看数据

  1. df = pd.read_csv('cuisines.csv')
  2. df.head()

        前5行数据如下图所示:

4db9405c915d4e618a7335e356fb9579.png

        查看数据结构:

df.info()

aad8d10e36654b628a69c1a7ab67602b.png

        通过查看数据与其组织结构,我们可以发现数据有2448行,385列,其中有大量的无效数据。

2.查看数据分布

        我们可以对数据进行可视化来发现数据集中的数据分布。

①各国样本分布直方图

        通过调用barh()函数将数据绘制为柱状图。

df.cuisine.value_counts().plot.barh() #根据不同国家对数据集进行划分

         结果如下:

f5a7565988fc451bb171128b5e1800e0.png

         我们可以看到,数据集中以韩国料理样本最多,泰国样本最少。美食的数量有限,但数据的分布是不均匀的。我们可以解决这个问题!在此之前,请进一步探索。

②各国样本划分

        了解各国美食有多少可用数据并将其打印输出。输入以下代码,从结果中我们可以看到不同国家美食的可用数据:

  1. thai_df = df[(df.cuisine == "thai")]#提取泰国美食
  2. japanese_df = df[(df.cuisine == "japanese")]#提取日本美食
  3. chinese_df = df[(df.cuisine == "chinese")]#提取中国美食
  4. indian_df = df[(df.cuisine == "indian")]#提取印度美食
  5. korean_df = df[(df.cuisine == "korean")]#提取韩国美食
  6. print(f'thai df: {thai_df.shape}')#输出数据结构
  7. print(f'japanese df: {japanese_df.shape}')#输出数据结构
  8. print(f'chinese df: {chinese_df.shape}')#输出数据结构
  9. print(f'indian df: {indian_df.shape}')#输出数据结构
  10. print(f'korean df: {korean_df.shape}')#输出数据结构

37a7a5b8f2d54363a3e1fdc109ba7543.png

3.各国最受欢迎食材可视化

        现在,我们可以更深入的挖掘数据,并了解每种菜肴的成分。在此之前,我们应该对数据进行预处理,删去重复值。

        在python中创建一个函数来删除无用的列,然后按成分数量进行排序。

  1. def create_ingredient_df(df):
  2. ingredient_df = df.T.drop(['cuisine','Unnamed: 0']).sum(axis=1).to_frame('value')#从原始数据中删除无效列并统计axis=1的总和赋值给value列
  3. ingredient_df = ingredient_df[(ingredient_df.T != 0).any()]#提取所有非0值的样本
  4. ingredient_df = ingredient_df.sort_values(by='value', ascending=False,
  5. inplace=False)#按照数值的大小进行排序
  6. return ingredient_df#返回处理并排序后的结果

        现在,我们调用create_ingredient_df()函数来了解泰国美食中,最受欢迎的十大食材。输入以下代码:

  1. thai_ingredient_df = create_ingredient_df(thai_df)
  2. thai_ingredient_df.head(10).plot.barh()

        查看泰国美食中十大最受欢迎的食材:

016773948d374a64b73075572a82738c.png

        我们可以看到,第一名的食材是garlic(大蒜),第十名则是chicken(鸡肉)。

        对中国数据执行同样的操作:

  1. chinese_ingredient_df = create_ingredient_df(chinese_df)
  2. chinese_ingredient_df.head(10).plot.barh()

        结果如下,中国美食食材中,最受欢迎的是soy_sauce(酱油),第十名是cayenne(红辣椒):

4d606c93c827479787918d483fd9aca9.png

        同理,我们也可以输出其他国家的美食食材,这里就不水文了。大家可自行尝试。

        现在,我们需要使用drop()函数删除不同美食间造成混淆的最常见成分以突出各国食材的特色,每个国家的人都喜欢米饭、大蒜和生姜。(可自行可视化查看),我们输入以下代码将其删去。

  1. feature_df= df.drop(['cuisine','Unnamed: 0','rice','garlic','ginger'], axis=1)#删去最常见的这几列以平衡不同国家之间的混淆
  2. labels_df = df.cuisine

4.平衡数据集

        现在我们已经清理了数据,因为不同国家的样本数量差异较大,我们需要使用SMOTE(“合成少数过度采样技术”)来平衡它。

SMOTE介绍

①样本插值采样

        调用SMOTE对象的 fit_resample() 函数来插值重采样生成新样本。输入以下代码:

  1. oversample = SMOTE()
  2. transformed_feature_df, transformed_label_df = oversample.fit_resample(feature_df, labels_df)
  3. print(f'new label count: {transformed_label_df.value_counts()}')
  4. print(f'old label count: {df.cuisine.value_counts()}')

        查看新样本与旧样本的数据量差异:

53fd56e6e4514e8db52bce861f874eeb.png

        我们可以看到,新样本不同类别标签下的样本数量都在799,它是以旧样本中最大样本量标签为基础构建的,而旧样本则参差不齐,分布不均匀。

        数据质量很好,干净、平衡!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/429399
推荐阅读
相关标签