当前位置:   article > 正文

理解fasterRCNN模型的构成,并进行训练和预测_faster-rcnn训练模型

faster-rcnn训练模型

学习目标

  • 了解VOC数据集的应用
  • 理解fasterRCNN模型的构成
  • 能够利用fasterRCNN网络模型进行训练和预测

1. VOC数据集简介

Pascal VOC数据集作为基准数据,在目标检测中常被使用到,很多优秀的计算机视觉模型比如分类,定位,检测,分割,动作识别等模型都是基于PASCAL VOC挑战赛及其数据集上推出的,尤其是一些目标检测模型(比如RCNN系列,以及后面要介绍的YOLO,SSD等)。

1.1 数据情况

常用的版本有2007和2012两个,在这里我们使用VOC2007作为案例实现的数据,该数据集总共有四大类,20个小类,如下图所示:

  • 从2007年开始,PASCAL VOC每年的数据集都是这个层级结构
  • 总共四个大类:vehicle,household,animal,person
  • 总共20个小类,预测的时候是只输出图中黑色粗体的类别

组织结构如下图所示:

  • Annotations 进行 detection 任务时的标签文件,xml 形式,文件名与图片名一一对应
  • ImageSets 包含三个子文件夹 Layout、Main、Segmentation,其中 Main 存放的是分类和检测的数据集分割文件
  • JPEGImages 存放 .jpg 格式的图片文件
  • SegmentationClass 存放按照 class 分割的图片
  • SegmentationObject 存放按照 object 分割的图片

我们使用的就是Annotations和JPEGImages两部分内容,另外我们通过Main文件夹下的文本文件获取对应的训练集及验证集数据,内容如下所示:

  • train.txt 写着用于训练的图片名称, 共 2501 个
  • val.txt 写着用于验证的图片名称,共 2510 个
  • trainval.txt train与val的合集。共 5011 个

1.2 标注信息

数据集的标注有专门的标注团队,并遵从统一的标注标准,标注信息是用 xml 文件组织,如下图所示:

标注信息如下所示:

  1. <annotation>
  2. <!--数据集版本位置-->
  3. <folder>VOC2007</folder>
  4. <!--文件名-->
  5. <filename>000001.jpg</filename>
  6. <!--文件来源-->
  7. <source>
  8. <database>The VOC2007 Database</database>
  9. <annotation>PASCAL VOC2007</annotation>
  10. <image>flickr</image>
  11. <flickrid>341012865</flickrid>
  12. </source>
  13. <!--拥有者-->
  14. <owner>
  15. <flickrid>Fried Camels</flickrid>
  16. <name>Jinky the Fruit Bat</name>
  17. </owner>
  18. <!--图片大小-->
  19. <size>
  20. <width>353</width>
  21. <height>500</height>
  22. <depth>3</depth>
  23. </size>
  24. <!--是否分割-->
  25. <segmented>0</segmented>
  26. <!--一个目标,里面的内容是目标的相关信息-->
  27. <object>
  28. <!--object名称,20个类别-->
  29. <name>dog</name>
  30. <!--拍摄角度:front, rear, left, right。。-->
  31. <pose>Left</pose>
  32. <!--目标是否被截断,或者被遮挡(超过15%)-->
  33. <truncated>1</truncated>
  34. <!--检测难易程度-->
  35. <difficult>0</difficult>
  36. <!--bounding box 的左上角点和右下角点的坐标值-->
  37. <bndbox>
  38. <xmin>48</xmin>
  39. <ymin>240</ymin>
  40. <xmax>195</xmax>
  41. <ymax>371</ymax>
  42. </bndbox>
  43. </object>
  44. <!--一个目标,里面的内容是目标的相关信息-->
  45. <object>
  46. <name>person</name>
  47. <pose>Left</pose>
  48. <!--目标是否被截断,或者被遮挡(超过15%)-->
  49. <truncated>1</truncated>
  50. <difficult>0</difficult>
  51. <bndbox>
  52. <xmin>8</xmin>
  53. <ymin>12</ymin>
  54. <xmax>352</xmax>
  55. <ymax>498</ymax>
  56. </bndbox>
  57. </object>
  58. </annotation>

