当前位置:   article > 正文

(三)基于Tensorflow设计VGGNet网络训练CIFAR-10图像分类_cifar-10-batches-bin

cifar-10-batches-bin

1 CIFAR-10数据集

1.0 数据集分类

CIFAR-10数据集的图像属性:

序号参数描述
1width32
2height32
3channels3

共分了十类物品,如下表:

序号种类源数据
1airplane
2automobile
3bird
4cat
5deer
6dog
7frog
8horse
9ship
10truck

1.2 获取图标

闲来无事做,爬了100张分类图像,如上图所示.

import requests, json, urllib
import os

def get_data():
	'''获取各分类链接'''
	i = 0
	classify = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
	urls = []
	for i in range(10):
		for j in range(10):
			# print(i)
			url = "http://www.cs.toronto.edu/~kriz/cifar-10-sample/{}{}.png".format(classify[i], j+1)
			# response = requests.get(url)
			urls.append(url)

	print("data urls: {}".format(urls))
	return urls

def download_images(urls):
	'''下载链接数据,保存到images文件夹下'''
	for i, url in enumerate(urls):
		image_name = url.split('/')[-1]
		print("No.{} images is downloading".format(i))
		urllib.request.urlretrieve(url, "images/"+image_name)

if __name__ == "__main__":
	download_images(get_data())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

1.3 获取CIFAR-10数据集

1.3.1 源码

直接爬取CIFAR-10数据集.

import time
import cifar10_input
import tarfile
from six.moves import urllib
import os
import sys

FLAGS = tf.app.flags.FLAGS
# 模型参数
tf.app.flags.DEFINE_string('data_dir', 'cifa10_data',
							"""Path to the CIFAR-10 data directory.""")

def data_stream():
	'''获取cifar-10数据并提取'''
	data_url = "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
	data_directory = FLAGS.data_dir
	print("Data directory: {}".format(data_directory))
	# 新建数据文件夹:cifar10_data
	if not os.path.exists(data_directory):
		os.mkdir(data_directory)
	# 获取文件名:cifar-10-binary.tar.gz
	filename = data_url.split('/')[-1]
	filepath = os.path.join(data_directory, filename)
	print("File path: {}".format(filepath))
	# 写入文件:cifar-10-binary.tar.gz
	if not os.path.exists(filepath):
		def _progress(count, block_size, total_size):
			sys.stdout.write('\r>>Downloading {} {}'.format(filename, float(count * block_size) / float(total_size) * 100.0))
			sys.stdout.flush()
			# 请求数据
		filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress)
		statinfo = os.stat(filepath)
		print("Successfully downloaded", filename, statinfo.st_size, 'bytes.')
	extracted_dir_path = os.path.join(data_directory, 'cifar-10-batches-bin')
	# 提取数据
	if not os.path.exists(extracted_dir_path):
		tarfile.open(filepath, 'r:gz').extractall(data_directory)
if __name__ == "__main__":
	data_stream()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39

1.3.2 目录结构

|-- cifa10_data
|   |-- cifar-10-batches-bin
|   |   |-- batches.meta.txt
|   |   |-- data_batch_1.bin
|   |   |-- data_batch_2.bin
|   |   |-- data_batch_3.bin
|   |   |-- data_batch_4.bin
|   |   |-- data_batch_5.bin
|   |   |-- readme.html
|   |   `-- test_batch.bin
|   `-- cifar-10-binary.tar.gz
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

其中,cifar-10-binary.tar.gz为源数据,cifar-10-batches-bin为提取的数据目录,内容如上所示.

1.3.3 数据集数据结构

数据集的数据包括图片数据和图片标签两部分,图像数据处理后的shape为(24, 24, 3),标签数据为整数,分10类,取值范围[0, 9],验证如下:

  • Demo
import tensorflow as tf
import cifar10_input

# 小批量数据大小  
batch_size = 512
# 数据所在路径
data_dir = "./cifa10_data/cifar-10-batches-bin"

