当前位置:   article > 正文

第二章 pytorch回归问题

第二章 pytorch回归问题

一、梯度下降算法

  • 用于反向传播时求解权重矩阵的最优解

二、线性回归问题与逻辑回归问题

  • 主要是使用使用激活函数将函数值限定在有限范围内(即:[0,1] or [-1,1] .etc)

三、线性回归实战

# -*- coding: UTF-8 -*-
'''
@version: 1.0
@PackageName: pytorch_learning - regression_demo.py
@author: yonghao
@Description: 实现线性回归
@since 2021/01/24 21:29
'''
import numpy as np


def compute_error_for_line_given_points(w, b, points) -> float:
'''
计算样本的平均损失
:paramw: 权重值
:paramb: 误差值
:parampoints: 样本点
:return: 返回样本平均损失值
'''
total_error = 0
for i in range(0, len(points)):
x = points[i, 0]
y = points[i, 1]
total_error += (y - (w * x + b)) ** 2
return total_error / float(len(points))


def step_gradient(b_current, w_current, points, learning_rate):
'''
计算当前样本的梯度值
:paramb_current:当前的拟合线性方程的常数值
:paramw_current:当前的拟合线性方程的斜率
:parampoints: 样本
:paramlearning_rate:学习率
:return: 返回梯度下降后的参数
'''
b_gradient = 0
w_gradient = 0
N = float(len(points))
for i in range(0, len(points)):
x = points[i, 0]
y = points[i, 1]
b_gradient += -(2 / N) * (w_current * x + b_current - y)
w_gradient += -(2 / N) * (w_current * x + b_current - y) * x
    new_w = w_current - learning_rate * w_gradient
    new_b = b_current - learning_rate * b_gradient
 return new_w, new_b


def gradient_descent_runner(points, starting_b, starting_w, learning_rate, num_iterations):
'''
梯度下降算法
:parampoints:样本点集
:paramstarting_b: 起始的b
:paramstarting_w: 起始的w
:paramlearning_rate: 学习率
:paramnum_iterations: 迭代次数
:return: 计算得到的权重系数与误差系数
'''
w = starting_w
    b = starting_b
 for i in range(num_iterations):
w, b = step_gradient(b, w, points, learning_rate)
return w, b


def run():
points = np.random.uniform(0, 5, (100, 2))
learning_rate = 0.0001
initial_b = 0
initial_w = 0
num_iterations = 1000
print('Starting gradient descent at b = {},w = {},error={}'
.format(initial_b, initial_w,
compute_error_for_line_given_points(initial_w,
initial_b,
points)))
w, b = gradient_descent_runner(points, initial_b, initial_w,
learning_rate, num_iterations)
print('After gradient descent at b = {},w = {},error={}'
.format(b, w,
compute_error_for_line_given_points(w,
b,
points)))


if __name__ == '__main__':
run()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88

四、分类问题

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述


五、手写数字识别体验

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

# -*- coding: UTF-8 -*-
'''
@version: 1.0
@PackageName: pytorch_learning - classification.py
@author: yonghao
@Description: 实现手写数字的分类时间
@since 2021/01/24 22:38
'''
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim

import torchvision
from matplotlib import pyplot as plt
from realwork.work2_handwrite_classification.utils import plot_curve, plot_image, one_hot

batch_size = 512

# step1 .load dataset
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=False)
# 测试显示数据集的函数
# x, y = next(iter(train_loader))
# print(x.shape, y.shape)
# print(test_loader)
# plot_image(x, y, 'image sample')
print("train num = {}".format(len(train_loader)))
print("test num = {}".format(len(test_loader)))


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# xw+b
self.fc1 = nn.Linear(28 * 28, 256)
self.fc2 = nn.Linear(256, 64)
self.fc3 = nn.Linear(64, 10) # 由于是10分类所以最后层输出一定是10

def forward(self, x):
# x:[b,1,28,28]
# h1 = relu(xw1+b1)
x = F.relu(self.fc1(x))
# h2 = relu(h1w2+b2)
x = F.relu(self.fc2(x))
# h3 = h2w3+b3
x = self.fc3(x)
return x


# 创建网络
net = Net()
# 定制的梯度下架算法计算器:[w1,b1,w2,b2,w3,b3]
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
# 保存梯度计算过程中的损失值
loss_values = []
for epoch in range(3):
for batch_idx, (x, y) in enumerate(train_loader):
# x:[b,1,28,28],y:[512]
# x:[b,1,28,28] -> [b,784]
x = x.view(x.size(0), x.size(2) * x.size(3))
# => [b,10]
out = net(x)
# [b,10]
y_onehot = one_hot(y)
# loss = mse(out,y_onehot)
loss = F.mse_loss(out, y_onehot)

optimizer.zero_grad()
loss.backward()
# w' = w - lr*grad
optimizer.step()
loss_values.append(loss.item())
# 每10个batch显示一次loss值
# if batch_idx % 10 == 0:
#     print(epoch, batch_idx, loss.item())

# we get optimal [w1,b1,w2,b2,w3,b3]
# 显示loss的变化情况
# plot_curve(loss_values)

# 由测试集显示其准确度
total_correct = 0
for x, y in test_loader:
x = x.view(x.size(0), 28 * 28)
out = net(x)
# out: [b,10] => pred: [b]
# 将soft_one_hot值转换为hard_one_hot值,使其与真实标签一致
# 标注最大值的位置(将高维空间表示为一维中的位置从0开始)
pred = out.argmax(dim=1)
# 判断正确的总数
correct = pred.eq(y).sum().float().item()
total_correct += correct

total_dataset = len(test_loader.dataset)
acc = total_correct / total_dataset
print("test acc:{}%".format(acc * 100))

next(iter(test_loader))
x, _ = next(iter(test_loader))
out = net(x.view(x.size(0), 28 * 28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号