当前位置:   article > 正文

Pythorch笔记(1)读取数据集,建立模型,训练模型_创建矩阵模型 训练 保存 读取

创建矩阵模型 训练 保存 读取

导入库函数

import torch//导入pythorch库

from torch import nn//快速创建复杂模型

import pandas as pd //Pandas库可以用来进行数据清洗和格式转换、数据分析和统计、数据可视化、数据读取和存储、数据合并和拼接等等,即一个数据分析工具

import numpy as np //将numpy视作np,为了简便,以np代表numpy。该库可用来存储和处理大型矩阵

import matplotlib.pyplot as plt //该库是绘图库,也提供了一些用于图像处理的函数和工具,例如读取图像文件、调整图像大小、旋转图像、改变图像颜色等

读取数据

data=pd.read_数据集文件格式(‘数据集路径’)
例如,data=pd.read_csv(‘dataset/income1.csv’)//文件格式为csv,文件路径为dataset/income1.csv
plt.scatter(x,y) //绘制散点图。x代表x轴,y代表y轴,
plt.xlabel(‘A’) //将X轴的名称写为A,同理得y轴
** plt.show()** //将散点图可视化

实际操作

在这里插入图片描述
需要注意几个点,
1.在anaconda导入pandas,numpy等库;
2.read_excel()这个函数,首先需要在你对应的文件夹里设置好文件目录。其次,如果pandas版本过低,大概率只能用excel或者csv这两个格式,本文这里我是用的wps的表格,但是还是可以用excel格式读出来。

建立模型

创建模型分为以下几步:
1.预处理数据:需要把数据reshape成pythorch能够识别的数据,原来的
X=data.edu
包含的是一个数据,但是我们需要多个数据分开,因此变为
X=data.edu.values.reshape(-1,1).astype(np.float32)
//astype()转换数据类型,目的是希望我们的数据的类型都是统一的,这里将数据类型都转换为32位float类型;reshape(-1,1)的意思是-1代表自动计算,1是数据长度
在这里插入图片描述
在这里插入图片描述
2.将数据变成Tensor流(pytorch能识别):torch.from_numpy() //from numpy to torch

在这里插入图片描述
3.创建模型(本文指代线性模型)model=nn.Linear(输入数据长度,输出数据长度)
例如model=nn.Linear(1,1)//创建了一个输入输出长度都为1的线性模型,Linear不是linear
// output=w*input+b等价于model(input)
4.设置损失函数loss_fn=nn.MSELoss() //得到损失函数,MSELoss是均方损失,和上面的Linear都是nn里自带的算法
5. 进行梯度下降opt=torch.optim.SGD(model.parameters(),lr=0.0001) //nn里的随机优化算法;lr是学习率

训练模型

for epoch in range(5000): //把所有数据遍历一遍叫epoch,这里是训练5000次epoch
for x,y in zip(X,Y): //对输入输出数据进行迭代
y_pred=model(x) //我们预测的y值
loss=loss_fn(y,y_pred) //计算预测值与真实值的均方误差
opt.zero_grad() //清零梯度,不清零会导致梯度累加,则永远也找不到极值点
loss.backward() //反向传播函数nn里自带的,用来求解梯度,即参数应该怎么变
opt.step() //模型优化,寻找最佳参数
model.weight //显示最优参数w
model.bias //显示最优参数b

模型结果

plt.scatter(data.edu,data.income)
plt.plot(X.numpy(),model(X).data.numpy(),c=‘r’)

//scatter是散点图,plot是折线图。因为我们一开始把数据转成了pytorch识别的Tensor,现在将Tensor数据转换成numpy类型的,目的是能够让numpy识别之后进行绘图,model(X)得到的只是模型的式子,并不是数据,因此需要.data操作,c=‘r’代表把折线的颜色变红

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/天景科技苑/article/detail/751192
推荐阅读
相关标签
  

闽ICP备14008679号