2 数据集解析

该数据集的解析在fasterRCNN/detection/datasets/pascal_voc.py中:

接下来我们分析整个的实现过程:

2.1 指定数据集

根据指定的数据集,获取对应的文件信息,进行处理,其中main中txt中的内容如下所示:

因此我们根据txt中的内容加载对应的训练和验证集:

  1. def load_labels(self):
  2. # 根据标签信息加载相应的数据
  3. if self.phase == 'train':
  4. txtname = os.path.join(
  5. self.data_path, 'ImageSets', 'Main', 'trainval.txt')
  6. else:
  7. txtname = os.path.join(
  8. self.data_path, 'ImageSets', 'Main', 'val.txt')
  9. # 获取图像的索引
  10. with open(txtname, 'r') as f:
  11. self.image_index = [x.strip() for x in f.readlines()]
  12. self.num_image = len(self.image_index)
  13. # 图像对应的索引放到列表gt_labels中
  14. gt_labels = []
  15. # 遍历每一份图像获取标注信息
  16. for index in self.image_index:
  17. # 获取标注信息,包括objet box坐标信息 以及类别信息
  18. gt_label = self.load_pascal_annotation(index)
  19. # 添加到列表中
  20. gt_labels.append(gt_label)
  21. # 将标注信息赋值给属性:self.gt_labels
  22. self.gt_labels = gt_labels

2.2图像读取

利用OpenCV读取图像数据,并进行通道的转换:

  1. def image_read(self, imname):
  2. # opencv 中默认图片色彩格式为BGR
  3. image = cv2.imread(imname)
  4. # 将图片转成RGB格式
  5. image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB).astype(np.float32)
  6. return image

2.3 标准信息的读取

标注信息的读取主要是根据图像的文件名获取索引后,找到对应的XML文件,读取其中的内容,得到图像的标注信息。

  1. def load_pascal_annotation(self, index):
  2. """
  3. 在PASCAL VOC的XML文件获取边框信息和类别信息
  4. """
  5. # 获取XML文件的地址
  6. filename = os.path.join(self.data_path, 'Annotations', index + '.xml')
  7. # 将XML中的内容获取出来
  8. tree = ET.parse(filename)
  9. # 获取节点图像的size
  10. image_size = tree.find('size')
  11. # 将图像的size信息存放到sizeinfo中
  12. size_info = np.zeros((2,), dtype=np.float32)
  13. # 宽
  14. size_info[0] = float(image_size.find('width').text)
  15. # 高
  16. size_info[1] = float(image_size.find('height').text)
  17. # 找到所有的object节点
  18. objs = tree.findall('object')
  19. # object的数量
  20. num_objs = len(objs)
  21. # boxes 坐标 (num_objs,4)
  22. boxes = np.zeros((num_objs, 4), dtype=np.float32)
  23. # class 的数量num_objs个,每个目标一个类别
  24. gt_classes = np.zeros((num_objs), dtype=np.int32)
  25. # 遍历所有的目标
  26. for ix, obj in enumerate(objs):
  27. # 找到bndbox节点
  28. bbox = obj.find('bndbox')
  29. # 获取坐标框的位置信息
  30. x1 = float(bbox.find('xmin').text) - 1
  31. y1 = float(bbox.find('ymin').text) - 1
  32. x2 = float(bbox.find('xmax').text) - 1
  33. y2 = float(bbox.find('ymax').text) - 1
  34. # 将位置信息存储在bbox中,注意boxes是一个np类的矩阵 大小为[num_objs,4]
  35. boxes[ix, :] = [y1, x1, y2, x2]
  36. # 找到class对应的类别信息
  37. cls = self.class_to_ind[obj.find('name').text.lower().strip()]
  38. # 将class信息存入gt_classses中,注意gt_classes也是一个np类的矩阵 大小为[num_objs] 是int值 对应于name
  39. gt_classes[ix] = cls
  40. # 获取图像的存储路径
  41. imname = os.path.join(self.data_path, 'JPEGImages', index + '.jpg')
  42. # 返回结果
  43. return {'boxes': boxes, 'gt_classs': gt_classes, 'imname': imname, 'image_size': size_info,
  44. 'image_index': index}

