赞
踩
1、代码实现 # 计算赤峰面积和房价之间的关系 import numpy as np import matplotlib.pyplot as plt # 构建数据集 data = [] # (70-90,90-100,100-110,110-130) for i in range(300): # 面积(训练集) area = np.random.uniform(60, 100) # 房价 eps2 = np.random.uniform(60, 62) # bias eps3 = np.random.uniform(200., 700.) # 总房价(标签) price = eps2 * area + eps3 # 随机生成一个线性方程,大小为(500,1) data.append([area, price]) data = np.array(data) # 数据集创建完毕 2维数组 [面积,房价] # 将参数分出来方便之后的使用 area = data[:, 0] price = data[:, 1] # 绘制原始数据 plt.title("Area-Price") # 标题名 plt.scatter(area, price, s=10) # 设置为散点图 plt.xlabel("area") # x轴的标题 plt.ylabel("price") # y轴的标题 plt.show() # 绘制出来 # 创建一个loss值的list loss_list = [] def mse(b, w, data): # 根据当前的 w,b,参数计算均方差损失 TotalError = 0 # 记录总误差 for i in range(0, len(data)): x = data[i, 0] y = data[i, 1] TotalError += (y - (w * x + b)) ** 2 return TotalError / float(len(data)) def gradient_update(b, w, data, lr): b_gradient = 0 w_gradient = 0 size = float(len(data)) for i in range(0, len(data)): x = data[i, 0] y = data[i, 1] # 计算梯度 b_gradient += (2 / size) * ((w * x + b) - y) w_gradient += (2 / size) * x * ((w * x + b) - y) # 根据梯度更新权重和偏置 b -= lr * b_gradient w -= lr * w_gradient return [b, w] # 梯度下降法 def gradient_descent(data, b, w, lr, num_iterations): # 因为没有batch,所以num_iterations即为epoch for num in range(num_iterations): # 更新参数 b, w = gradient_update(b, w, data, lr) # 计算损失值并添加到损失列表 loss = mse(b, w, data) loss_list.append(loss) print('iteration:[%s] | loss:[%s] | w:[%s] | b:[%s]' % (num, loss, w, b)) return [b, w] def main(): lr = 0.00001 initial_b = np.random.randn(1) initial_w = np.random.randn(1) num_iterations = 100 # 因为没有batch,所以num_iterations即为epoch [b, w] = gradient_descent(data, initial_b, initial_w, lr, num_iterations) loss = mse(b, w, data) print('Final loss:[%s] | w:[%s] | b:[%s]' % (loss, w, b)) # 损失函数 plt.title("Loss Function") # 标题名 plt.plot(np.arange(0, 100), loss_list) plt.xlabel('Interation') plt.ylabel('Loss Value') plt.show() # 绘制 y2 = w * area + b print(w * 100 + b) plt.title("Fit the line graph") # 标题名 plt.scatter(area, price, label='Original Data', s=10) # 设置为散点图 plt.plot(area, y2, color='Red', label='Fitting Line', linewidth=2) plt.xlabel('m_j') plt.ylabel('j_g') plt.legend() plt.show() main() 2.代码结果
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。