当前位置:   article > 正文

全卷积神经网络FCN-TensorFlow代码精析_fcn 网络ssh模块 tensorflow代码python

fcn 网络ssh模块 tensorflow代码python

FCN-TensorFlow完整代码Github:https://github.com/EternityZY/FCN-TensorFlow.git

这里解析所有代码 并加入详细注释

注意事项:

请按照代码中要求,将VGG-19模型和训练集下载好,运行下载很慢。

MODEL_URL =  'http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat'

DATA_URL =  'http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip'

代码经过修改可以运行在TensorFlow1.4上面

训练模型只需执行python FCN.py

修改学习率1e-5 甚至更小 否则loss会一直在3左右浮动

debug标志可以在训练期间设置,以添加关于激活函数,梯度,变量等的信息。

FCN.py

  1. # coding=utf-8
  2. from __future__ import print_function
  3. import tensorflow as tf
  4. import numpy as np
  5. import TensorflowUtils as utils
  6. import read_MITSceneParsingData as scene_parsing
  7. import datetime
  8. import BatchDatsetReader as dataset
  9. from six.moves import xrange
  10. # 参数设置
  11. FLAGS = tf.flags.FLAGS
  12. tf.flags.DEFINE_integer("batch_size", "2", "batch size for training")
  13. tf.flags.DEFINE_string("logs_dir", "logs/", "path to logs directory")
  14. tf.flags.DEFINE_string("data_dir", "Data_zoo/MIT_SceneParsing/", "path to dataset")
  15. tf.flags.DEFINE_float("learning_rate", "1e-6", "Learning rate for Adam Optimizer")
  16. tf.flags.DEFINE_string("model_dir", "Model_zoo/", "Path to vgg model mat")
  17. tf.flags.DEFINE_bool('debug', "True", "Debug mode: True/ False")
  18. tf.flags.DEFINE_string('mode', "train", "Mode train/ test/ visualize")
  19. MODEL_URL = 'http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat'
  20. MAX_ITERATION = 20000 # 迭代次数
  21. NUM_OF_CLASSESS = 151 # 类别数 151
  22. IMAGE_SIZE = 224 # 图片大小 224
  23. fine_tuning = False
  24. # VGG网络部分,weights是权重集合, image是预测图像的向量
  25. def vgg_net(weights, image):
  26. # VGG网络前五大部分
  27. layers = (
  28. 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
  29. 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
  30. 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
  31. 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
  32. 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
  33. 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
  34. 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
  35. 'relu5_3', 'conv5_4', 'relu5_4'
  36. )
  37. net = {}
  38. current = image # 预测图像
  39. for i, name in enumerate(layers):
  40. kind = name[:4]
  41. if kind == 'conv':
  42. kernels, bias = weights[i][0][0][0][0]
  43. # matconvnet: weights are [width, height, in_channels, out_channels]
  44. # tensorflow: weights are [height, width, in_channels, out_channels]
  45. kernels = utils.get_variable(np.transpose(kernels, (1, 0, 2, 3)), name=name + "_w") # conv1_1_w
  46. bias = utils.get_variable(bias.reshape(-1), name=name + "_b") # conv1_1_b
  47. current = utils.conv2d_basic(current, kernels, bias) # 前向传播结果 current
  48. elif kind == 'relu':
  49. current = tf.nn.relu(current, name=name) # relu1_1
  50. if FLAGS.debug: # 是否开启debug模式 true / false
  51. utils.add_activation_summary(current) # 画图
  52. elif kind == 'pool':
  53. # vgg 的前5层的stride都是2,也就是前5层的size依次减小1倍
  54. # 这里处理了前4层的stride,用的是平均池化
  55. # 第5层的pool在下文的外部处理了,用的是最大池化
  56. # pool1 size缩小2倍
  57. # pool2 size缩小4倍
  58. # pool3 size缩小8倍
  59. # pool4 size缩小16倍
  60. current = utils.avg_pool_2x2(current)
  61. net[name] = current # 每层前向传播结果放在net中, 是一个字典
  62. return net
  63. # 预测流程,image是输入图像,keep_prob dropout比例
  64. def inference(image, keep_prob):
  65. """
  66. Semantic segmentation network definition # 语义分割网络定义
  67. :param image: input image. Should have values in range 0-255
  68. :param keep_prob:
  69. :return:
  70. """
  71. # 获取预训练网络VGG
  72. print("setting up vgg initialized conv layers ...")
  73. # model_dir Model_zoo/
  74. # MODEL_URL 下载VGG19网址
  75. model_data = utils.get_model_data(FLAGS.model_dir, MODEL_URL) # 返回VGG19模型中内容
  76. mean = model_data['normalization'][0][0][0] # 获得图像均值
  77. mean_pixel = np.mean(mean, axis=(0, 1)) # RGB
  78. weights = np.squeeze(model_data['layers']) # 压缩VGG网络中参数,把维度是1的维度去掉 剩下的就是权重
  79. processed_image = utils.process_image(image, mean_pixel) # 图像减均值
  80. with tf.variable_scope("inference"): # 命名作用域 是inference
  81. image_net = vgg_net(weights, processed_image) # 传入权重参数和预测图像,获得所有层输出结果
  82. conv_final_layer = image_net["conv5_3"] # 获得输出结果
  83. pool5 = utils.max_pool_2x2(conv_final_layer) # /32 缩小32倍
  84. W6 = utils.weight_variable([7, 7, 512, 4096], name="W6") # 初始化第6层的w b
  85. b6 = utils.bias_variable([4096], name="b6")
  86. conv6 = utils.conv2d_basic(pool5, W6, b6)
  87. relu6 = tf.nn.relu(conv6, name="relu6")
  88. if FLAGS.debug:
  89. utils.add_activation_summary(relu6)
  90. relu_dropout6 = tf.nn.dropout(relu6, keep_prob=keep_prob)
  91. W7 = utils.weight_variable([1, 1, 4096, 4096], name="W7") # 第7层卷积层
  92. b7 = utils.bias_variable([4096], name="b7")
  93. conv7 = utils.conv2d_basic(relu_dropout6, W7, b7)
  94. relu7 = tf.nn.relu(conv7, name="relu7")
  95. if FLAGS.debug:
  96. utils.add_activation_summary(relu7)
  97. relu_dropout7 = tf.nn.dropout(relu7, keep_prob=keep_prob)
  98. W8 = utils.weight_variable([1, 1, 4096, NUM_OF_CLASSESS], name="W8")
  99. b8 = utils.bias_variable([NUM_OF_CLASSESS], name="b8")
  100. conv8 = utils.conv2d_basic(relu_dropout7, W8, b8) # 第8层卷积层 分类151类
  101. # annotation_pred1 = tf.argmax(conv8, dimension=3, name="prediction1")
  102. # now to upscale to actual image size
  103. deconv_shape1 = image_net["pool4"].get_shape() # 将pool4 1/16结果尺寸拿出来 做融合 [b,h,w,c]
  104. # 定义反卷积层的 W,B [H, W, OUTC, INC] 输出个数为pool4层通道个数,输入为conv8通道个数
  105. # 扩大两倍 所以stride = 2 kernel_size = 4
  106. W_t1 = utils.weight_variable([4, 4, deconv_shape1[3].value, NUM_OF_CLASSESS], name="W_t1")
  107. b_t1 = utils.bias_variable([deconv_shape1[3].value], name="b_t1")
  108. # 输入为conv8特征图,使得其特征图大小扩大两倍,并且特征图个数变为pool4的通道数
  109. conv_t1 = utils.conv2d_transpose_strided(conv8, W_t1, b_t1, output_shape=tf.shape(image_net["pool4"]))
  110. fuse_1 = tf.add(conv_t1, image_net["pool4"], name="fuse_1") # 进行融合 逐像素相加
  111. # 获得pool3尺寸 是原图大小的1/8
  112. deconv_shape2 = image_net["pool3"].get_shape()
  113. # 输出通道数为pool3通道数, 输入通道数为pool4通道数
  114. W_t2 = utils.weight_variable([4, 4, deconv_shape2[3].value, deconv_shape1[3].value], name="W_t2")
  115. b_t2 = utils.bias_variable([deconv_shape2[3].value], name="b_t2")
  116. # 将上一层融合结果fuse_1在扩大两倍,输出尺寸和pool3相同
  117. conv_t2 = utils.conv2d_transpose_strided(fuse_1, W_t2, b_t2, output_shape=tf.shape(image_net["pool3"]))
  118. # 融合操作deconv(fuse_1) + pool3
  119. fuse_2 = tf.add(conv_t2, image_net["pool3"], name="fuse_2")
  120. shape = tf.shape(image) # 获得原始图像大小
  121. # 堆叠列表,反卷积输出尺寸,[b,原图H,原图W,类别个数]
  122. deconv_shape3 = tf.stack([shape[0], shape[1], shape[2], NUM_OF_CLASSESS])
  123. # 建立反卷积w[8倍扩大需要ks=16, 输出通道数为类别个数, 输入通道数pool3通道数]
  124. W_t3 = utils.weight_variable([16, 16, NUM_OF_CLASSESS, deconv_shape2[3].value], name="W_t3")
  125. b_t3 = utils.bias_variable([NUM_OF_CLASSESS], name="b_t3")
  126. # 反卷积,fuse_2反卷积,输出尺寸为 [b,原图H,原图W,类别个数]
  127. conv_t3 = utils.conv2d_transpose_strided(fuse_2, W_t3, b_t3, output_shape=deconv_shape3, stride=8)
  128. # 目前conv_t3的形式为size为和原始图像相同的size,通道数与分类数相同
  129. # 这句我的理解是对于每个像素位置,根据第3维度(通道数)通过argmax能计算出这个像素点属于哪个分类
  130. # 也就是对于每个像素而言,NUM_OF_CLASSESS个通道中哪个数值最大,这个像素就属于哪个分类
  131. # 每个像素点有21个值,哪个值最大就属于那一类
  132. # 返回一张图,每一个点对于其来别信息shape=[b,h,w]
  133. annotation_pred = tf.argmax(conv_t3, dimension=3, name="prediction")
  134. # 从第三维度扩展 形成[b,h,w,c] 其中c=1, conv_t3最后具有21深度的特征图
  135. return tf.expand_dims(annotation_pred, dim=3), conv_t3
  136. def train(loss_val, var_list):
  137. """
  138. :param loss_val: 损失函数
  139. :param var_list: 需要优化的值
  140. :return:
  141. """
  142. optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
  143. grads = optimizer.compute_gradients(loss_val, var_list=var_list)
  144. if FLAGS.debug:
  145. # print(len(var_list))
  146. for grad, var in grads:
  147. utils.add_gradient_summary(grad, var)
  148. return optimizer.apply_gradients(grads) # 返回迭代梯度
  149. def main(argv=None):
  150. # dropout保留率
  151. keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
  152. # 图像占坑
  153. image = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3], name="input_image")
  154. # 标签占坑
  155. annotation = tf.placeholder(tf.int32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 1], name="annotation")
  156. # 预测一个batch图像 获得预测图[b,h,w,c=1] 结果特征图[b,h,w,c=151]
  157. pred_annotation, logits = inference(image, keep_probability)
  158. tf.summary.image("input_image", image, max_outputs=2)
  159. tf.summary.image("ground_truth", tf.cast(annotation, tf.uint8), max_outputs=2)
  160. tf.summary.image("pred_annotation", tf.cast(pred_annotation, tf.uint8), max_outputs=2)
  161. # 空间交叉熵损失函数[b,h,w,c=151] 和labels[b,h,w] 每一张图分别对比
  162. loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
  163. labels=tf.squeeze(annotation, squeeze_dims=[3]),
  164. name="entropy")))
  165. tf.summary.scalar("entropy", loss)
  166. # 返回需要训练的变量列表
  167. trainable_var = tf.trainable_variables()
  168. if FLAGS.debug:
  169. for var in trainable_var:
  170. utils.add_to_regularization_and_summary(var)
  171. # 传入损失函数和需要训练的变量列表
  172. train_op = train(loss, trainable_var)
  173. print("Setting up summary op...")
  174. # 生成绘图数据
  175. summary_op = tf.summary.merge_all()
  176. print("Setting up image reader...")
  177. # data_dir = Data_zoo/MIT_SceneParsing/
  178. # training: [{image: 图片全路径, annotation:标签全路径, filename:图片名字}] [{}][{}]
  179. train_records, valid_records = scene_parsing.read_dataset(FLAGS.data_dir)
  180. print(len(train_records)) # 长度
  181. print(len(valid_records))
  182. print("Setting up dataset reader")
  183. image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
  184. if FLAGS.mode == 'train':
  185. # 读取图片 产生类对象 其中包含所有图片信息
  186. train_dataset_reader = dataset.BatchDatset(train_records, image_options)
  187. validation_dataset_reader = dataset.BatchDatset(valid_records, image_options)
  188. sess = tf.Session()
  189. print("Setting up Saver...")
  190. saver = tf.train.Saver()
  191. summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)
  192. sess.run(tf.global_variables_initializer())
  193. # logs/
  194. if fine_tuning:
  195. ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir) # 训练断点回复
  196. if ckpt and ckpt.model_checkpoint_path: # 如果存在checkpoint文件 则恢复sess
  197. saver.restore(sess, ckpt.model_checkpoint_path)
  198. print("Model restored...")
  199. if FLAGS.mode == "train":
  200. for itr in range(MAX_ITERATION):
  201. # 读取下一batch
  202. train_images, train_annotations = train_dataset_reader.next_batch(FLAGS.batch_size)
  203. feed_dict = {image: train_images, annotation: train_annotations, keep_probability: 0.85}
  204. # 迭代优化需要训练的变量
  205. sess.run(train_op, feed_dict=feed_dict)
  206. if itr % 10 == 0:
  207. # 迭代10次打印显示
  208. train_loss, summary_str = sess.run([loss, summary_op], feed_dict=feed_dict)
  209. print("Step: %d, Train_loss:%g" % (itr, train_loss))
  210. summary_writer.add_summary(summary_str, itr)
  211. if itr % 500 == 0:
  212. # 迭代500 次验证
  213. valid_images, valid_annotations = validation_dataset_reader.next_batch(FLAGS.batch_size)
  214. valid_loss = sess.run(loss, feed_dict={image: valid_images, annotation: valid_annotations,
  215. keep_probability: 1.0})
  216. print("%s ---> Validation_loss: %g" % (datetime.datetime.now(), valid_loss))
  217. # 保存模型
  218. saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr)
  219. elif FLAGS.mode == "visualize":
  220. # 可视化
  221. valid_images, valid_annotations = validation_dataset_reader.get_random_batch(FLAGS.batch_size)
  222. # pred_annotation预测结果图
  223. pred = sess.run(pred_annotation, feed_dict={image: valid_images, annotation: valid_annotations,
  224. keep_probability: 1.0})
  225. valid_annotations = np.squeeze(valid_annotations, axis=3)
  226. pred = np.squeeze(pred, axis=3)
  227. for itr in range(FLAGS.batch_size):
  228. utils.save_image(valid_images[itr].astype(np.uint8), FLAGS.logs_dir, name="inp_" + str(5+itr))
  229. utils.save_image(valid_annotations[itr].astype(np.uint8), FLAGS.logs_dir, name="gt_" + str(5+itr))
  230. utils.save_image(pred[itr].astype(np.uint8), FLAGS.logs_dir, name="pred_" + str(5+itr))
  231. print("Saved image: %d" % itr)
  232. if __name__ == "__main__":
  233. tf.app.run()