2.4 图像的大小处理

在将图像送入到网络中时,我们需要将其进行大小的调整,在这里我们为了保证长宽比,使最长边resize为1024,短边进行pad:

  1. def prep_im_for_blob(self, im, pixel_means, target_size, max_size):
  2. "对输入的图像进行处理"
  3. im = im.astype(np.float32, copy=False)
  4. # 减去均值
  5. im -= pixel_means
  6. # 图像的大小
  7. im_shape = im.shape
  8. # 最短边
  9. im_size_min = np.min(im_shape[0:2])
  10. # 最长边
  11. im_size_max = np.max(im_shape[0:2])
  12. # 短边变换到800的比例
  13. im_scale = float(target_size) / float(im_size_min) # 600/最短边
  14. # 若长边以上述比例变换后大于1024,则修正变换比例
  15. if np.round(im_scale * im_size_max) > max_size:
  16. im_scale = float(max_size) / float(im_size_max)
  17. # 根据变换比例对图像进行resize
  18. im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
  19. shape = (1024, 1024, im.shape[-1])
  20. pad = np.zeros(shape, dtype=im.dtype)
  21. pad[:im.shape[0], :im.shape[1], ...] = im
  22. # 返回im 和 im_scale
  23. return pad, im_scale, im.shape

2.5 构建数据读取的类

在这里我们使用tf.keras.utils.Sequence来完成数据的读取,继承Sequence类必须重载三个私有方法__init__、len__和__getitem,主要是__getitem__。

  • __init__是构造方法,用于初始化数据的。
  • __len__用于计算样本数据长度。
  • __getitem__用于生成整个batch的数据,送入神经网络模型进行训练,其输出格式是元组。__getitem__相当于生成器的作用。

Sequence是进行多处理的更安全方法。这种结构保证了网络在每个时间段的每个样本上只会训练一次。

例如:

  1. class CIFAR10Sequence(Sequence):
  2. # 定义一个类继承自Sequence
  3. # _init_方法进行初始化数据,指定相应的属性即可
  4. def __init__(self, x_set, y_set, batch_size):
  5. # 数据集
  6. self.x, self.y = x_set, y_set
  7. # batch的大小
  8. self.batch_size = batch_size
  9. # 定义一个epoch中的迭代次数
  10. def __len__(self):
  11. return math.ceil(len(self.x) / self.batch_size)
  12. # 获取一个批次数据
  13. def __getitem__(self, idx):
  14. # 获取一个批次的特征值数据
  15. batch_x = self.x[idx * self.batch_size:(idx + 1) *
  16. self.batch_size]
  17. # 获取一个批次的目标值数据
  18. batch_y = self.y[idx * self.batch_size:(idx + 1) *
  19. self.batch_size]
  20. # 返回结果
  21. return np.array([
  22. resize(imread(file_name))
  23. for file_name in batch_x]), np.array(batch_y)

