赞
踩
大四的时候大致的看过一本基于tensorflow的实战Google深度学习框架的书,目前看论文源码也好,修改代码做改进也好,很多基本知识还是源于那个时候。这是远远不够的,为此,我在github上找了一个基于tensorflow的实例管理教程,来再细致的学习一下tensorflow,希望能够增强自己读代码,写代码的能力,对深度学习也有更好的理解。
具体的学习过程,因为有之前的一些基础,为此直接从各种神经网络模型入手,来学习tensorflow框架,并且还可以对模型进一步的进行理解。在模型的搭建训练之前,首先就是训练测试数据的输入是如何实现的。下面结合代码,分块讲解。
数据集获取。建议先下载到本地,通过函数去下载相对不太稳定,耗时较长。
from __future__ import print_function import gzip import os import urllib import numpy SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' def maybe_download(filename, work_directory): """如果数据集不存在,从yann's的网站下载所需的数据集.""" if not os.path.exists(work_directory): os.mkdir(work_directory) filepath = os.path.join(work_directory, filename) if not os.path.exists(filepath): filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath) statinfo = os.stat(filepath) print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') return filepath
定义了所获取到的字节流的存储,采用大尾端的方式,返回值就是字节流中的前四个:magic,num_images,rows,cols。
def _read32(bytestream):
dt = numpy.dtype(numpy.uint32).newbyteorder('>')
return numpy.frombuffer(bytestream.read(4), dtype=dt)
对图片进行提取并将其转化为一个四维的numpy数组。
def extract_images(filename):
"""将输入的图片转化为一个uint8类型四维的numpy数组 [index, y, x, depth]."""
print('Extracting', filename
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。