当前位置:   article > 正文

pytorch学习笔记(二)——pytorch手写数字识别_plt.figure() 手写

plt.figure() 手写

手写数字识别原理 

每张照片用长28宽28个像元的灰度信息表示

将28*28[28,28]的矩阵打平(flat)成784个像素[784],则可以忽略二维位置相关性,再插入一个维度变成[1,784]

使用三个线性函数y=wx+b的嵌套来解决手写数字识别问题

H1 = XW1 + b1    W1: [d1, dx],b1: [d1]

H2 = H1W2 + b2   W2: [d2, d1],b2: [d2]

H3 = H2W3 + b3   W3: [d3, d2],b3: [d3]

H1,H2,H3的维度分别为[1,d1], [1,d2], [1,d3]

H3中1表示照片数量,d3表示0-9数字

使用one-hot编码方式Y: [0/1/…/9],避免label之间有大小属性的关系,如果分10类就是一个十维的向量

预测值

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/369107
推荐阅读
相关标签