要求created和error只能有一个为True :param path: :param created: 当文件夹不存_data_utils">
当前位置:   article > 正文

R-CNN原理详解与代码超详细讲解(三)--data_utils代码讲解

data_utils

R-CNN原理详解与代码超详细讲解(三)–data_utils代码讲解

在这里插入图片描述

check_directory代码详解

def check_directory(path, created=True, error=False):
    """
    检查文件或者文件夹路径path是否存在,如果不存在,根据参数created和error进行操作<br/>
    要求created和error只能有一个为True
    :param path:
    :param created:  当文件夹不存在的时候,进行创建
    :param error:  当path不存在的时候,报错
    :return:
    """
    flag = os.path.exists(path)
    if not flag:
        if created:
            os.makedirs(path)
            flag = True
        elif error:
            raise Exception("Path must exists!!{}".format(path))
    return flag
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

resize_image代码详解

def resize_image(in_image, new_width, new_height, out_image=None, resize_mode=cv.INTER_CUBIC):
    """
    进行图像大小重置操作
    :param in_image:  输入的图像
    :param new_width:  新的宽度
    :param new_height:  新的高度
    :param out_image:  输出对象位置路径
    :param resize_mode:  大小重置方式
    :return:
    """
    image = cv.resize(in_image, (new_width, new_height), resize_mode)
    if out_image:
        cv.imwrite(out_image, image)
    return image
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

iou代码详解

def iou(box1, box2):
    """
    计算边框1和边框2的IoU的值
    :param box1:  边框1的坐标值[左上角坐标的x,左上角坐标的y,右小角坐标的x,右小角坐标的y]
    :param box2:  边框2的坐标值[左上角坐标的x,左上角坐标的y,右小角坐标的x,右小角坐标的y]
    :return:
    """
    # 1. 提取边框信息并排序
    x = [(box1[0], 1), (box1[2], 1), (box2[0], 2), (box2[2], 2)]
    y = [(box1[1], 1), (box1[3], 1), (box2[1], 2), (box2[3], 2)]
    x = sorted(x, key=lambda t: t[0])
    y = sorted(y, key=lambda t: t[0])

    # 2. 计算重叠区域(查看排序之后,边框的顺序是否打乱,也就是第一个值和第二值是否来自一个边框,如果来自一个边框,那么表示没有重叠)
    if x[0][1] == x[1][1] or y[0][1] == y[1][1]:
        union_area = 0.0
    else:
        union_area = (x[2][0] - x[1][0]) * (y[2][0] - y[1][0])

    # 3. 计算IoU的值
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
    total_area = box1_area + box2_area - union_area

    # 4. 返回IoU的值
    return 1.0 * union_area / total_area
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

make_training_data代码详解