def get_data():
	with tf.Session() as sess:
		# 初始化
		init_op = tf.global_variables_initializer()
		sess.run(init_op)
		# 提取源数据
		train_images, train_labels = cifar10_input.distorted_inputs(batch_size=batch_size, data_dir=data_dir)
		# 协程模式
		coord = tf.train.Coordinator()
		# 多线程
		threads = tf.train.start_queue_runners(coord=coord)
		# tensorflow处理数据
		batch_images, batch_labels = sess.run([train_images, train_labels])
		# 批量数据维度
		print("Images shape: {}, labels shape: {}".format(batch_images.shape, batch_labels.shape))
		# 单独一组数据维度
		print("Images shape: {}, labels shape: {}".format(batch_images[0].shape, batch_labels[0].shape))
		# 单独一组数据
		print("Images: {}, labels: {}".format(batch_images[0], batch_labels[0]))
		# 批量标签长度
		print("Length of labels: {}".format(len(batch_labels)))
		# 标签数据
		print("Labels: {}".format(batch_labels))
		# 线程阻塞
		coord.request_stop()
		# 等待子线程执行完毕
		coord.join(threads)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • Result
Images shape: (512, 24, 24, 3), labels shape: (512,)
Images shape: (24, 24, 3), labels shape: ()
Images: [[[-0.18781795 -0.14232759 -0.01706854]
  [ 0.41756073  0.42177528  0.49199972]
  [ 1.6420764   1.6050153   1.6477226 ]
  ...
  [-0.14654216 -0.15608616 -0.08586162]
  [-0.18781795 -0.19736204 -0.12713741]
  [-0.25661105 -0.26615503 -0.18217184]]

 [[-0.5730589  -0.5413272  -0.44358534]
  [ 0.77528447  0.7519817   0.8084478 ]
  [ 2.0961106   2.0177736   2.0329638 ]
  ...
  [-0.49050733 -0.5000513  -0.42982668]
  [-0.38043845 -0.38998246 -0.3197579 ]
  [-0.31164536 -0.32118946 -0.25096482]]

 [[-0.49050733 -0.4725342  -0.40230957]
  [ 1.2981114   1.2610501   1.2899989 ]
  [ 2.3024895   2.1691184   2.129274  ]
  ...
  [-0.73816216 -0.73394763 -0.69124025]
  [-0.69688636 -0.69267184 -0.6499644 ]
  [-0.61433476 -0.6238787  -0.58117145]]

 ...

 [[-0.6556105  -0.65139604 -0.66372305]
  [-0.6556105  -0.65139604 -0.66372305]
  [-0.6556105  -0.6376374  -0.66372305]
  ...
  [-0.75192076 -0.74770623 -0.74627465]
  [-0.7656794  -0.76146483 -0.76003325]
  [-0.7656794  -0.76146483 -0.76003325]]

 [[-0.6556105  -0.65139604 -0.6499644 ]
  [-0.66936916 -0.66515464 -0.66372305]
  [-0.66936916 -0.66515464 -0.66372305]
  ...
  [-0.7656794  -0.76146483 -0.76003325]
  [-0.7656794  -0.76146483 -0.76003325]
  [-0.77943796 -0.77522343 -0.77379185]]

 [[-0.66936916 -0.66515464 -0.66372305]
  [-0.66936916 -0.66515464 -0.66372305]
  [-0.66936916 -0.66515464 -0.66372305]
  ...
  [-0.77943796 -0.77522343 -0.77379185]
  [-0.77943796 -0.77522343 -0.77379185]
  [-0.77943796 -0.77522343 -0.77379185]]], labels: 5