read_MITSceneParsingData.py

  1. # coding=utf-8
  2. __author__ = 'charlie'
  3. import numpy as np
  4. import os
  5. import random
  6. from six.moves import cPickle as pickle
  7. from tensorflow.python.platform import gfile
  8. import glob
  9. import TensorflowUtils as utils
  10. # DATA_URL = 'http://sceneparsing.csail.mit.edu/data/ADEChallengeData2016.zip'
  11. DATA_URL = 'http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip'
  12. def read_dataset(data_dir):
  13. # data_dir = Data_zoo / MIT_SceneParsing /
  14. pickle_filename = "MITSceneParsing.pickle"
  15. # 文件路径 Data_zoo / MIT_SceneParsing / MITSceneParsing.pickle
  16. pickle_filepath = os.path.join(data_dir, pickle_filename)
  17. if not os.path.exists(pickle_filepath):
  18. utils.maybe_download_and_extract(data_dir, DATA_URL, is_zipfile=True) # 不存在文件 则下载
  19. SceneParsing_folder = os.path.splitext(DATA_URL.split("/")[-1])[0] # ADEChallengeData2016
  20. # result = {training: [{image: 图片全路径, annotation:标签全路径, filename:图片名字}] [][]
  21. # validation:[{image:图片全路径, annotation:标签全路径, filename:图片名字}] [] []}
  22. result = create_image_lists(os.path.join(data_dir, SceneParsing_folder)) # Data_zoo / MIT_SceneParsing / ADEChallengeData2016
  23. print ("Pickling ...") # 制作pickle文件
  24. with open(pickle_filepath, 'wb') as f:
  25. pickle.dump(result, f, pickle.HIGHEST_PROTOCOL)
  26. else:
  27. print ("Found pickle file!")
  28. with open(pickle_filepath, 'rb') as f: # 打开pickle文件
  29. result = pickle.load(f) # 读取
  30. training_records = result['training']
  31. validation_records = result['validation']
  32. del result
  33. # training: [{image: 图片全路径, annotation:标签全路径, filename:图片名字}] [{}][{}]
  34. return training_records, validation_records
  35. def create_image_lists(image_dir):
  36. """
  37. :param image_dir: Data_zoo / MIT_SceneParsing / ADEChallengeData2016
  38. :return:
  39. """
  40. if not gfile.Exists(image_dir):
  41. print("Image directory '" + image_dir + "' not found.")
  42. return None
  43. directories = ['training', 'validation']
  44. image_list = {} # 图像字典 training:[] validation:[]
  45. for directory in directories: # 训练集和验证集 分别制作
  46. file_list = []
  47. image_list[directory] = []
  48. # Data_zoo/MIT_SceneParsing/ADEChallengeData2016/images/training/*.jpg
  49. file_glob = os.path.join(image_dir, "images", directory, '*.' + 'jpg')
  50. # 加入文件列表 包含所有图片文件全路径+文件名字 如 Data_zoo/MIT_SceneParsing/ADEChallengeData2016/images/training/hi.jpg
  51. file_list.extend(glob.glob(file_glob))
  52. if not file_list: # 文件为空
  53. print('No files found')
  54. else:
  55. for f in file_list: # 扫描文件列表 这里f对应文件全路径
  56. # 获取图片名字 hi
  57. filename = os.path.splitext(f.split("/")[-1])[0]
  58. # Data_zoo/MIT_SceneParsing/ADEChallengeData2016/annotations/training/*.png
  59. annotation_file = os.path.join(image_dir, "annotations", directory, filename + '.png')
  60. if os.path.exists(annotation_file): # 如果文件路径存在
  61. # image:图片全路径, annotation:标签全路径, filename:图片名字
  62. record = {'image': f, 'annotation': annotation_file, 'filename': filename}
  63. # image_list{training:[{image:图片全路径, annotation:标签全路径, filename:图片名字}] [] []
  64. # validation:[{image:图片全路径, annotation:标签全路径, filename:图片名字}] [] []}
  65. image_list[directory].append(record)
  66. else:
  67. print("Annotation file not found for %s - Skipping" % filename)
  68. # 对图片列表进行洗牌
  69. random.shuffle(image_list[directory])
  70. no_of_images = len(image_list[directory]) # 包含图片文件的个数
  71. print ('No. of %s files: %d' % (directory, no_of_images))
  72. return image_list

