当前位置:   article > 正文

AI模型实战_使用raw文件训练ai模型

使用raw文件训练ai模型

背景

通过示例梳理AI模型训练流程,示例比较简单,方便演示。

流程

机器学习实战步骤

  • 定义问题
  • 数据收集和预处理
  • 选择算法并建立模型
  • 训练模型
  • 模型评估和优化

示例

定义问题

根据公开数据集预测加州房价分布

数据收集

import pandas as pd #导入Pandas,用于数据读取和处理
# 读入房价数据,示例代码中的文件地址为internet链接,读者也可以下载该文件到本机进行读取
# 如,当数据集和代码文件位于相同本地目录,路径名应为"./house.csv",或直接放"house.csv"亦可
df_housing = pd.read_csv("https://raw.githubusercontent.com/huangjia2019/house/master/house.csv") 
df_housing.head #显示加州房价数据
  • 1
  • 2
  • 3
  • 4
  • 5

运行

<bound method NDFrame.head of        longitude  latitude  housing_median_age  total_rooms  total_bedrooms  \
0        -114.31     34.19                15.0       5612.0          1283.0   
1        -114.47     34.40                19.0       7650.0          1901.0   
2        -114.56     33.69                17.0        720.0           174.0   
3        -114.57     33.64                14.0       1501.0           337.0   
4        -114.57     33.57                20.0       1454.0           326.0   
...          ...       ...                 ...          ...             ...   
16995    -124.26     40.58                52.0       2217.0           394.0   
16996    -124.27     40.69                36.0       2349.0           528.0   
16997    -124.30     41.84                17.0       2677.0           531.0   
16998    -124.30     41.80                19.0       2672.0           552.0   
16999    -124.35     40.54                52.0       1820.0           300.0   
       population  households  median_income  median_house_value  
0          1015.0       472.0         1.4936             66900.0  
1          1129.0       463.0         1.8200             80100.0  
2           333.0       117.0         1.6509             85700.0  
3           515.0       226.0         3.1917             73400.0  
4           624.0       262.0         1.9250             65500.0  
...           ...         ...            ...                 ...  
16995       907.0       369.0         2.3571            111400.0  
16996      1194.0       465.0         2.5179             79000.0  
16997      1244.0       456.0         3.0313            103600.0  
16998      1298.0       478.0         1.9797             85800.0  
16999       806.0       270.0         3.0147             94600.0  
[17000 rows x 9 columns]>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

构建特征

X = df_housing.drop("median_house_value",axis = 1) #构建特征集X
y = df_housing.median_house_value #构建标签集y
  • 1
  • 2

选择算法并建立模型

from sklearn.model_selection import train_test_split #导入数据集拆分工具
X_train, X_test, y_train, y_test = train_test_split(X, y, 
         test_size=0.2, random_state=0) #以80%/20%的比例进行数据集的拆分
  • 1
  • 2
  • 3

训练模型

线性回归算法是最简单、最基础的机器学习算法,它其实就是给每一个特征变量找参数的过程。
fit 的核心就是减少损失,使函数对特征到标签的模拟越来越贴切
通过梯度下降,逐步优化模型的参数,使训练集误差值达到最小

from sklearn.linear_model import LinearRegression #导入线性回归算法模型
model = LinearRegression() #使用线性回归算法
model.fit(X_train, y_train) #用训练集数据,训练机器,拟合函数,确定参数

y_pred = model.predict(X_test) #预测测试集的Y值
print ('房价的真值(测试集)',y_test)
print ('预测的房价(测试集)',y_pred)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

运行

房价的真值(测试集) 3873     171400.0
3625     189600.0
3028     500001.0
13814    229400.0
15398    163400.0
           ...   
1363     212500.0
7947     210500.0
14574    142900.0
10009    128300.0
9149      84700.0
Name: median_house_value, Length: 3400, dtype: float64
预测的房价(测试集) [211157.06335417 218581.64298574 465317.31295563 ... 201751.23969631
 160873.51846958 138847.26913352]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

模型评估

print("给预测评分:", model.score(X_test, y_test)) #评估预测结果
  • 1

运行

给预测评分: 0.632101417157948
  • 1

显示

import matplotlib.pyplot as plt #导入matplotlib画图库
#用散点图显示家庭收入中位数和房价中位数的分布
plt.scatter(X_test.median_income, y_test,  color='brown')
#画出回归函数(从特征到预测标签)
plt.plot(X_test.median_income, y_pred, color='green', linewidth=1)
plt.xlabel('Median Income') #X轴-家庭收入中位数
plt.ylabel('Median House Value') #Y轴-房价中位数
plt.show() #显示房价分布和机器习得的函数图形
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

运行

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/315059
推荐阅读
相关标签
  

闽ICP备14008679号