那在VOC数据集的读取中我们也类似的进行处理:

  1. class pascal_voc(keras.utils.Sequence):
  2. def __init__(self, phase):
  3. # pascal_voc 2007数据的存储路径
  4. self.data_path = os.path.join('../VOCdevkit', 'VOC2007')
  5. # batch_size
  6. self.batch_size = 1
  7. # 图片的最小尺寸
  8. self.target_size = 800
  9. # 图片的最大尺寸
  10. self.max_size = 1024
  11. # 输入网络中的图像尺寸
  12. self.scale = (1024, 1024)
  13. # 类别信息 ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus'....]
  14. self.classes = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse','motorbike', 'person', 'pottedplant', 'sheep', 'sofa','train', 'tvmonitor']
  15. # 构建目标类别的字典{'background': 0, 'aeroplane': 1, "bicycle": 2....}
  16. self.class_to_ind = dict(zip(self.classes, range(len(self.classes))))
  17. # 像素RGB的均值
  18. self.pixel_means = np.array([[[122.7717, 115.9465, 102.9801]]])
  19. # 用来指明获取训练集或者是验证集数据
  20. self.phase = phase
  21. # 获取图像数量,并加载相应的标签
  22. self.load_labels()
  23. # 目标总数量
  24. self.num_gtlabels = len(self.gt_labels)
  25. self.img_transform = transforms.ImageTransform(self.scale, self.pixel_means, [1.,1.,1.], 'fixed')
  26. self.bbox_transform = transforms.BboxTransform()
  27. self.flip_ratio=0.5
  28. def __len__(self):
  29. # 返回迭代次数
  30. return np.round(self.num_image/self.batch_size)
  31. def __getitem__(self, idx):
  32. # 获取当前batch的起始索引值
  33. i = idx * self.batch_size
  34. batch_images = []
  35. batch_imgmeta= []
  36. batch_box = []
  37. bacth_labels = []
  38. for c in range(self.batch_size):
  39. # 获取相应的图像
  40. imname = self.gt_labels[i+c]['imname']
  41. # 读取图像
  42. image = self.image_read(imname)
  43. # 获取原始图像的尺寸
  44. ori_shape = image.shape
  45. # 进行尺度调整后的图像及调整的尺度
  46. image, image_scale,image_shape= self.prep_im_for_blob(image, self.pixel_means, self.target_size, self.max_size)
  47. # 获取尺度变换后图像的尺寸
  48. pad_shape = image.shape
  49. # 将gt_boxlabel与scale相乘获取图像调整后的标注框的大小:boxes.shape=[num_obj,4]
  50. bboxes = self.gt_labels[i+c]['boxes'] * image_scale
  51. # 获取对应的类别信息
  52. labels = self.gt_labels[i+c]['gt_classs']
  53. # print(labels)
  54. # 图像的基本信息
  55. img_meta_dict = dict({
  56. 'ori_shape': ori_shape,
  57. 'img_shape': image_shape,
  58. 'pad_shape': pad_shape,
  59. 'scale_factor': image_scale
  60. })
  61. # 将字典转换为列表的形式
  62. image_meta = self.compose_image_meta(img_meta_dict)
  63. # print(image_meta)
  64. batch_images.append(image)
  65. bacth_labels.append(labels)
  66. batch_imgmeta.append(image_meta)
  67. batch_box.append(bboxes)
  68. # 将图像转换成tensorflow输入的形式:【batch_size,H,W,C】
  69. batch_images = np.reshape(batch_images, (self.batch_size, image.shape[0], image.shape[1], 3))
  70. # 图像元信息
  71. batch_imgmeta = np.reshape(batch_imgmeta,(self.batch_size,11))
  72. # 标注框信息
  73. batch_box = np.reshape(batch_box,(self.batch_size,bboxes.shape[0],4 ))
  74. # 标注类别信息
  75. bacth_labels = np.reshape(bacth_labels,((self.batch_size,labels.shape[0])))
  76. # 返回结果:尺度变换后的图像,图像元信息,目标框位置,目标类别
  77. return batch_images,batch_imgmeta, batch_box, bacth_labels

2.6 数据解析类演示