Length of labels: 512
Labels: [5 0 3 7 8 2 1 3 6 8 8 0 2 0 7 1 2 0 4 3 9 6 7 3 6 9 0 7 0 3 6 3 9 2 3 5 0
 2 2 8 5 9 3 3 0 3 8 4 5 8 0 7 7 6 8 0 6 5 7 4 1 8 9 2 5 5 8 5 3 5 9 5 8 2
 2 5 3 8 8 6 2 9 8 2 5 3 9 3 2 9 5 4 4 7 4 7 5 8 0 4 9 5 7 3 9 6 0 2 5 6 9
 5 1 2 5 9 4 8 6 7 7 7 3 1 2 9 9 5 5 2 3 9 8 1 8 7 4 6 9 3 8 4 6 6 8 9 1 6
 9 7 2 1 6 4 6 2 4 0 6 7 6 6 8 3 2 8 4 4 5 3 9 2 6 6 8 4 6 3 5 4 4 2 9 0 9
 8 1 3 6 3 5 4 5 2 9 2 3 5 3 2 4 9 4 5 7 9 5 8 2 7 5 4 2 3 0 8 4 9 5 6 7 0
 8 2 4 7 8 3 2 9 0 5 8 6 5 0 9 1 3 6 7 6 8 9 2 7 0 4 1 6 8 9 2 9 7 4 1 1 2
 7 2 2 8 1 4 8 5 4 7 9 0 7 0 2 1 7 2 6 7 6 1 0 7 2 2 6 9 8 3 0 8 1 4 9 5 6
 3 0 1 3 0 6 2 7 9 1 5 7 4 5 6 3 3 4 2 7 7 1 1 6 6 2 2 9 2 7 7 2 8 2 0 6 9
 3 7 6 8 1 3 2 4 3 4 6 8 8 5 6 9 1 0 4 0 5 3 1 1 7 8 5 4 5 0 6 1 3 6 7 2 2
 8 8 7 3 1 7 0 5 0 8 7 1 9 2 1 7 8 0 5 0 3 3 9 2 0 3 0 4 9 0 4 4 6 4 3 9 0
 5 2 8 1 4 1 7 3 2 5 0 0 5 2 1 2 7 7 8 4 2 4 8 3 7 5 7 7 2 8 4 0 2 1 2 8 7
 6 0 2 5 9 1 0 7 2 8 2 3 2 3 1 4 4 6 9 5 4 4 0 5 8 4 4 6 6 9 9 1 3 5 7 6 3
 0 7 9 6 3 0 5 2 4 2 9 1 0 7 8 8 0 3 7 5 5 7 3 4 4 2 0 0 7 4 5]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • Analysis
    (1) 批量数据每组大小为512,因此批量数据维度即占位数据维度为:图像数据(512, 24, 24, 3), 标签维度(512, );
    (2) 从批量数据中取一组数据,图像维度(24, 24, 3)表示一张图片,标签维度(),标签为单个数据;
    (3) 图像数据格式为float,标签值为10;
    (4) 批量数据标签数量为512,数据为0至9,对应图像分类;

2 基于Tensorflow搭建VGGNet训练网络

2.1 VGGNet模型

VGGNet.py,其中cifar10_input.py通过博客:
(二)VGGNet训练CIFAR10数据集之数据预处理获取.

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import time
import cifar10_input
import tarfile
from six.moves import urllib
import os
import sys

FLAGS = tf.app.flags.FLAGS
LOG_DIR = "./logs/testlogs"

# 图路径
LOG_DIR = "./logs/cifar"
# 小批量数据大小  
batch_size = 512
# 每轮训练数据的组数,每组为一batchsize  
s_times = 20
# 学习率
learning_rate = 0.0001
# 数据所在路径
data_dir = "./cifa10_data/cifar-10-batches-bin"

# Xavier初始化方法 
# 卷积权重(核)初始化 
def init_conv_weights(shape, name): 
	weights = tf.get_variable(name=name, shape=shape, dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer_conv2d()) 
	return weights 

# 全连接权重初始化 
def init_fc_weights(shape, name):
	weights = tf.get_variable(name=name, shape=shape, dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) 
	return weights

# 偏置 
def init_biases(shape, name): 
	biases = tf.Variable(tf.random_normal(shape),name=name, dtype=tf.float32) 
	return biases

# 卷积 
# 参数:输入张量,卷积核,偏置,卷积核在高和宽维度上移动的步长 
# 卷积核:不使用全0填充,padding='VALID' 
def conv2d(input_tensor, weights, biases, s_h, s_w):
	conv = tf.nn.conv2d(input_tensor, weights, [1, s_h, s_w, 1], padding='VALID') 
	return tf.nn.relu(conv + biases)

# 池化 
# 参数:输入张量,池化核高和宽,池化核在高,宽维度上移动步长 
# 池化窗口:不使用全0填充,padding='VALID'
def max_pool(input_tensor, k_h, k_w, s_h, s_w):
	return tf.nn.max_pool(input_tensor, ksize=[1, k_h, k_w, 1], strides=[1, s_h, s_w, 1], padding='VALID') 
	
