当前位置:   article > 正文

李沐深度学习-softmax从零开始_深度学习中softmax函数的使用

深度学习中softmax函数的使用
import torch
import torchvision
import numpy as np
import sys

sys.path.append("路径")
import d2lzh_pytorch as d2l

'''
1. 获取和读取数据
2. 初始化参数和模型
3. 定义softmax运算
4. 定义模型
5. 定义损失函数:交叉熵损失函数
6. 定义分类准确率
7. 训练模型
8. 预测
'''

'''
----------------- !!分类计算中,每个样本都要进行标签中类别数个预测,来判断该样本属于那种分类的概率大!!!!
'''

'''
-----------------------------------------------------------获取和读取数据
'''
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

'''
----------------------------------------------------------初始化模型参数
'''
num_inputs = 784  # 一个图像是28x28大小,输入特征个数就是784个  w分10类,一类也需要784个值,b只需要10个分类偏置
num_outputs = 10  # 分类个数
w = torch.tensor(np.random.normal(0, 0.01, size=(num_inputs, num_outputs)), dtype=torch.float, )  # 归一化生成w:784x10
b = torch.zeros(num_outputs, dtype=torch.float)
w.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True)

'''
----------------------------------------------------------定义softmax运算
'''
# 如何对多维Tensor按维度进行操作
# 在以下的操作中和对其中同一列(dim=0)或同一行(dim=1)的元素进行求和,并在结果中保留行和列这两个维度(keepdim=True)
X = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(X.sum(dim=0, keepdim=True))
print(X.sum(dim=1, keepdim=True))


def softmax(X):  # 这里X不是样本,而是经过线性运算之后的结果,行数代表样本数,列数代表种类数/输出个数
    X_exp = X.exp()  # 先对X中的每个元素幂指数化
    partition = X_exp.sum(dim=1, keepdim=True)  # 每一行进行幂指数求和,得到分母
    return X_exp / partition  # 这里运用了广播机制


X = torch.rand(2, 5)
result = softmax(X)
print(result)  # 经过softmax处理后,得到了预测输出在每个类别上的预测概率分布

'''
-------------------------------------------------------------------------定义模型
'''


def net(X):
    return softmax(torch.mm(X.view(-1, num_inputs), w) + b)  # 这里是进行线性矢量计算,可以是单个样本也可以是批量样本计算


'''
------------------------------------------------------------------------定义损失函数:交叉熵损失函数
'''
# 为了得到标签的预测概率,可以使用gather函数,下面y_hat是2个样本在3个类别的预测概率,y是这2个样本的标签类别
# 通过使用gather函数,可以得到2个样本的标签的预测概率。 在代码中,标签类别的离散值是从0开始逐一递增的
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 2])  # 这是标签的类别
y_hat.gather(1, y.view(-1, 1))  # 应该根据y中标签类别确定y_hat中标签类别的对应位置然后取出这个类别的概率,这里应该取出


# y_hat中的 (0,0),(1,2)位置处的值,因为这里是和y中类别0,2所对应的概率位置


def cross_entroy(y_hat, y):
    return -torch.log(y_hat.gather(1, y.view(-1, 1)))  # 交叉熵只关心对正确类别的预测概率,通过使用gather函数,可以得到2个样本的标签的预测概率。


tensor_0 = torch.arange(3, 12).view(3, 3)
index = torch.tensor([[2, 1, 0]]).t()
output = tensor_0.gather(1, index)
print(output)
'''

-------------------------------------------------------------------计算分类准确率
'''


# y_hat是一个预测概率分布,把分布中预测概率最大的作为输出类别,如果与真实类别y一直,则表示预测准确
# 分类预测率=正确预测个数与总预测数量之比


def accuracy(y_hat, y):
    return (y_hat.argmax(dim=1) == y).float().mean().item(), (y_hat.argmax(dim=1) == y).float()
    # y_hat.argmax(dim=1)返回矩阵y_hat每行中最大的元素索引,且返回结果与变量y形状相同
    # 上述判断式是一个类型为ByteTensor的Tensor,使用float()将其转换为值为0(相等为假)或1(相等为真)的浮点型Tensor


print(accuracy(y_hat, y))  # 50%准确率,第一个预测错误,第二个预测正确
print(d2l.evaluate_accuracy(test_iter, net))  # 预测了mnist的test数据集softmax运算精度

'''
-------------------------------------------------------------------------------------训练模型
'''
num_epochs, lr = 5, 0.1
result = d2l.train_ch3(net, train_iter, test_iter, cross_entroy, num_epochs, batch_size, [w, b], lr)

'''
-------------------------------------------------------------------------------------预测
'''
X, y = next(iter(test_iter))  # test_iter返回的是一个迭代器对象,需要使用next()函数进行调用
true_labels = d2l.get_fashion_mnist_labels(y.numpy())  # 获取了test数据集中的真实标签并进行转义
pred_labels = d2l.get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())  # 将预测函数net返回的y_hat取其每行最大值的下标索引
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]  # 列表中使用for循环

# 一次next()只访问了一个批量元组,X就是一个列表
d2l.show_fashion_mnist(X[0:9], titles[0:9])

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/黑客灵魂/article/detail/939104
推荐阅读
相关标签
  

闽ICP备14008679号