我们利用上述的数据解析方法来对VOC数据集进行解析:

  • 导入所需的工具包
  1. # 导入数据集 VOC data
  2. from detection.datasets import pascal_voc
  3. from detection.datasets.utils import get_original_image
  4. import numpy as np
  5. # 图像展示
  6. from matplotlib import pyplot as plt
  • 获取图像并设置图像的均值与方差
  1. # 实例化
  2. pascal = pascal_voc.pascal_voc('train')
  3. # 获取图像
  4. image, image_meta, bboxes, labels = pascal[8]
  5. # 图像的均值和标准差
  6. img_mean = (122.7717, 115.9465, 102.9801)
  7. img_std = (1., 1., 1.)
  • 原图像展示
  1. # 获取原图像
  2. ori_img = get_original_image(image[0], image_meta[0], img_mean).astype(np.uint8)
  3. # 图像展示
  4. plt.imshow(ori_img)
  5. plt.show()

  • 送入网络中的图像进行了resize和pasding
  1. # 送入网络中的图像
  2. rgb_img = np.round(image + img_mean).astype(np.uint8)
  3. plt.imshow(rgb_img[0])
  4. plt.show()

  • 将标注信息显示出来
  1. # 显示图像,及对应的标签值
  2. import visualize
  3. visualize.display_instances(rgb_img[0], bboxes[0], labels[0], pascal.classes)

3.模型训练

接下来我们利用已搭建好的网络和数据进行模型训练,在这里我们使用:

  • 定义tf.GradientTape的作用域,计算损失值
  • 使用 tape.gradient(ys, xs) 自动计算梯度
  • 使用 optimizer.apply_gradients(grads_and_vars) 自动更新模型参数

完成网络的训练。我们来看下实现流程:

  1. # 导入工具包
  2. from detection.datasets import pascal_voc
  3. import tensorflow as tf
  4. import numpy as np
  5. from matplotlib import pyplot as plt
  6. from detection.models.detectors import faster_rcnn
  • 加载数据获取数据类别
  1. # 加载数据集
  2. train_dataset = pascal_voc.pascal_voc('train')
  3. # 数据类别
  4. num_classes = len(train_dataset.classes)
  • 模型加载
model = faster_rcnn.FasterRCNN(num_classes=num_classes)
  • 定义优化器
  1. # 优化器
  2. optimizer = tf.keras.optimizers.SGD(1e-3, momentum=0.9, nesterov=True)
  • 模型训练
  1. # 模型优化
  2. loss_his = []
  3. for epoch in range(10):
  4. # 获取样本的index
  5. indices = np.arange(train_dataset.num_gtlabels)
  6. # 打乱
  7. np.random.shuffle(indices)
  8. # 迭代次数
  9. iter = np.round(train_dataset.num_gtlabels /
  10. train_dataset.batch_size).astype(np.uint8)
  11. # 每一次迭代
  12. for idx in range(iter):
  13. # 获取某一个bacth
  14. idx = indices[idx]
  15. # 获取当前batch的结果
  16. batch_imgs, batch_metas, batch_bboxes, batch_labels = train_dataset[idx]
  17. # 定义作用域
  18. with tf.GradientTape() as tape:
  19. # 将数据送入网络中计算损失
  20. rpn_class_loss, rpn_bbox_loss, rcnn_class_loss, rcnn_bbox_loss = \
  21. model((batch_imgs, batch_metas, batch_bboxes,
  22. batch_labels), training=True)
  23. # 求总损失
  24. loss = rpn_class_loss + rpn_bbox_loss + rcnn_class_loss + rcnn_bbox_loss
  25. # 计算梯度值
  26. grads = tape.gradient(loss, model.trainable_variables)
  27. # 更新参数值
  28. optimizer.apply_gradients(zip(grads, model.trainable_variables))
  29. # 打印损失结果
  30. print("epoch:%d, loss:%f" % (epoch + 1, loss))
  31. loss_his.append(loss)
  32. # 每一次迭代中只运行一个图像
  33. continue
  34. # 每一个epoch中只运行一次迭代
  35. continue

结果为:

  1. epoch:1, loss:147.117371
  2. epoch:2, loss:72.580498
  3. epoch:3, loss:79.347351
  4. epoch:4, loss:41.220577
  5. epoch:5, loss:5.238140
  6. epoch:6, loss:2.924250
  7. epoch:7, loss:5.287500

