当前位置:   article > 正文

机器学习代码实现:一元线性回归(最小二乘法)_最小二乘法可以求得闭式解

最小二乘法可以求得闭式解

机器学习三要素:

1、建模思想:试图学得一个一元线性模型 z = f ( x ) = w x + b ( 一 条 直 线 ) z = f(x) = wx+b(一条直线) z=f(x)=wx+b(线),使得其输出的预测值 w x + b wx+b wx+b 与样例的真实标记 y y y 尽可能接近。

数据集如图:在这里插入图片描述
样本只有一个属性描述,左边是样本x,右边是真实标记y


2、策略:得到损失函数,即均方误差的表达式

  • 用均方误差表示损失函数,相当于得到样本 到一元线性模型的 欧氏距离的平方和
    在这里插入图片描述

3、求出使得损失函数最小化的参数:w,b

  • 使用最小二乘参数估计(损失函数分别对w和b求偏导),求使得损失函数最小的w和b。
  • 最小二乘法可以得到最优闭式解。
  • 在这里插入图片描述在这里插入图片描述

把求得的w和b带入损失函数公式,就可以得到:

  • 损失函数,计算出损失(均方误差)
  • 一元线性回归模型: z = f ( x ) = w x + b ( 一 条 直 线 ) z = f(x) = wx+b(一条直线) z=f(x)=wx+b(线)。画出该拟合直线。可散点图对比。

在这里插入图片描述


#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2022/4/3 20:34
# @Author : cc
# @File : Least square method.py
# @Software: PyCharm


import numpy as np
import matplotlib.pyplot as plt
from numpy import array

""" 使用最小二乘法,拟合出一元线性回归模型:z = wx + b。
一元的意思是样本x通过一个属性描述,原本可能是矢量x_i = (x_i1, x_i2...,x_id)被例如颜色,大小...
属性描述,现在只有一个x_i1描述,则直接把矢量x_i看成标量,w也是标量

计算出使得损失最小的w和b,
画出拟合直线和原始的散点图
点距离拟合直线越远,代表误差越大
"""


# 画出样例的真实分布,输入样本x和真实标记y
def plot_origin(points: array) -> None:
    """
    :param points: array类型的二维数组
    :return:
    """
    arr_x = points[:, 0]  # return list: 所有元素(子数组)中的第一个元素:x
    arr_y = points[:, 1]  # return list: 所有元素(子数组)中的第二个元素:y

    # 画出散点图:照理说学的时间越长,考试分数越高
    plt.scatter(arr_x, arr_y)
    plt.show()


# 2. 策略:求均方误差(损失函数)
def compute_cost(w: float, b: float, points: array) -> float:
    """ 计算均方误差(损失函数):预测输出和真实标记之间的差距: y - (wx + b)
                z = wx + b 为我们要拟合的线性模型
    :param w: 线性模型参数
    :param b: 线性模型参数
    :param points: array类型的二维数组:所有样例
    :return:  输出E(w,b) 均方误差值
    """
    total_cost = 0
    m = len(points)  # 样本个数m

    # 计算均方误差
    for i in range(m):
        x_i = points[i, 0]  # 第i个样例的第一个元素:x
        y_i = points[i, 1]  # 第i个样例的第二个元素:y
        total_cost += (y_i - w * x_i - b) ** 2
    return total_cost / m  # 均方误差


"""3. 算法:拟合:学得z = wx+b 近似于真实标记y
使用基于均方误差最小化的 最小二乘参数估计
求能使得均方误差最小的w和b:损失函数分别对w,b求偏导=0
以下代码都基于公式推导出来的w,b的表示方法
"""


# 求列表内元素平均值
def avg(lst):
    l = len(lst)
    return sum(lst[i] for i in range(l)) / l


def fit(points: array) -> tuple:
    x_avg = avg(points[:, 0])  # 样本均值
    m = len(points)
    # 求w
    numerator, denominator = 0, -m * x_avg ** 2  # w公式的分子,分母
    for i in range(m):
        x_i, y_i = points[i, 0], points[i, 1]
        numerator += y_i * (x_i - x_avg)
        denominator += x_i ** 2
    w = numerator / denominator

    # 求b
    sum_y = 0
    for i in range(m):
        x_i, y_i = points[i, 0], points[i, 1]
        sum_y += y_i
    b = (sum_y - w * x_avg * m) / m
    return w, b


# 画出拟合函数:一元线性回归模型
def plot_fit(arr_x: array, arr_y: array) -> None:
    plt.scatter(arr_x, arr_y)  # 画散 点 图
    # array类型可以直接对每个元素乘上一个常数,不用for循环慢慢一个个乘
    predict_y = w * arr_x + b  # 拟合的线性模型:预测标记y
    plt.plot(x, predict_y, c='r')  # 画 经过x,y的曲线/直线
    plt.show()


if __name__ == '__main__':
    points = np.genfromtxt('data.csv', delimiter=',')  # array
    x = points[:, 0]  # return array: 所有元素(子数组)中的第一个元素:x
    y = points[:, 1]  # return array: 所有元素(子数组)中的第二个元素:y

    w, b = fit(points)
    print('w, b分别为', w, b)

    print('损失为:', compute_cost(w, b, points))

    plot_fit(x, y)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110

数据集:链接: https://pan.baidu.com/s/1p1GuA9aV2BtwOCIk_YxHbw?pwd=tj5d 提取码: tj5d
–来自百度网盘超级会员v4的分享

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号