当前位置:   article > 正文

tensorflow框架精细讲解(一)_transflow教程

transflow教程

前言

大四的时候大致的看过一本基于tensorflow的实战Google深度学习框架的书,目前看论文源码也好,修改代码做改进也好,很多基本知识还是源于那个时候。这是远远不够的,为此,我在github上找了一个基于tensorflow的实例管理教程,来再细致的学习一下tensorflow,希望能够增强自己读代码,写代码的能力,对深度学习也有更好的理解。

一.数据准备

具体的学习过程,因为有之前的一些基础,为此直接从各种神经网络模型入手,来学习tensorflow框架,并且还可以对模型进一步的进行理解。在模型的搭建训练之前,首先就是训练测试数据的输入是如何实现的。下面结合代码,分块讲解。

数据集获取。建议先下载到本地,通过函数去下载相对不太稳定,耗时较长。

from __future__ import print_function
import gzip
import os
import urllib
import numpy
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def maybe_download(filename, work_directory):
  """如果数据集不存在,从yann's的网站下载所需的数据集."""
  if not os.path.exists(work_directory):
    os.mkdir(work_directory)
  filepath = os.path.join(work_directory, filename)
  if not os.path.exists(filepath):
    filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath)
    statinfo = os.stat(filepath)
    print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
  return filepath
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

定义了所获取到的字节流的存储,采用大尾端的方式,返回值就是字节流中的前四个:magic,num_images,rows,cols。

def _read32(bytestream):
  dt = numpy.dtype(numpy.uint32).newbyteorder('>')
  return numpy.frombuffer(bytestream.read(4), dtype=dt)
  • 1
  • 2
  • 3

对图片进行提取并将其转化为一个四维的numpy数组。

def extract_images(filename):
  """将输入的图片转化为一个uint8类型四维的numpy数组 [index, y, x, depth]."""
  print('Extracting', filename
  • 1
  • 2
声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号