赞
踩
代码连接:https://github.com/xiguanlezz/Faster-RCNN
因为Faster-RCNN的loss值是包含两部分的,第一部分是先验框即anchors和对应anchors_target的loss;第二部分是建议框即proposals和对应的proposals_target的loss。
其实就是先剔除掉在图片外面的先验框,然后根据IOU来创建标签,在__call___函数里计算了回归值以及标签信息。
import numpy as np from utils.util import calculate_iou, get_inside_index, box2loc class AnchorTargetCreator: def __init__(self, n_sample=256, pos_iou_thresh=0.7, neg_iou_thresh=0.3, pos_ratio=0.5): """ function description: AnchorTargetCreator构造函数 :param n_sample: 256, target的总数量 :param pos_iou_thresh: 和boxes的iou的阈值,超过此值为"正"样本, label会置为1 :param neg_iou_thresh: 和boxes的iou的阈值,低于此之为"负"样本, label会置为0 :param pos_ratio: target总数量中"正"样本的比例 """ self.n_sample = n_sample self.pos_iou_thresh = pos_iou_thresh self.neg_iou_thresh = neg_iou_thresh self.pos_ratio = pos_ratio # target总数量中"正"样本,如果正样本数量不足,则填充负样本 def __call__(self, boxes, anchors, img_size): """ function description: 得到先验框对应的回归值和的labels :param boxes: 图片中真实框左上角和右下角的坐标, 维度: [boxes_num, 4] :param anchors: 根据featuremap生成的所有anchors的坐标, 维度: [anchors_num, 4] :param img_size: 原图的大小, 用来过滤掉出界的anchors :return: anchor_locs: 最终的坐标, 维度为[inside_anchors_num ,4] anchor_labels: 最终的标签, 维度为[inside_anchors_num] """ img_width, img_height = img_size inside_index = get_inside_index(anchors, img_width, img_height) # 根据index取到在图片内部的anchors inside_anchors = anchors[inside_index] # 返回维度都为[inside_anchors_num]的每个先验框对应的iou最大的真实框的索引及打好的标签 argmax_ious, labels = self._create_label(inside_anchors, boxes) # 计算inside_anchors和对应iou最大的boxes的回归值 locs = box2loc(inside_anchors, boxes[argmax_ious]) anchors_num = len(anchors) # 把inside_anchors重新展开回原来所有的anchors方便计算第一部分关于先验框的loss anchor_labels = np.empty((anchors_num,), dtype=labels.dtype) anchor_labels.fill(-1) anchor_labels[inside_index] = labels # 利用broadcast重新展开locs方便计算第一部分关于先验框的loss anchor_locs = np.empty((anchors_num,) + locs.shape[1:], dtype=locs.dtype) anchor_locs.fill(0) anchor_locs[inside_index, :] = locs return anchor_locs, anchor_labels def _create_label(self, inside_anchors, boxes): """ function description: 为每个inside_anchors创建一个label, 其中1表示正样本, 0表示负样本, -1则忽略 所有打标签的情况: 1、与真实框的iou最高的先验框的分配为正样本; 2、与真实框的iou大于pos_iou_thresh的分配为正样本; 3、与真实框的iou小于neg_iou_thresh的分配为负样本 :param inside_anchors: 在图片内的先验框(anchors), 维度为: [inside_anchors_num, 4] :param boxes: 图片中的真实标注框, 维度为: [boxes_num, 4] :return: argmax_ious: 每个先验框对应的iou最大的真实框的索引, 维度为: [inside_anchors_num] label: 为每个inside_anchors创建的label, 维度为: [inside_anchors_num] """ # 对于每个在图片内的anchor都生成一个label label = np.empty((len(inside_anchors)), dtype=np.int32) # 先将label初始化为-1, 默认为忽略的label label.fill(-1) # argmax_ious, max_ious, gt_argmax_ious维度都为: [inside_anchors_num] argmax_ious, max_ious, gt_argmax_ious = self._calculate_iou(inside_anchors, boxes) # 将与真实框的iou重叠最大的anchors设置为正样本(分配每个真实框至少对应一个先验框); 对应情况(a) label[gt_argmax_ious] = 1 # 大于正样本的阈值则设置为正样本即将label设置为1; 对应情况(b) label[max_ious >= self.pos_iou_thresh] = 1 # 小于负样本的阈值就设置为负样本即将label设置为0; 对应情况(c) label[max_ious < self.neg_iou_thresh] = 0 # 下面的代码都是平衡正负样本,保持总数量为256(忽略-1的锚点) pos_standard = int(self.pos_ratio * self.n_sample) pos_num = np.where(label == 1)[0] if len(pos_num) > pos_standard: # replace=False表示随机选择索引的时候不会重复 disable_index = np.random.choice(pos_num, size=(len(pos_num) - pos_standard), replace=False) label[disable_index] = -1 neg_standard = self.n_sample - np.sum(label == 1) # 非正样本的个数 neg_num = np.where(label == 0)[0] if len(neg_num) > neg_standard: disable_index = np.random.choice(neg_num, size=(len(neg_num) - neg_standard), replace=False) label[disable_index] = -1 return argmax_ious, label def _calculate_iou(self, inside_anchors, boxes): """ function description: 从二维iou张量中获得每个先验框对应的iou最大的真实框的索引以及相应iou的值 :param inside_anchors: 在图片内的先验框(anchors) :param boxes: 图片中的真实框 :return: argmax_ious: 每个inside_anchor对应所有boxes中的最高iou的索引, 维度为: [inside_anchors_num] max_ious: 每个inside_anchor对应所有boxes中的最高iou, 维度为: [inside_anchors_num] gt_argmax_ious: 每个box对应所有inside_anchors中的最高iou的索引, 维度为: [inside_anchors_num] """ # 第一个维度是先验框的个数(inside_anchors_num), 第二个维度是真实框的个数(boxes_num) ious = calculate_iou(inside_anchors, boxes) argmax_ious = ious.argmax(axis=1) # 维度为:[inside_num] # 取到每个先验框对应的真实框最大的iou # TODO 将第一个维度从np.arange(len(inside_anchors))改为np.arange(len(ious)) max_ious = ious[np.arange(len(ious)), argmax_ious] gt_argmax_ious = ious.argmax(axis=0) # 维度为:[boxes_num] # 取到每个真实框对应的先验框最大的iou gt_max_ious = ious[gt_argmax_ious, np.arange(ious.shape[1])] gt_argmax_ious = np.where(ious == gt_max_ious)[0] return argmax_ious, max_ious, gt_argmax_ious
这个部分主要逻辑就是保持正负样本的均衡性,在__call___函数里计算了回归值并将rois打上相应的标签信息。
import numpy as np from utils.util import calculate_iou, box2loc class ProposalTargetCreator: def __init__(self, n_sample=128, pos_ratio=0.25, pos_iou_thresh=0.5, neg_iou_thresh_hi=0.5, neg_iou_thresh_lo=0.0): """ function description: 采样128正负样本个传入FastRCNN的网络 :param n_sample: 需要采样的数量 :param pos_ratio: 正样本比例 :param pos_iou_thresh: 正样本阈值 :param neg_iou_thresh_hi: 负样本最大阈值 :param neg_iou_thresh_lo: 负样本最低阈值 :return: sample_rois: 采样后的感兴趣区域 gt_roi_labels: boxes的标签 gt_roi_locs: sample_rois和boxes的线性回归值 """ self.n_sample = n_sample self.pos_ratio = pos_ratio self.pos_iou_thresh = pos_iou_thresh self.neg_iou_thresh_hi = neg_iou_thresh_hi self.neg_iou_thresh_lo = neg_iou_thresh_lo def __call__(self, rois, boxes, labels, loc_normalize_mean=(0., 0., 0., 0.), loc_normalize_std=(0.1, 0.1, 0.2, 0.2)): """ function description: 得到采样后的rois, 及其对应的labels和回归值 :param rois: rpn输入的rois :param boxes: 一幅图的位置标注 :param labels: 一幅图的类别标注 :param loc_normalize_mean: 均值 :param loc_normalize_std: 标准差 :return: """ n_bbox, _ = boxes.shape # 取到正样本的个数(四舍五入) pos_num = np.round(self.n_sample * self.pos_ratio) ious = calculate_iou(rois, boxes) gt_assignment = ious.argmax(axis=1) # 返回维度为[rois_num] max_iou = ious.max(axis=1) gt_roi_labels = labels[gt_assignment] # 返回维度为[rois_num] # 筛选出其中iou满足阈值的部分 pos_index = np.where(max_iou >= self.pos_iou_thresh)[0] pos_num_for_this_image = int(min(pos_num, pos_index.size)) if pos_index.size > 0: pos_index = np.random.choice(pos_index, size=pos_num_for_this_image, replace=False) # 筛选出其中iou不满足阈值的部分 neg_index = np.where((max_iou < self.neg_iou_thresh_hi) & (max_iou >= self.neg_iou_thresh_lo))[0] neg_num = self.n_sample - pos_num_for_this_image neg_num_for_this_image = int(min(neg_index.size, neg_num)) if neg_index.size > 0: neg_index = np.random.choice(neg_index, size=neg_num_for_this_image, replace=False) keep_index = np.append(pos_index, neg_index) gt_roi_labels = gt_roi_labels[keep_index] gt_roi_labels[pos_num_for_this_image:] = 0 # 背景标记为0, pos_num_for_this_image及之后的索引都标为0 sample_rois = rois[keep_index] gt_roi_locs = box2loc(sample_rois, boxes[gt_assignment[keep_index]]) return sample_rois, gt_roi_labels, gt_roi_locs
先来看一下论文中对总loss公式的定义:
至于代码中的实现相当于是加了不同的权重,总的loss值主要是第二部分loss中的,而且只计算label中为正样本的loss值(因为label为0表示背景,会略背景的loss计算)。
def smooth_l1_loss(x, t, in_weight, sigma): """ function description: 计算L1损失函数 :param x: 输出的位置信息 :param t: 标注的位置信息 :param in_weight: 筛选矩阵, 非正样本的地方为0 :param sigma: :return: """ sigma2 = sigma ** 2 diff = in_weight * (x - t) abs_diff = diff.abs() flag = (abs_diff.data < (1. / sigma2)).float() # TODO loss的计算 y = (flag * (sigma2 / 2.) * (diff ** 2) + (1 - flag) * (abs_diff - 0.5 / sigma2)) return y.sum() def loc_loss(pred_loc, gt_loc, gt_label, sigma): """ function description: 仅对正样本进行loc_loss值的计算 :param pred_loc: 输出的位置信息 :param gt_loc: 标注的位置信息 :param gt_label: 标注的类别 :param sigma: :return: """ in_weight = torch.zeros(gt_loc.shape).cuda() # 用作筛选矩阵, 维度为[gt_label_num, 4] in_weight[(gt_label > 0).view(-1, 1).expand_as(in_weight)] = 1 loc_loss = smooth_l1_loss(pred_loc, gt_loc, in_weight.detach(), sigma) loc_loss /= ((gt_label >= 0).sum().float()) return loc_loss
可以将上篇文章中的网络再看看,这里就是将之前讲过的网络组合起来,并计算一个loss。
from torch import nn import torch.nn.functional as F from nets.vgg16 import decom_VGG16 from nets.rpn import RPN from nets.anchor_target_creator import AnchorTargetCreator from nets.proposal_target_creator import ProposalTargetCreator from nets.fast_rcnn import FastRCNN from utils.util import loc_loss from collections import namedtuple import torch from utils.util import loc2box, non_maximum_suppression import numpy as np from configs.config import class_num, device_name LossTuple = namedtuple('LossTuple', ['rpn_loc_loss', 'rpn_cls_loss', 'roi_loc_loss', 'roi_cls_loss', 'total_loss' ]) device = torch.device(device_name) class FasterRCNN(nn.Module): def __init__(self, path): super(FasterRCNN, self).__init__() self.extractor, classifier = decom_VGG16(path) self.rpn = RPN() self.anchor_target_creator = AnchorTargetCreator() self.sample_rois = ProposalTargetCreator() self.fast_rcnn = FastRCNN(n_class=class_num, roi_size=7, spatial_scale=1. / 16, classifier=classifier) # 系数,用来计算l1_smooth_loss self.rpn_sigma = 3. self.roi_sigma = 1. def forward(self, x, gt_boxes, labels): # -----------------part 1: feature 提取部分---------------------- h = self.extractor(x) # -----------------part 2: rpn部分(output_1)---------------------- img_size = (x.size(2), x.size(3)) # rpn_locs维度为: [batch_size, w, h, 4*k], 类型是pytorch的张量 # rpn_scores维度为: [batch_size, w, h, k], 类型是pytorch的张量 # anchors维度为: [batch_size, w*h*k, 4], 类型是numpy数组 # rois维度为: [w*h*k ,4] rpn_locs, rpn_scores, anchors, rois = self.rpn(h, img_size) # gt_anchor_locs维度为: [anchors_num, 4], gt_anchor_labels维度为:[anchors_num] # gt_anchor_labels这个labels如果为1表示先验框内有物体, 0表示先验框内没有物体 gt_anchor_locs, gt_anchor_labels = self.anchor_target_creator(gt_boxes[0].detach().cpu().numpy(), anchors, img_size) # ----------------part 3: roi采样部分---------------------------- # gt_roi_labels这个labels表示rois所属类别 sample_rois, gt_roi_labels, gt_roi_locs = self.sample_rois(rois, gt_boxes[0].detach().cpu().numpy(), labels[0].detach().cpu().numpy()) # ---------------part 4: fast rcnn(roi)部分(output_2)------------ # roi_cls_locs维度为: [batch_size, 4], roi_scores维度为:[batch_size, 1] roi_locs, roi_scores = self.fast_rcnn(h, sample_rois) # RPN LOSS gt_anchor_locs = torch.from_numpy(gt_anchor_locs).to(device) gt_anchor_labels = torch.from_numpy(gt_anchor_labels).long().to(device) # rpn_scores[0]维度为[batch_size, w*h*k, 2], 且第三个维度为0表示不包含object的置信度, 1表示包含object的置信度 rpn_cls_loss = F.cross_entropy(rpn_scores[0], gt_anchor_labels, ignore_index=-1) # label值为-1的不参与loss值的计算 rpn_loc_loss = loc_loss(rpn_locs[0], gt_anchor_locs, gt_anchor_labels, self.rpn_sigma) # ROI LOSS gt_roi_labels = torch.from_numpy(gt_roi_labels).long().to(device) gt_roi_locs = torch.from_numpy(gt_roi_locs).float().to(device) roi_cls_loss = F.cross_entropy(roi_scores, gt_roi_labels) n_sample = roi_locs.shape[0] # batch_size roi_cls_locs = roi_locs.view(n_sample, -1, 4) roi_locs = roi_cls_locs[torch.arange(0, n_sample).long(), gt_roi_labels] roi_loc_loss = loc_loss(roi_locs.contiguous(), gt_roi_locs, gt_roi_labels, self.roi_sigma) losses = [rpn_loc_loss, rpn_cls_loss, roi_loc_loss, roi_cls_loss] losses = losses + [sum(losses)] return LossTuple(*losses) @torch.no_grad() def predict(self, x): # 设置为测试模式, 改变rpn网络中n_post_nms的阈值为300 self.eval() # -----------------part 1: feature 提取部分---------------------- h = self.extractor(x) img_size = (x.size(2), x.size(3)) # ----------------------part 2: rpn部分-------------------------- rpn_locs, rpn_socres, anchors, rois = self.rpn(h, img_size) # ------------------part 3: fast rcnn(roi)部分------------------- # 先经过Roi pooling层, 在经过两个全连接层 roi_locs, roi_scores = self.fast_rcnn(h, np.asarray(rois)) n_sample = roi_locs.shape[0] # --------------------part 4:boxes生成部分----------------------- roi_cls_locs = roi_locs.view(n_sample, -1, 4) rois = torch.from_numpy(rois).to(device) rois = rois.view(-1, 1, 4).expand_as(roi_cls_locs) boxes = loc2box(rois.cpu().numpy().reshape((-1, 4)), roi_cls_locs.cpu().numpy().reshape((-1, 4))) boxes = torch.from_numpy(boxes).to(device) # 修剪boxes中的坐标, 使其落在图片内 boxes[:, [0, 2]] = (boxes[:, [0, 2]]).clamp(min=0, max=img_size[0]) boxes[:, [1, 3]] = (boxes[:, [1, 3]]).clamp(min=0, max=img_size[1]) boxes = boxes.view(n_sample, -1) # roi_scores转换为概率, prob维度为[rois_num, 7] prob = F.softmax(roi_scores, dim=1) # ----------------part 5:筛选环节------------------------ raw_boxes = boxes.cpu().numpy() raw_prob = prob.cpu().numpy() final_boxes, labels, scores = self._suppress(raw_boxes, raw_prob) self.train() return final_boxes, labels, scores def _suppress(self, raw_boxes, raw_prob): # print(raw_prob.shape) score_thresh = 0.7 nms_thresh = 0.3 n_class = class_num box = list() label = list() score = list() for i in range(1, class_num): box_i = raw_boxes.reshape((-1, n_class, 4)) box_i = box_i[:, i, :] # 维度为: [rois_num, k, 4] prob_i = raw_prob[:, i] # 维度为: [rois_num] mask = prob_i > score_thresh box_i = box_i[mask] prob_i = prob_i[mask] order = prob_i.argsort()[::-1] # 按照score值从大到小进行排序 box_i = box_i[order] box_i_after_nms, keep = non_maximum_suppression(box_i, nms_thresh) box.append(box_i_after_nms) label_i = (i - 1) * np.ones((len(keep),)) label.append(label_i) score.append(prob_i[keep]) box = np.concatenate(box, axis=0).astype(np.float32) label = np.concatenate(label, axis=0).astype(np.int32) score = np.concatenate(score, axis=0).astype(np.float32) return box, label, score
写了两种生成txt文件的代码。
第一种基于的前提是数据集给的是txt标注,那可以用下面的函数生成4个txt并生成对应的xml文件。
from lxml import etree as ET import glob import cv2 import random from configs.config import classes_for_label, xml_root_dir, img_root_dir, txt_root_dir, pic_format import numpy as np from PIL import Image def write_xml(filename, saveimg, typename, boxes, xmlpath): """ function description: 将txt的标注文件转为xml :param filename: 图片名 :param saveimg: opencv读取图片张量 :param typename: 类名 :param boxes: 左上角和右下角坐标 :param xmlpath: 保存的xml文件名 """ # 根节点 root = ET.Element("annotation") # folder节点 folder_node = ET.SubElement(root, 'folder') folder_node.text = 'kitti' # filename节点 filename_node = ET.SubElement(root, 'filename') filename_node.text = filename # source节点 source_node = ET.SubElement(root, 'source') database_node = ET.SubElement(source_node, 'database') database_node.text = 'kitti Database' annotation_node = ET.SubElement(source_node, 'annotation') annotation_node.text = 'kitti' image_node = ET.SubElement(source_node, 'image') image_node.text = 'flickr' flickrid_node = ET.SubElement(source_node, 'flickrid') flickrid_node.text = '-1' # owner节点 owner_node = ET.SubElement(root, 'owner') flickrid_node = ET.SubElement(owner_node, 'flickrid') flickrid_node.text = 'muke' name_node = ET.SubElement(owner_node, 'name') name_node.text = 'muke' # size节点 size_node = ET.SubElement(root, 'size') width_node = ET.SubElement(size_node, 'width') width_node.text = str(saveimg.shape[1]) height_node = ET.SubElement(size_node, 'height') height_node.text = str(saveimg.shape[0]) depth_node = ET.SubElement(size_node, 'depth') depth_node.text = str(saveimg.shape[2]) # segmented节点(用于图像分割) segmented_node = ET.SubElement(root, 'segmented') segmented_node.text = '0' # object节点(循环添加节点) for i in range(len(typename)): object_node = ET.SubElement(root, 'object') name_node = ET.SubElement(object_node, 'name') name_node.text = typename[i] pose_node = ET.SubElement(object_node, 'pose') pose_node.text = 'Unspecified' # 是否截断 truncated_node = ET.SubElement(object_node, 'truncated') truncated_node.text = '1' difficult_node = ET.SubElement(object_node, 'difficult') difficult_node.text = '0' bndbox_node = ET.SubElement(object_node, 'bndbox') xmin_node = ET.SubElement(bndbox_node, 'xmin') xmin_node.text = str(boxes[i][0]) ymin_node = ET.SubElement(bndbox_node, 'ymin') ymin_node.text = str(boxes[i][1]) xmax_node = ET.SubElement(bndbox_node, 'xmax') xmax_node.text = str(boxes[i][2]) ymax_node = ET.SubElement(bndbox_node, 'ymax') ymax_node.text = str(boxes[i][3]) tree = ET.ElementTree(root) tree.write(xmlpath, pretty_print=True) def split_dataset_byTXT(): """ function description: 根据总训练集标注的txt文件将其数据集切分为训练集, 验证集以及测试集, 并且写入相应的xml作为标注 """ trainval = open(txt_root_dir + 'trainval.txt', 'w') train = open(txt_root_dir + 'train.txt', 'w') val = open(txt_root_dir + 'val.txt', 'w') test = open(txt_root_dir + 'train_test.txt', 'w') list_anno_files = glob.glob(train_label_path + "*") random.shuffle(list_anno_files) index = 0 for anno_file in list_anno_files: with open(anno_file) as file: boxes = [] typename = [] anno_infos = file.readlines() for anno_item in anno_infos: anno_new_infos = anno_item.split(" ") # 去掉杂项和不关心这俩类别 if anno_new_infos[0] == "Misc" or anno_new_infos[0] == "DontCare": continue else: box = (int(float(anno_new_infos[4])), int(float(anno_new_infos[5])), int(float(anno_new_infos[6])), int(float(anno_new_infos[7]))) boxes.append(box) typename.append(anno_new_infos[0]) filename = anno_file.split("\\")[-1].replace(".txt", pic_format) xmlpath = xml_root_dir + filename.replace(pic_format, ".xml") imgpath = img_root_dir + 'training/' + filename print(imgpath) saveimg = cv2.imread(imgpath) write_xml(filename, saveimg, typename, boxes, xmlpath) index += 1 if index > len(list_anno_files) * 0.9: test.write(filename.replace(pic_format, "\n")) else: trainval.write(filename.replace(pic_format, "\n")) if index > len(list_anno_files) * 0.7: val.write(filename.replace(pic_format, "\n")) else: train.write(filename.replace(pic_format, "\n")) trainval.close() train.close() val.close() test.close()
第二种基于的前提是数据集的标注直接是xml文件,那直接根据文件名生成txt文件就OK了。
def split_dataset_byXML(): """ function description: 根据总训练集的XML标注文件将其切分为训练集, 验证集以及测试集 """ trainval = open(txt_root_dir + 'trainval.txt', 'w') train = open(txt_root_dir + 'train.txt', 'w') val = open(txt_root_dir + 'val.txt', 'w') train_test = open(txt_root_dir + 'train_test.txt', 'w') list_anno_files = glob.glob(xml_root_dir + "*") random.shuffle(list_anno_files) index = 0 for anno_file in list_anno_files: filename = anno_file.replace(".xml", pic_format) index += 1 if index > len(list_anno_files) * 0.9: train_test.write(filename.replace(pic_format, "\n")) else: trainval.write(filename.replace(pic_format, "\n")) if index > len(list_anno_files) * 0.7: val.write(filename.replace(pic_format, "\n")) else: train.write(filename.replace(pic_format, "\n")) trainval.close() train.close() val.close() train_test.close()
因为考虑到真正的测试集是没有标注这么一说的,所以__getitem__函数返回的内容也应该不是一样的。对于测试集和训练集想要最大程度地复用代码,在所以在构造函数里面传了一个标记位,用来区分是train还是test。而且Faster-RCNN训练所需要图片的尺寸是有要求的,最小的边必须超过600px,否则在Roi pooling的时候会出现问题,但是预测不准确,所以我在代码里面还是用了reshape函数。
注意:对于训练集直接reshape那就大错特错了,你需要在缩放图片的同时等比例缩放标注框的位置!下面给的代码都实现了,都是直接将张量和标注拉到内存,所以占用的内存空间会很大。
from torch.utils.data import Dataset, DataLoader from torchvision import transforms import os from data.process_data import parse_xml, reshape import numpy as np from PIL import Image from configs.config import pic_format class ImageDataset(Dataset): def __init__(self, xml_root_dir, img_root_dir, txt_root_dir, txt_file, isTest=False, transform=None): """ class description: 这个类已经将最小边缩放到600px了, 同时将训练集中标注的位置也等比例修改了 :param xml_root_dir: xml标注文件的根路径 :param img_root_dir: img图片的根路径 :param txt_root_dir: txt文件的根路径 :param txt_file: txt文件名 :param isTest: 标志是否是测试集 :param transform: 变换 """ super(ImageDataset, self).__init__() self.xml_root_dir = xml_root_dir self.img_root_dir = img_root_dir self.txt_root_dir = txt_root_dir self.txt_file = txt_file self.isTest = isTest if transform == None: self.transform = transforms.Compose([ # TODO BUG的根源... 为了适配vgg16的输入 # transforms.Resize((int(224), int(224))), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) if self.isTest == False: boxes, labels, images = self.load_txt(self.txt_file) self.boxes = boxes self.labels = labels self.images = images elif self.isTest == True: self.images = self.load_txt(self.txt_file) id_list_files = os.path.join(txt_root_dir, txt_file) self.ids = [id_.strip() for id_ in open(id_list_files)] def load_txt(self, filename): """ function description: 加载txt文件中的信息并放到numpy数组中, numpy可以直接在list中再次添加可变list :param filename: txt文件名 """ print('-------------the file name is ', filename) boxes = [] labels = [] images = [] print(os.path.join(self.txt_root_dir, filename)) with open(os.path.join(self.txt_root_dir, filename), mode='r') as f: lines = f.readlines() # index = 0 for line in lines: line = line.strip() if self.isTest == False: box, label, image = self.load_xml(line + ".xml") boxes.append(box) labels.append(label) # index += 1 elif self.isTest == True: image = (line + pic_format) # image = line.replace("\n", ".jpg") images.append(image) if self.isTest == False: print('the length of boxes is ', len(boxes)) print('the length of labels is ', len(labels)) print('the length of images is ', len(images)) return boxes, labels, images elif self.isTest == True: return images def load_xml(self, filename): """ function description: 加载xml文件中需要的属性并将最小边缩放为600 :param filename: xml文件名 """ path = os.path.join(self.xml_root_dir, filename) if not os.path.exists(path): return boxes, labels = parse_xml(path) img_name = filename.replace(".xml", pic_format) images, boxes = reshape(Image.open(self.img_root_dir + img_name), boxes) return np.stack(boxes).astype(np.float32), \ np.stack(labels).astype(np.int32), \ images def __len__(self): return len(self.images) def __getitem__(self, index): if self.isTest == False: id = self.ids[index] box, label, image = self.load_xml('{0}.xml'.format(id)) img_tensor = self.transform(image) # [channel, height, width] -> [channel, width, height] img_tensor = img_tensor.permute(0, 2, 1) return { "img_name": id + pic_format, "img_tensor": img_tensor, "img_classes": label, "img_gt_boxes": box } elif self.isTest == True: img = Image.open(self.img_root_dir + self.images[index]) img_tensor = self.transform(img) img_tensor = img_tensor.permute(0, 2, 1) return { "img_name": self.images[index], "img_tensor": img_tensor, }
下图是我kitti数据集在我代码上面跑了一个epoch之后进行预测的结果。
效果还是不错的。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。