损失函数的变换如下图所示:

  1. # 绘制损失函数变化的曲线
  2. plt.plot(range(len(loss_his)),[loss.numpy() for loss in loss_his])
  3. plt.grid()

4.模型测试

在这里我们首先加载模型,我们来看下RPN网络和fastRCNN网络的输出。

导入工具包

  1. # 导入数据集 VOC data
  2. from detection.datasets import pascal_voc
  3. import numpy as np
  4. # 图像展示
  5. from matplotlib import pyplot as plt
  6. # 显示图像,及对应的标签值
  7. import visualize
  8. # 模型加载
  9. from detection.models.detectors import faster_rcnn
  10. import tensorflow as tf
  11. from detection.core.bbox import transforms
  12. from detection.datasets.utils import get_original_image

4.1 数据和模型加载

首先加载要进行预测的数据和训练好的模型:

  • 数据集加载
  1. # 数据集加载
  2. # 实例化
  3. pascal = pascal_voc.pascal_voc('train')
  4. # 获取图像
  5. image, image_meta, bboxes, labels = pascal[8]
  6. # 图像的均值和标准差
  7. img_mean = (122.7717, 115.9465, 102.9801)
  8. img_std = (1., 1., 1.)
  9. # 获取图像,显示
  10. rgb_img = np.round(image + img_mean).astype(np.uint8)
  11. plt.imshow(rgb_img[0])
  12. plt.show()

  • 模型加载
  1. # 加载模型:模型训练是在COCO数据集中进行的,
  2. # coco数据集中的类别信息
  3. classes = ['bg', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
  4. # 模型加载
  5. model = faster_rcnn.FasterRCNN(num_classes=len(classes))
  6. # 将数据送入到网络中
  7. _ = model((image, image_meta, bboxes, labels), training=True)
  8. # 加载已训练好的权重
  9. model.load_weights('weights/faster_rcnn.h5')

接下来,我们来看下RPN网络和fastRCNN的输出。

4.2 RPN网络

4.2.1 RPN的目标值

  • 获取图像的anchor,并匹配目标值
  1. # 根据图像信息产生anchor
  2. anchors, valid_flags = model.rpn_head.generator.generate_pyramid_anchors(image_meta)
  3. # 并设置anchor对应的目标值
  4. rpn_target_matchs, rpn_target_deltas = model.rpn_head.anchor_target.build_targets(
  5. anchors, valid_flags, bboxes, labels)
  • 获取正负样本,及正样本的回归值
  1. # 获取正样本
  2. positive_anchors = tf.gather(anchors, tf.where(tf.equal(rpn_target_matchs, 1))[:, 1])
  3. # 获取负样本
  4. negative_anchors = tf.gather(anchors, tf.where(tf.equal(rpn_target_matchs, -1))[:, 1])
  5. # 获取非正非负样本
  6. neutral_anchors = tf.gather(anchors, tf.where(tf.equal(rpn_target_matchs, 0))[:, 1])
  7. # 获取正样本的回归值
  8. positive_target_deltas = rpn_target_deltas[0, :tf.where(tf.equal(rpn_target_matchs, 1)).shape[0]]
  9. # 获取anchor修正的目标值
  10. refined_anchors = transforms.delta2bbox(
  11. positive_anchors, positive_target_deltas, (0., 0., 0., 0.), (0.1, 0.1, 0.2, 0.2))
  • 正负样本的结果
  1. # 正负样本的个数
  2. print('rpn_target_matchs:\t', rpn_target_matchs[0].shape.as_list())
  3. print('rpn_target_deltas:\t', rpn_target_deltas[0].shape.as_list())
  4. print('positive_anchors:\t', positive_anchors.shape.as_list())
  5. print('negative_anchors:\t', negative_anchors.shape.as_list())
  6. print('neutral_anchors:\t', neutral_anchors.shape.as_list())
  7. print('refined_anchors:\t', refined_anchors.shape.as_list()
  8. rpn_target_matchs: [261888]
  9. rpn_target_deltas: [256, 4]
  10. positive_anchors: [4, 4]
  11. negative_anchors: [252, 4]
  12. neutral_anchors: [261632, 4]
  13. refined_anchors: [4, 4]
  • 将正样本绘制在图像上
  1. # 将正样本的anchor显示在图像上
  2. visualize.draw_boxes(rgb_img[0],
  3. boxes=positive_anchors.numpy(),
  4. refined_boxes=refined_anchors.numpy())
  5. plt.show()

4.2.2 RPN的预测结果

  • 将图像送入网络中获取预测结果
  1. # 不可训练
  2. training = False
  3. # 获取backbone提取的特征结果
  4. C2, C3, C4, C5 = model.backbone(image,
  5. training=training)
  6. # 获取fcn的特征结果
  7. P2, P3, P4, P5, P6 = model.neck([C2, C3, C4, C5],
  8. training=training)
  • 获取送入到RPN网络中的特征图,并送入RPN网络中
  1. # 获取特征图
  2. rpn_feature_maps = [P2, P3, P4, P5, P6]
  3. # 获取RPN的预测结果:分类和回归结果
  4. rpn_class_logits, rpn_probs, rpn_deltas = model.rpn_head(
  5. rpn_feature_maps, training=training)
  • 将置信度较高的anchor显示在图像上
  1. # [batch_size, num_anchors, (bg prob, fg prob)] rpn的分类结果,在这里我们去第一个batch,所有anchor前景的概率
  2. rpn_probs_tmp = rpn_probs[0, :, 1]
  3. # 将置信度top100的绘制在图像上
  4. limit = 100
  5. ix = tf.nn.top_k(rpn_probs_tmp, k=limit).indices[::-1]
  6. # 绘制在图像上
  7. visualize.draw_boxes(rgb_img[0], boxes=tf.gather(anchors, ix).numpy())

结果如下所示:

4.3 fastRCNN网络

将RPN的结果送入到后续网络中,进行检测:

  • 获取候选区域
  1. # 候选区域
  2. proposals_list = model.rpn_head.get_proposals(
  3. rpn_probs, rpn_deltas, image_meta)
  • 进行ROIPooling
  1. rois_list = proposals_list
  2. # roipooling
  3. pooled_regions_list = model.roi_align(
  4. (rois_list, rcnn_feature_maps, image_meta), training=training)
  • 预测
  1. # 进行预测
  2. rcnn_class_logits_list, rcnn_probs_list, rcnn_deltas_list = \
  3. model.bbox_head(pooled_regions_list, training=training)
  • 获取预测结果
  1. # 获取预测结果
  2. detections_list = model.bbox_head.get_bboxes(
  3. rcnn_probs_list, rcnn_deltas_list, rois_list, image_meta)
  • 获取预测结果的坐标,并绘制在图像上
  1. # 获得坐标值
  2. tmp = detections_list[0][:, :4]
  3. # 将检测检测的框绘制在图像上
  4. visualize.draw_boxes(rgb_img[0], boxes=tmp.numpy())

4.4 目标检测

上述我们是分步进行预测,我们也可以直接在原图像上进行预测:

  1. # 获取原图像
  2. ori_img = get_original_image(image[0], image_meta[0], img_mean)
  3. # 获取候选区域
  4. proposals = model.simple_test_rpn(image[0], image_meta[0])
  5. # 检测结果
  6. res = model.simple_test_bboxes(image[0], image_meta[0], proposals)
  7. # 将检测结果绘制在图像上
  8. visualize.display_instances(ori_img, res['rois'], res['class_ids'],
  9. classes, scores=res['scores'])

最终的检测结果为:


总结

  • 了解VOC数据集的应用理解

Pascal VOC数据集作为基准数据,在目标检测中常被使用到

  • fasterRCNN模型的构成

主要有RPN网络进行候选区域的的生成,然后使用fastRCNN网络进行预测

  • 能够利用fasterRCNN网络模型进行训练和预测
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/113978
推荐阅读
相关标签
  

闽ICP备14008679号