# 全链接 
# 参数:输入张量,全连接权重,偏置 
def fullc(input_tensor, weights, biases): 
	return tf.nn.relu_layer(input_tensor, weights, biases)

# 使用tensorboard对网络结构进行可视化的效果较好 
# 输入占位节点
with tf.name_scope("source_data"):
	# images 输入图像 
	images = tf.placeholder(tf.float32, [batch_size, 24 ,24 ,3])
	# 图像标签
	labels = tf.placeholder(tf.int32, [batch_size]) 
	# 正则 
	keep_prob = tf.placeholder(tf.float32)
	tf.summary.image("input", images, 8)

# 第一组卷积 
# input shape (batch_size, 24, 24, 3)
# output shape (batch_size, 22, 22, 16) 
with tf.name_scope('conv_gp_1'): 
	# conv3-16
	cw_1 = init_conv_weights([3, 3, 3, 16], name='conv_w1') 
	cb_1 = init_biases([16], name='conv_b1') 
	conv_1 = conv2d(images, cw_1, cb_1, 1, 1)
	reshape_cgp_1 = tf.reshape(conv_1[0], [16, 22, 22, 1])
	tf.summary.image('conv_gp_1', reshape_cgp_1, 8)

# 第二组卷积  
# input shape (batch_size, 22, 22, 16)
# output shape (batch_size, 20, 20, 32)
with tf.name_scope('conv_gp2'):
	# conv3-32 
	cw_2 = init_conv_weights([3, 3, 16, 32], name='conv_w2') 
	cb_2 = init_biases([32], name='conv_b2') 
	conv_2 = conv2d(conv_1, cw_2, cb_2, 1, 1)
	reshape_cgp_2 = tf.reshape(conv_2[0], [32, 20, 20, 1])
	tf.summary.image('conv_gp_2', reshape_cgp_2, 8)

# 第三组卷积   
# input shape (batch_size, 20, 20, 32)
# output shape (batch_size, 16, 16, 64)
with tf.name_scope('conv_gp_3'): 
	# conv3-64
	cw_3 = init_conv_weights([3, 3, 32, 64], name='conv_w3') 
	cb_3 = init_biases([64], name='conv_b3') 
	conv_3 = conv2d(conv_2, cw_3, cb_3, 1, 1)
	# conv3-64 
	cw_4 = init_conv_weights([3, 3, 64, 64], name='conv_w4') 
	cb_4 = init_biases([64], name='conv_b4') 
	conv_4 = conv2d(conv_3, cw_4, cb_4, 1, 1)
	reshape_cgp_3 = tf.reshape(conv_4[0], [64, 16, 16, 1])
	tf.summary.image('conv_gp_3', reshape_cgp_3, 8)
	
# 第四组卷积  
# input shape (batch_size, 16, 16, 64)
# output shape (batch_size, 12, 12, 128)
with tf.name_scope('conv_gp_4'): 
	# conv3-128
	cw_5 = init_conv_weights([3, 3, 64, 128], name='conv_w5') 
	cb_5 = init_biases([128], name='conv_b5') 
	conv_5 = conv2d(conv_4, cw_5, cb_5, 1, 1) 
	# conv3-128 
	cw_6 = init_conv_weights([3, 3, 128, 128], name='conv_w6') 
	cb_6 = init_biases([128], name='conv_b6') 
	conv_6 = conv2d(conv_5, cw_6, cb_6, 1, 1)
	reshape_cgp_4 = tf.reshape(conv_6[0], [128, 12, 12, 1])
	tf.summary.image('conv_gp_4', reshape_cgp_4, 8)
	
# 最大池化 窗口尺寸2x2,步长2
# input (batch_size, 12, 12, 128)
# output (batch_size, 6, 6, 128)
pool_4 = max_pool(conv_6, 2, 2, 2, 2)
reshape_pool_4 = tf.reshape(pool_4[0], [128, 6, 6, 1])
tf.summary.image('pool_4', reshape_pool_4, 8)