TensorflowUitls.py

  1. # coding=utf-8
  2. __author__ = 'Charlie'
  3. # Utils used with tensorflow implemetation
  4. import tensorflow as tf
  5. import numpy as np
  6. import scipy.misc as misc
  7. import os, sys
  8. from six.moves import urllib
  9. import tarfile
  10. import zipfile
  11. import scipy.io
  12. # 获取VGG预训练模型
  13. def get_model_data(dir_path, model_url):
  14. # model_dir Model_zoo/
  15. # MODEL_URL 下载VGG19网址
  16. maybe_download_and_extract(dir_path, model_url) # 判断文件目录和文件是否存在, 不存在则下载
  17. filename = model_url.split("/")[-1] # 将url按/切分, 取最后一个字符串作为文件名
  18. filepath = os.path.join(dir_path, filename) # dir_path/filename 文件全路径
  19. if not os.path.exists(filepath): # 判断是否存在此文件
  20. raise IOError("VGG Model not found!")
  21. data = scipy.io.loadmat(filepath) # 使用io读取VGG.mat文件
  22. return data
  23. def maybe_download_and_extract(dir_path, url_name, is_tarfile=False, is_zipfile=False):
  24. # dir_path Model_zoo/
  25. # url_name 下载VGG19网址
  26. if not os.path.exists(dir_path): # 判断文件路径是否存在,如果不存在则创建此路径
  27. os.makedirs(dir_path)
  28. filename = url_name.split('/')[-1] # 将url中 按照/切分,并取最后一个字符串 作为文件名字
  29. filepath = os.path.join(dir_path, filename) # 文件路径 = dir_path/filename
  30. if not os.path.exists(filepath): # 判断此路径是否存在(此文件),如果不存在,则下载
  31. def _progress(count, block_size, total_size): # 内部函数
  32. sys.stdout.write(
  33. '\r>> Downloading %s %.1f%%' % (filename, float(count * block_size) / float(total_size) * 100.0))
  34. sys.stdout.flush()
  35. filepath, _ = urllib.request.urlretrieve(url_name, filepath, reporthook=_progress) # 将url中文件 下载到filepath路径中
  36. print()
  37. statinfo = os.stat(filepath)
  38. print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
  39. if is_tarfile: # 如果是tar文件, 解压缩
  40. tarfile.open(filepath, 'r:gz').extractall(dir_path)
  41. elif is_zipfile: # 如果是zip文件 解压缩
  42. with zipfile.ZipFile(filepath) as zf:
  43. zip_dir = zf.namelist()[0]
  44. zf.extractall(dir_path)

