赞
踩
参考:
主要参考 课时5 简单回归问题-2_哔哩哔哩_bilibili,
系统的回顾一下pytorch
目录
1: 简单回归问题
2: 回归问题实战
一 简单回归问题(Linear Regression)
根据预测值,或者标签值不同
线性回归问题
损失函数
参数学习:
梯度下降原理(泰勒公式展开)
设
二 回归问题实战
数据集
模型
参数学习
设
梯度
参数更新
三 例
3.1 训练部分
- # -*- coding: utf-8 -*-
- """
- Created on Thu Nov 10 21:33:37 2022
- @author: cxf
- """
-
- import torch
- import numpy as np
- from torch.utils.data import Dataset, DataLoader
- from draw import draw_loss
- #需要继承data.Dataset
- class MyDataset(Dataset):
-
-
- def __init__(self, data, target):
-
- self.x = data
- self.y = target
- self.len = self.x.shape[0] #样本个数
-
-
- def __getitem__(self, index):
-
- x = self.x[index]
- y = self.y[index]
- return x,y
-
-
- def __len__(self):
- return self.len
-
- #linear regression
- class LR:
-
- '''
- 预测值
- args
- w: 权重系数
- b: 偏置系数
-
- '''
- def predict(self,w,b,x):
-
-
- predY= torch.mm(w.T,x)+b
- return predY
-
-
- '''
- 梯度更新
- args
- w_cur: 权重系数
- b_cur 偏置
- trainX: 训练数据集
- trainY: 标签集
- '''
- def step_gradient(self,w_cur,b_cur, trainX,trainY):
-
-
- w_gradient = 0
- b_gradient = 0
- m = trainX.shape[0]
- N = float(m)
-
-
- for i in range(0,m):
-
- x = trainX[i].view(self.n,1)
- y = trainY[i]
-
- predY = self.predict(w_cur,b_cur,x)
- delta = predY - y
-
- b_gradient +=(2/N)*delta
- w_gradient +=(2/N)*delta*x
-
- new_b = b_cur- self.learnRate*b_gradient
- new_w = w_cur- self.learnRate*w_gradient
-
- return new_w,new_b
-
-
- '''
- 梯度下降训练
- args
- dataX: 数据集
- dataY: 标签集
- '''
- def train(self,dataX,dataY):
-
-
-
- y_train_loss =[]
- b_cur = torch.zeros([1,1],dtype=torch.float)
- w_cur = torch.rand((self.n,1),dtype=torch.float)
- trainData = MyDataset(dataX, dataY)
- train_loader = DataLoader(dataset = trainData, batch_size =self.batch, shuffle = True,drop_last =True)
-
- for epoch in range(self.maxIter):
-
- for step, (batch_x, batch_y) in enumerate(train_loader):
-
- w,b = self.step_gradient(w_cur, b_cur, batch_x,batch_y)
- w_cur = w
- b_cur = b
-
- loss = self.compute_error(w, b, dataX, dataY)
- #print("\n epoch: ",epoch,"\n loss ",loss)
- y_train_loss.append(loss)
- return y_train_loss
-
-
-
-
-
- def compute_error(self,w,b, dataX,dataY):
-
- totalError = 0.0
- m = len(dataX)
-
- for i in range(0,m):
-
- x = dataX[i].view(self.n,1)
- y = dataY[i]
-
- predY = self.predict(w, b, x)
- z = predY-y
-
- loss = np.power(z,2)
- totalError+=loss
- totalError = totalError.numpy()[0,0]
- return totalError
-
-
-
-
- '''
- 加载数据集
- '''
- def loadData(self):
-
- data = np.genfromtxt("data.csv",delimiter=",")
-
-
- trainData = data[:,0:-1]
- trainLabel = data[:,-1]
-
-
- x = torch.tensor(trainData, dtype=torch.float)
- y = torch.tensor(trainLabel, dtype = torch.float)
- self.m ,self.n=x.shape[0],x.shape[1]
-
-
- print("\n m ",self.m,"\t n",self.n)
-
- return x,y
-
-
- def __init__(self):
-
- self.w = 0 #权重系数
- self.b = 0 #偏置
- self.m = 0 #样本个数
- self.n = 0 #样本维度
- self.batch = 20 #训练用的样本数
- self.maxIter = 1000 #最大迭代次数
- self.learnRate = 0.01 #学习率
-
-
-
- if __name__ == "__main__":
-
- lr = LR()
-
- x,y = lr.loadData()
- loss = lr.train(x, y)
- draw_loss(loss)
-
3.2 绘图部分
- # -*- coding: utf-8 -*-
- """
- Created on Mon Nov 14 20:14:28 2022
- @author: cxf
- """
-
- import numpy as np
- import matplotlib.pyplot as plt
-
-
-
-
-
- def draw_loss(y_train_loss):
-
-
-
-
- plt.figure()
-
- x_train_loss = range(len(y_train_loss))
-
-
- # 去除顶部和右边框框
- ax = plt.axes()
- ax.spines['top'].set_visible(False)
- ax.spines['right'].set_visible(False)
- #标签
- plt.xlabel('iters')
- plt.ylabel('accuracy')
-
- plt.plot(x_train_loss, y_train_loss, linewidth=1, linestyle="solid", label="train loss")
- plt.legend()
-
- plt.title('train loss')
- plt.show()
3.3 数据部分
- # -*- coding: utf-8 -*-
- """
- Created on Fri Nov 11 22:17:07 2022
- @author: cxf
- """
-
- import numpy as np
- import csv
-
-
- def makeData():
-
- wT = np.array([[1.0,2.0,2.0]])
- b = 0.5
- Data = np.random.random((200,3))
- m,n = np.shape(Data)
- trainData =[]
-
-
- for i in range(m):
-
-
- x = Data[i].T
- y = np.matmul(wT,x)+b
- item =list(x)
- item.append(y[0])
- trainData.append(item)
- return trainData
-
-
- def save(data):
-
- csvFile = open("data.csv",'w',newline='')
-
- wr = csv.writer(csvFile)
-
- m = len(data)
-
- for i in range(m):
- wr.writerow(data[i])
- csvFile.close()
-
-
-
- makeData()
-
- if __name__ =="__main__":
-
- data = makeData()
- save(data)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。