def make_training_data(in_file, output_data_file, output_label_file, image_width=227, image_height=227):
    """
    构造最原始的数据集,也就是从训练原始图像数据中提取ROI的候选框区域以及相关的属性值
    :param in_file: 训练用原始数据存在的txt文件
    :param output_data_file: 提取出来的特征属性+目标属性存在的路径
    :param output_label_file:  提取出来的class name和id之间映射关系的数据存在的路径
    :param image_width:  ROI区域图形的最终宽度
    :param image_height:  ROI区域图形的最终宽度
    :return:
    """
    # 0. 检查输入的文件是否存在、输出文件所在的文件夹是否存在
    check_directory(in_file, created=False, error=True)
    check_directory(os.path.dirname(output_data_file)) #output_data_file:存放最后每张图片的候选框与真实框的信息datas路径
    check_directory(os.path.dirname(output_label_file))

    # 1. 获取得到fine tune数据对应的根目录文件夹
    root_dir = os.path.dirname(os.path.abspath(in_file))
    # 2. 遍历fine tune数据对应的文件内容
    class_name_2_index_dict = {}
    current_class_index = 1
    '''训练数据格式: 2flowers/jpg/0/image_0561.jpg 2 90,126,350,434'''
    with open(in_file, 'r', encoding='utf-8') as reader:
        datas = [] #记录每张图片的候选框和自身的一个GT框的信息
        for line in reader:
            # a. 分割数据
            values = line.strip().split(" ")

            # b. 异常数据过滤
            if len(values) != 3:
                continue

            # c. 正常数据做处理
            # 1. 解析参数(得到路径、类别id、坐标信息)
            # a. 获取得到原始图形的路径
            image_file_path = os.path.join(root_dir, values[0])
            # b. 获取得到类别id
            class_name = values[1].strip() #class_name是实际物体类别
            try:
                image_label = class_name_2_index_dict[class_name] #image_label是训练时候的物体类别
            except KeyError:
                image_label = current_class_index
                class_name_2_index_dict[class_name] = image_label
                current_class_index += 1
            # c. 获取坐标信息(左上角坐标、右小角坐标、宽度、高度、中心点坐标)
            l_x, l_y, gw, gh = list(map(int, values[2].split(",")))
            r_x = l_x + gw
            r_y = l_y + gh
            gx = (l_x + r_x) // 2
            gy = (l_y + r_y) // 2


            # 2. 读取图像形成Image对象
            image = cv.imread(image_file_path)


            # 3. 获取Ground Truth的边框区域图形
            ground_truth_image = image[l_y:r_y, l_x:r_x]

            # 4. 通过Selective Search获取ROI候选区域边框
            _, regions = selective_search(image, scale=500, sigma=0.9, min_size=10)

            # 5. 遍历所有的ROI区域,计算各个区域和GT的IoU值等信息
            candidate = set()
            for idx, region in enumerate(regions):
                # a. 提取当前候选框对应的特征信息
                rect = region['rect']
                size = region['size']
                # 提取坐标信息(左上角坐标、右小角坐标、宽度、高度、中心点坐标)
                lr_x, lr_y, pw, ph = rect
                rr_x = lr_x + pw #候选框的右下角坐标
                rr_y = lr_y + ph #候选框的右下角坐标
                px = (lr_x + rr_x) // 2 #候选框的中心点
                py = (lr_y + rr_y) // 2 #候选框的中心点

                # b. 做过滤操作
                if rect in candidate:
                    continue
                if size < 200:
                    continue
                if pw * ph < 500:
                    continue

                # c. 记录当前区域经过处理的候选框
                candidate.add(rect)

                # d. 获取当前ROI区域候选框对应的特征信息
                # 1. 获取区域图像建议
                region_proposal = image[lr_y:rr_y, lr_x:rr_x]

                # 2. 计算RoI区域和GT区域的IoU的值
                region_iou = iou(
                    box1=[l_x, l_y, r_x, r_y],
                    box2=[lr_x, lr_y, rr_x, rr_y]
                )

                # 3. 对区域建议图形进行resize大小重置
                region_proposal = resize_image(region_proposal,
                                               new_width=image_width,
                                               new_height=image_height)

                # 3. 计算偏移量信息(真实框与候选框的偏移)
                tx = (gx - px) / pw
                ty = (gy - py) / ph
                tw = np.log(gw / pw)
                th = np.log(gh / ph)
                offset_box = [tx, ty, tw, th]

                # c. 添加特征信息的数据
                # 需要的数据格式: 区域图像image、区域图形所属物体类别(原始图像的所属类别)、候选框类型(0真实框、1区域建议候选框)、IoU、offset_box(偏移量的值)
                data = []
                # 添加图形数据
                data.append(region_proposal)
                # 添加图形所属类别
                data.append(image_label)
                # 添加候选框的类型
                data.append(1)
                # 添加IoU
                data.append(region_iou)
                # 添加回归训练用的目标属性: offset box
                data.append(offset_box)
                datas.append(data)

            # 6. 添加真实框的信息
            # 需要的数据格式: 区域图像image、区域图形所属物体类别(原始图像的所属类别)、候选框类型(0真实框、1区域建议候选框)、IoU、offset_box(偏移量的值)
            data = []
            # 添加图形数据(GT框大小不能保证是227*227,所以要记得resize)
            ground_truth_image = resize_image(ground_truth_image,
                                              new_width=image_width,
                                              new_height=image_height)
            data.append(ground_truth_image)
            # 添加图形所属类别
            data.append(image_label)
            # 添加候选框的类型
            data.append(0)
            # 添加IoU
            data.append(1.0)
            # 添加回归训练用的目标属性: offset box
            data.append([0, 0, 0, 0])
            datas.append(data)

    # 3. 数据持久化操作
    np.save(output_data_file, datas)
    with open(output_label_file, 'wb') as writer:
        pickle.dump(class_name_2_index_dict, writer)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144

FlowerDataLoader代码详解