# 第五组卷积
# input (batch_size, 6, 6, 128)
# output (batch_size, 2, 2, 128)
with tf.name_scope('conv_gp_5'): 
	# conv3-256
	cw_7 = init_conv_weights([3, 3, 128, 128], name='conv_w7') 
	cb_7 = init_biases([128], name='conv_b7') 
	conv_7 = conv2d(pool_4, cw_7, cb_7, 1, 1) 
	# conv3-256
	cw_8 = init_conv_weights([3, 3, 128, 128], name='conv_w8') 
	cb_8 = init_biases([128], name='conv_b8') 
	conv_8 = conv2d(conv_7, cw_8, cb_8, 1, 1)
	reshape_cgp_5 = tf.reshape(conv_8[0], [128, 2, 2, 1])
	tf.summary.image('conv_gp_5', reshape_cgp_5, 8)


# 转换数据shape
# input shape (batch_size, 2, 2, 128)
# reshape_conv8 (batch_size, 512)
reshape_conv8 = tf.reshape(conv_8, [batch_size, -1])
# n_in = 512
n_in = reshape_conv8.get_shape()[-1].value 

# 第一个全连接层
with tf.name_scope('fullc_1'):
	# (n_in, 256) = (512, 256) 
	fw9 = init_fc_weights([n_in, 256], name='fullc_w9') 
	# (256, )
	fb9 = init_biases([256], name='fullc_b9')
	# (512, 256)
	activation1 = fullc(reshape_conv8, fw9, fb9) 
# dropout正则 
drop_act1 = tf.nn.dropout(activation1, keep_prob) 
# 第二个全连接层
with tf.name_scope('fullc_2'): 
	# (256, 256)
	fw10 = init_fc_weights([256, 256], name='fullc_w10') 
	# (256, )
	fb10 = init_biases([256], name='fullc_b10') 
	# (512, 256)
	activation2 = fullc(drop_act1, fw10, fb10) 
# dropout正则 
drop_act2 = tf.nn.dropout(activation2, keep_prob) 

# 第三个全连接层
with tf.name_scope('fullc_3'):
	# (256, 10) 
	fw11 = init_fc_weights([256, 10], name='fullc_w11') 
	# (10, )
	fb11 = init_biases([10], name='full_b11') 
	# (512, 10)
	logits = tf.add(tf.matmul(drop_act2, fw11), fb11) 
	output = tf.nn.softmax(logits)

with tf.name_scope("cross_entropy"):
	cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) 
	cost = tf.reduce_mean(cross_entropy,name='Train_Cost') 
	tf.summary.scalar("cross_entropy", cost)

with tf.name_scope("train"):
	optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)

# with tf.name_scope("accuracy"):

# # 用来评估测试数据的准确率 # 数据labels没有使用one-hot编码格式,labels是int32 
# 	def accuracy(labels, output): 
# 		labels = tf.to_int64(labels) 
# 		pred_result = tf.equal(labels, tf.argmax(output, 1)) 
# 		accu = tf.reduce_mean(tf.cast(pred_result, tf.float32))
# 		tf.summary.scalar('accuracy', accu) 
# 		# return accu


merged = tf.summary.merge_all()

# 加载训练batch_size大小的数据,经过增强处理,剪裁,反转,等等
train_images, train_labels = cifar10_input.distorted_inputs(batch_size= batch_size, data_dir= data_dir)

# 加载测试数据,batch_size大小,不进行增强处理
test_images, test_labels = cifar10_input.inputs(batch_size= batch_size, data_dir= data_dir,eval_data= True)

