当前位置:   article > 正文

分割Mask2Polygon

分割Mask2Polygon
def seg_inference_on_test_dataset(model,
             data_loader,
             evaluate,
             aug_eval=False,
             scales=1.0,
             flip_horizontal=False,
             flip_vertical=False,
             is_slide=False,
             stride=None,
             crop_size=None,
             precision='fp32',
             amp_level='O1',
             print_detail=True,
             auc_roc=False):
    """
    Launch evalution.

    Args:
        model(nn.Layer): A semantic segmentation model.
        eval_dataset (paddle.io.Dataset): Used to read and process validation datasets.
        aug_eval (bool, optional): Whether to use mulit-scales and flip augment for evaluation. Default: False.
        scales (list|float, optional): Scales for augment. It is valid when `aug_eval` is True. Default: 1.0.
        flip_horizontal (bool, optional): Whether to use flip horizontally augment. It is valid when `aug_eval` is True. Default: True.
        flip_vertical (bool, optional): Whether to use flip vertically augment. It is valid when `aug_eval` is True. Default: False.
        is_slide (bool, optional): Whether to evaluate by sliding window. Default: False.
        stride (tuple|list, optional): The stride of sliding window, the first is width and the second is height.
            It should be provided when `is_slide` is True.
        crop_size (tuple|list, optional):  The crop size of sliding window, the first is width and the second is height.
            It should be provided when `is_slide` is True.
        precision (str, optional): Use AMP if precision='fp16'. If precision='fp32', the evaluation is normal.
        amp_level (str, optional): Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp)
        num_workers (int, optional): Num workers for data loader. Default: 0.
        print_detail (bool, optional): Whether to print detailed information about the evaluation process. Default: True.
        auc_roc(bool, optional): whether add auc_roc metric

    Returns:
        float: The mIoU of validation datasets.
        float: The accuracy of validation datasets.
    """
    
    if print_detail: #and hasattr(data_loader, 'dataset'):
        logger.info("Start evaluating (total_samples: {}, total_iters: {})...".
                    format(len(list(data_loader.task_loaders.values())[0].dataset), len(data_loader)))

    model.eval()

    pred_res = []
    with paddle.no_grad():
        for iter, data in enumerate(tqdm(data_loader)):
            trans_info = data['segmentation']['trans_info']
            img_path = data['segmentation']['im_path'][0]
            im_id = data['segmentation']['im_id'][0]
            id2path = data['segmentation']['id2path']
            # imgname = os.path.splitext(os.path.basename(img_path))[0] + '.png'
            if aug_eval:
                pred, _ = aug_inference(
                    model,
                    data,
                    trans_info=trans_info,
                    scales=scales,
                    flip_horizontal=flip_horizontal,
                    flip_vertical=flip_vertical,
                    is_slide=is_slide,
                    stride=stride,
                    crop_size=crop_size)
            else:
                pred, _ = inference(
                    model,
                    data,
                    trans_info=trans_info,
                    is_slide=is_slide,
                    stride=stride,
                    crop_size=crop_size)

            results = []
            results_id = []
            paddle.distributed.all_gather(results, pred)
            paddle.distributed.all_gather(results_id, im_id)
            for k, result in enumerate(results):                               
            # pred_img = pred.numpy().squeeze(0).transpose(1,2,0).astype(np.uint8)
            # cv2.imwrite(save_path + '/' + imgname, pred_img) 
                res = mask2polygon(result.numpy().squeeze(0).squeeze(0).astype(np.uint8))
                tmp = dict()
                id = results_id[k].numpy()[0]
                imgname = os.path.splitext(os.path.basename(id2path[0][id][0]))[0] + '.png'
                tmp[imgname] = res
                pred_res.append(tmp)

    with open("pred_seg.json", 'w') as f:
        f.write(json.dumps({'seg': pred_res}))
    return {}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
def mask2polygon(mask_image):

    """
    :param mask_image: 输入mask图片地址, 默认为gray, 且像素值为0或255
    :return: list, 每个item为一个labelme的points
    """
    cls_2_polygon = {}
    for i in range(19):
        mask = copy.deepcopy(mask_image)
        mask[mask != i] = 0
        mask[mask == i] = 1
        mask.astype('uint8')
 
        contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        results = [item.squeeze().tolist() for item in contours]
        cls_2_polygon[i] = results

    return cls_2_polygon  #results
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/80991?site
推荐阅读
相关标签
  

闽ICP备14008679号