赞
踩
#注释部分为测试
import glob
import os.path
import random
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
MODEL_DIR = 'model/' # inception-v3模型的文件夹
MODEL_FILE = 'tensorflow_inception_graph.pb' # inception-v3模型文件名
CACHE_DIR = 'data/tmp/bottleneck' # 图像的特征向量保存地址
INPUT_DATA = 'data/flower_photos' # 图片数据文件夹
VALIDATION_PERCENTAGE = 10 # 验证数据的百分比
TEST_PERCENTAGE = 10 # 测试数据的百分比
# inception-v3模型参数
BOTTLENECK_TENSOR_SIZE = 2048 # inception-v3模型瓶颈层的节点个数
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0' # inception-v3模型中代表瓶颈层结果的张量名称
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0' # 图像输入张量对应的名称
# 神经网络的训练参数
LEARNING_RATE = 0.01
STEPS = 1000
BATCH = 100
CHECKPOINT_EVERY = 100
NUM_CHECKPOINTS = 5
# 从数据文件夹中读取所有的图片列表并按训练、验证、测试分开
def create_image_lists(validation_percentage, test_percentage):
result = {} # 保存所有图像。key为类别名称。value也是字典,存储了所有的图片名称
sub_dirs = [x[0] for x in os.walk(INPUT_DATA)] # 获取所有子目录
is_root_dir = True # 第一个目录为当前目录,需要忽略
# 分别对每个子目录进行操作
for sub_dir in sub_dirs:
if is_root_dir:
is_root_dir = False
continue
# 获取当前目录下的所有有效图片
extensions = {'jpg', 'jpeg', 'JPG', 'JPEG'}
file_list = [] # 存储所有图像
dir_name = os.path.basename(sub_dir) # 获取路径的最后一个目录名字
for extension in extensions:
file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)
# file_list.extend(alob.glob(file_glob))
file_list.extend(glob.glob(file_glob))
if not file_list:
continue
# 将当前类别的图片随机分为训练数据集、测试数据集、验证数据集
label_name = dir_name.lower() # 通过目录名获取类别的名称
training_images = []
testing_images = []
validation_images = []
for file_name in file_list:
base_name = os.path.basename(file_name) # 获取该图片的名称
chance = np.random.randint(100) # 随机产生100个数代表百分比
if chance < validation_percentage:
validation_images.append(base_name)
elif chance < (validation_percentage + test_percentage):
testing_images.append(base_name)
else:
training_images.append(base_name)
# 将当前类别的数据集放入结果字典
result[label_name] = {
'dir': dir_name,
'training': training_images,
'testing': testing_images,
'validation': validation_images
}
# 返回整理好的所有数据
return result
# 通过类别名称、所属数据集、图片编号获取一张图片的地址
def get_image_path(image_lists, image_dir, label_name, index, category):
label_lists = image_lists[label_name] # 获取给定类别中的所有图片
category_list = label_lists[category] # 根据所属数据集的名称获取该集合中的全部图片
mod_index = index % len(category_list) # 规范图片的索引
base_name = category_list[mod_index] # 获取图片的文件名
sub_dir = label_lists['dir'] # 获取当前类别的目录名
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。