# 训练
def training(max_steps, s_times, keeprob, display):
	with tf.Session() as sess:
		init_op = tf.global_variables_initializer()
		sess.run(init_op)
		coord = tf.train.Coordinator()
		threads = tf.train.start_queue_runners(coord=coord)
		summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
		# writer.close() 
		for i in range(max_steps): 
			for j in range(s_times): 
				start = time.time() 
				batch_images, batch_labels = sess.run([train_images, train_labels]) 
				opt = sess.run(optimizer, feed_dict={images:batch_images, labels:batch_labels, keep_prob:keeprob}) 
				every_batch_time = time.time() - start 
			summary, c = sess.run([merged, cost], feed_dict={images:batch_images, labels:batch_labels, keep_prob:keeprob}) 
			# 保存训练模型路径
			ckpt_dir = './vgg_models/vggmodel.ckpt'
			# 保存训练模型
			saver = tf.train.Saver()
			saver.save(sess,save_path=ckpt_dir,global_step=i)

			if i % display == 0: 
				samples_per_sec = float(batch_size) / every_batch_time 
				
				print("Epoch {}: {} samples/sec, {} sec/batch, Cost : {}".format(i+display, samples_per_sec, every_batch_time, c)) 
			summary_writer.add_summary(summary, i)
		# 线程阻塞
		coord.request_stop()
		# 等待子线程执行完毕
		coord.join(threads)
	summary_writer.close()

if __name__ == "__main__":
	training(5000, 5, 0.7, 10)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244

2.2 网络解析

2.2.1 源数据

通过爬取的数据,进行训练,爬取数据之后,需要使用Tensorflow进行预处理,即对图片进行:

  • 剪切tf.random_crop
  • 旋转tf.image.random_flip_left_right
  • 调节亮度random_brightness

使用Tensorflow提供的cifar10_input.py文件进行处理,将图片剪切成(24x24x3)的图像,并进行shuffle,改变训练数据的组合方式,提高训练识别的准确率.源码参见博客(可直接引用):
(二)VGGNet训练CIFAR10数据集之数据预处理

2.2.2 网络结构

由于CIFAR-10数据集的原始图像尺寸较小(32x32),并且经过预处理后为(24x24),因此,设计的网络是VGGNet的简洁版,结构如下图:


VGG

图2.1 `简洁`版VGGNet

2.2.3 网络分析

(1) 该网络使用了5个卷积组,2个池化层和三个全连接层,卷积计算不使用全0填充,因此通过卷积核进行卷积计算时,会改变图像尺寸,结合第一篇文章,使用全0填充VGGNet神经网络简介及Tensorflow搭建可视化网络,卷积计算时不改变图像尺寸.池化层即改变图形尺寸.
(2) 卷积神经网络计算参见博客:卷积神经网络CNN详解;
(3) 本文通过Tensorboard可视化网路结构及训练结果;

3 训练结果及分析

3.1 结果

  • 源数据图像结果

    在这里插入图片描述
图3.1 原始数据图像(24x24)
  • 第一组卷积图像结果

    在这里插入图片描述
图3.2 第一组卷积(22x22)
  • 第二组卷积图像结果

    在这里插入图片描述
图3.3 卷积结果(20x20)
  • 第三组卷积图像结果

    在这里插入图片描述
图3.4 卷积结果(16x16)
  • 第四组卷积图像结果

    在这里插入图片描述
图3.5 卷积结果(12x12)
  • 第五组卷积图像结果

    在这里插入图片描述
图3.6 卷积结果(2x2)
  • 损失函数图像结果

    在这里插入图片描述
图3.7 损失函数结果

3.2 分析

  • 源数据经过5组卷积处理,各卷积组分别提取特征,各有差别,随着深度增加,提取的信息也不相同,浅层主要提取图像内容(形状,位置,颜色和文理),深层提取图像组成(物体轮廓),这也是图像风格转换的应用.
  • 由于源数据的尺寸较小,最终输出的图像为 2 × 2 2\times2 2×2,信息量较少,模型精度待验证;
  • 损失值逐渐降低,本文测试了50轮,损失值降到2.1,随着训练轮数的增加,损失值会逐步降低,并稳定在某一范围内,按需训练,提取模型;

4 总结

  • 本文讲解了基于Tensorflow搭建VGGNet网络训练图像分类,可在此基础上改变网络结构,训练不同的分类及不同尺寸的图像;
  • 使用Tensorboard可视化训练过程及训练结果;
  • 不同卷积层提取的图像信息不同,可根据需要截取不同层次的信息,浅层主要提取图像内容,深层主要提取图像轮廓.

[参考文献]
[1]https://blog.csdn.net/Xin_101/article/details/81917134
[2]https://blog.csdn.net/Xin_101/article/details/86348653
[3]https://blog.csdn.net/Xin_101/article/details/86707518


声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号