class FlowerDataLoader(object):
    '''数据生成器'''
    def __init__(self, one_hot=True):
        # 设置相关属性
        self.image_width = cfg.IMAGE_WIDTH #图片宽
        self.image_height = cfg.IMAGE_HEIGHT #图片高

        self.fine_tune_positive_batch_size = cfg.FINE_TUNE_POSITIVE_BATCH_SIZE #微调训练的时候正样本的批次大小,默认为8.但是在文献中是24
        self.fine_tune_negative_batch_size = cfg.FINE_TUNE_NEGATIVE_BATCH_SIZE #微调训练的时候负样本的批次大小,默认为24,但是在文献中是96
        self.fine_tune_iou_threshold = cfg.FINE_TUNE_IOU_THRESHOLD #微调训练的时候正负样本区别的iou的大小

        self.higher_features_iou_threshold = 0.3 #svm训练时候正负样本区别的iou大小

        # Fine Tune训练相关变量
        fine_tune_X = []  #存放微调x数据
        fine_tune_Y = []  #存放微调y数据
        total_fine_tune_samples = 0  #记录微调数据数量
        fine_tune_positive_samples = 0 #记录正样本微调数据数量
        fine_tune_negative_samples = 0 #记录负样本微调数据数量
        fine_tune_positive_samples_index = [] #存放正样本数据的下标
        fine_tune_negative_samples_index = [] #存放负样本数据的下标

        higher_features_Y = [] #存放svm分类器数据的标签
        higher_features_label_2_samples_index = defaultdict(list)
        higher_features_label_2_negative_samples = 0 #记录svm训练时候负样本数量
        higher_features_label_2_positive_samples = 0 #记录svm训练时候正样本数量

        # 一、加载Fine Tune微调训练用的X、Y的值形成的NumPy的对象
        # 1. 提取训练数据对应的Selective Search的边框值(检查SS的结果磁盘文件是否存在,如果不存在,就构造生成SS结果数据,如果存在,就直接加载)
        print("Start load training data.....")
        if not check_directory(cfg.TRAIN_DATA_FILE_PATH, False, False):
            '''如果training_data文件不存在,那么就进行make_training_data'''
            print("Training data file not exists, So load traning data and save to file.....")
            make_training_data(in_file=cfg.ORIGINAL_FINE_TUNE_DATA_FILE_PATH,
                               output_data_file=cfg.TRAIN_DATA_FILE_PATH,
                               output_label_file=cfg.TRAIN_LABEL_DICT_FILE_PATH,
                               image_width=self.image_width, image_height=self.image_height)
        datas = np.load(cfg.TRAIN_DATA_FILE_PATH,allow_pickle=True)

        # 2. 遍历边框值,得到训练可用的X和Y的值
        for idx, (image, label, box_type, region_iou, box) in enumerate(datas):
            # 加载图像数据
            fine_tune_X.append(image)
            total_fine_tune_samples += 1

            # 加载fine tune的标签相关数据
            if region_iou > self.fine_tune_iou_threshold:#找出所有iou大于阈值的框作为微调的正样本
                # 正样本
                fine_tune_Y.append(label)
                fine_tune_positive_samples_index.append(idx)
                fine_tune_positive_samples += 1
            else: #找出所有iou大于阈值的框作为微调的负样本
                # 负样本
                fine_tune_Y.append(0)
                fine_tune_negative_samples_index.append(idx)
                fine_tune_negative_samples += 1

            # 加载svm高阶特征训练相关的数据
            if region_iou < self.higher_features_iou_threshold:#找出所有iou大于阈值的框作为svm的正样本
                # 负例样本: IoU < 0.3
                higher_features_label_2_negative_samples += 1
                higher_features_Y.append(0)
                higher_features_label_2_samples_index[label].append(idx)
            else:#找出所有iou大于阈值的框作为svm的负样本
                higher_features_Y.append(label)
                if int(box_type) == 0:
                    higher_features_label_2_positive_samples += 1
                    higher_features_label_2_samples_index[label].append(idx)
        print("Complete load training data!!!! Total samples:{}".format(total_fine_tune_samples))
        print("Fine tune positive example:{}, negative example:{}".format(fine_tune_positive_samples,
                                                                          fine_tune_negative_samples))
        print("Higher Features positive sample:{}, negative example:{}".format(higher_features_label_2_positive_samples,
                                                                               higher_features_label_2_negative_samples))
        print('*'*30)
        print('higher_features_label_2_samples_index:',higher_features_label_2_samples_index)

        # 进行Fine Tune相关数据赋值操作
        self.fine_tune_x = np.asarray(fine_tune_X)
        if one_hot:
            # 对需要做哑编码操作的进行对应操作
            one_hot_encoder = OneHotEncoder(sparse=False, categories='auto')
            self.fine_tune_y = np.asarray(one_hot_encoder.fit_transform(np.reshape(fine_tune_Y, (-1, 1))))
            pass
        else:
            self.fine_tune_y = np.asarray(fine_tune_Y).reshape((-1, 1))
        self.total_fine_tune_samples = total_fine_tune_samples
        self.fine_tune_positive_samples = fine_tune_positive_samples
        self.fine_tune_negative_samples = fine_tune_negative_samples
        self.fine_tune_positive_cursor = 0 #创建正样本计数游标
        self.fine_tune_negative_cursor = 0 #创建负样本计数游标
        self.fine_tune_positive_samples_index = np.asarray(fine_tune_positive_samples_index)
        self.fine_tune_negative_samples_index = np.asarray(fine_tune_negative_samples_index)
        np.random.shuffle(self.fine_tune_positive_samples_index)
        np.random.shuffle(self.fine_tune_negative_samples_index)
        print('fine_tune_positive_samples_index:',fine_tune_positive_samples_index)
        print('fine_tune_negative_samples_index:',fine_tune_negative_samples_index)
        print('higher_features_Y',higher_features_Y)

        # 进行svm高阶特征获取的相关数据给定
        self.higher_features_y = np.asarray(higher_features_Y)
        self.higher_features_label_2_samples_index = higher_features_label_2_samples_index

    def __fetch_batch(self, batch_size, cursor, total_samples, x, y, index):
        """
        基于给定的数据获取当前批次的数据(X\Y)以及下一个批次获取前是否需要进行数据的重置操作
        :param batch_size:
        :param cursor:
        :param total_samples:
        :param x:
        :param y:
        :param index:
        :return:
        """
        need_reset_data = False
        # 1. 计算当前开始的下标、结束的下标
        start_idx = cursor * batch_size
        end_idx = start_idx + batch_size

        # 2. 如果已经数据的尾巴的位置的话,那么需要在下一次获取数据之前,重置数据
        if end_idx >= total_samples:
            need_reset_data = True

        # 3. 获取样本下标
        sample_index = index[start_idx:end_idx]

        # 4. 基于下标获取对象的样本
        images = x[sample_index]
        labels = y[sample_index]

        # 5. 返回结果
        return images, labels, need_reset_data

    def get_fine_tune_batch(self):
        """
        按照给定的属性获取正样本和负样本,并合并返回
        :return:
        """
        # 一、获取正样本
        positive_images, positive_labels, flag = self.__fetch_batch(
            batch_size=self.fine_tune_positive_batch_size,
            cursor=self.fine_tune_positive_cursor,
            total_samples=self.fine_tune_positive_samples,
            x=self.fine_tune_x,
            y=self.fine_tune_y,
            index=self.fine_tune_positive_samples_index)

        if flag:
            print("Reset fine tune positive samples!!!")
            self.fine_tune_positive_cursor = 0
            np.random.shuffle(self.fine_tune_positive_samples_index)
        else:
            self.fine_tune_positive_cursor += 1

        # 二、获取负样本
        negative_images, negative_labels, flag = self.__fetch_batch(
            batch_size=self.fine_tune_negative_batch_size,
            cursor=self.fine_tune_negative_cursor,
            total_samples=self.fine_tune_negative_samples,
            x=self.fine_tune_x,
            y=self.fine_tune_y,
            index=self.fine_tune_negative_samples_index)
        if flag:
            print("Reset fine tune negative samples!!!")
            self.fine_tune_negative_cursor = 0
            np.random.shuffle(self.fine_tune_negative_samples_index)
        else:
            self.fine_tune_negative_cursor += 1

        # 三、数据合并
        images = np.concatenate([positive_images, negative_images], axis=0)
        labels = np.concatenate([positive_labels, negative_labels], axis=0)

        return images, labels

    def get_structure_higher_features(self, label):
        """
        基于给定的标签获取训练svm用的原始数据
        :param label:
        :return:
        """
        if label in self.higher_features_label_2_samples_index:
            index = self.higher_features_label_2_samples_index[label]
            # 加下面这行代码的原因是,代码有bug
            index = index[:10]
            return self.fine_tune_x[index], self.higher_features_y[index]
        else:
            return None, None
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/337062
推荐阅读
相关标签
  

闽ICP备14008679号