BatchDatsetReader.py

  1. # coding=utf-8
  2. """
  3. Code ideas from https://github.com/Newmu/dcgan and tensorflow mnist dataset reader
  4. """
  5. import numpy as np
  6. import scipy.misc as misc
  7. class BatchDatset:
  8. files = []
  9. images = []
  10. annotations = []
  11. image_options = {}
  12. batch_offset = 0
  13. epochs_completed = 0
  14. def __init__(self, records_list, image_options={}):
  15. """
  16. Intialize a generic file reader with batching for list of files
  17. :param records_list: list of file records to read -
  18. sample record: {'image': f, 'annotation': annotation_file, 'filename': filename}
  19. :param image_options: A dictionary of options for modifying the output image
  20. Available options:
  21. resize = True/ False
  22. resize_size = #size of output image - does bilinear resize
  23. color=True/False
  24. """
  25. print("Initializing Batch Dataset Reader...")
  26. print(image_options)
  27. self.files = records_list # 文件列表
  28. self.image_options = image_options # 图片操作方式 resize 224
  29. self._read_images()
  30. def _read_images(self):
  31. self.__channels = True
  32. # 扫描files字典中所有image 图片全路径
  33. # 根据文件全路径读取图像,并将其扩充为RGB格式
  34. self.images = np.array([self._transform(filename['image']) for filename in self.files])
  35. self.__channels = False
  36. # 扫描files字典中所有annotation 图片全路径
  37. # 根据文件全路径读取图像,并将其扩充为三通道格式
  38. self.annotations = np.array(
  39. [np.expand_dims(self._transform(filename['annotation']), axis=3) for filename in self.files])
  40. print (self.images.shape)
  41. print (self.annotations.shape)
  42. def _transform(self, filename):
  43. # 读取文件图片
  44. image = misc.imread(filename)
  45. if self.__channels and len(image.shape) < 3: # make sure images are of shape(h,w,3)
  46. # 将图片三个通道设置为一样的图片
  47. image = np.array([image for i in range(3)])
  48. if self.image_options.get("resize", False) and self.image_options["resize"]:
  49. resize_size = int(self.image_options["resize_size"])
  50. # 使用最近邻插值法resize图片
  51. resize_image = misc.imresize(image,
  52. [resize_size, resize_size], interp='nearest')
  53. else:
  54. resize_image = image
  55. return np.array(resize_image) # 返回已经resize的图片
  56. def get_records(self):
  57. """
  58. 返回图片和标签全路径
  59. :return:
  60. """
  61. return self.images, self.annotations
  62. def reset_batch_offset(self, offset=0):
  63. """
  64. 剩下的batch
  65. :param offset:
  66. :return:
  67. """
  68. self.batch_offset = offset
  69. def next_batch(self, batch_size):
  70. # 当前第几个batch
  71. start = self.batch_offset
  72. # 读取下一个batch 所有offset偏移量+batch_size
  73. self.batch_offset += batch_size
  74. # iamges存储所有图片信息 images.shape(len, h, w)
  75. if self.batch_offset > self.images.shape[0]: # 如果下一个batch的偏移量超过了图片总数 说明完成了一个epoch
  76. # Finished epoch
  77. self.epochs_completed += 1 # epochs完成总数+1
  78. print("****************** Epochs completed: " + str(self.epochs_completed) + "******************")
  79. # Shuffle the data
  80. perm = np.arange(self.images.shape[0]) # arange生成数组(0 - len-1) 获取图片索引
  81. np.random.shuffle(perm) # 对图片索引洗牌
  82. self.images = self.images[perm] # 洗牌之后的图片顺序
  83. self.annotations = self.annotations[perm]
  84. # Start next epoch
  85. start = 0 # 下一个epoch从0开始
  86. self.batch_offset = batch_size # 已完成的batch偏移量
  87. end = self.batch_offset # 开始到结束self.batch_offset self.batch_offset+batch_size
  88. return self.images[start:end], self.annotations[start:end] # 取出batch
  89. def get_random_batch(self, batch_size):
  90. # 按照一个batch_size一个块 进行对所有图片总数进行随机操作, 相当于洗牌工作
  91. indexes = np.random.randint(0, self.images.shape[0], size=[batch_size]).tolist()
  92. return self.images[indexes], self.annotations[indexes]


原文:https://blog.csdn.net/qq_16761599/article/details/80069824 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/72329
推荐阅读
相关标签
  

闽ICP备14008679号