赞
踩
def seg_inference_on_test_dataset(model, data_loader, evaluate, aug_eval=False, scales=1.0, flip_horizontal=False, flip_vertical=False, is_slide=False, stride=None, crop_size=None, precision='fp32', amp_level='O1', print_detail=True, auc_roc=False): """ Launch evalution. Args: model(nn.Layer): A semantic segmentation model. eval_dataset (paddle.io.Dataset): Used to read and process validation datasets. aug_eval (bool, optional): Whether to use mulit-scales and flip augment for evaluation. Default: False. scales (list|float, optional): Scales for augment. It is valid when `aug_eval` is True. Default: 1.0. flip_horizontal (bool, optional): Whether to use flip horizontally augment. It is valid when `aug_eval` is True. Default: True. flip_vertical (bool, optional): Whether to use flip vertically augment. It is valid when `aug_eval` is True. Default: False. is_slide (bool, optional): Whether to evaluate by sliding window. Default: False. stride (tuple|list, optional): The stride of sliding window, the first is width and the second is height. It should be provided when `is_slide` is True. crop_size (tuple|list, optional): The crop size of sliding window, the first is width and the second is height. It should be provided when `is_slide` is True. precision (str, optional): Use AMP if precision='fp16'. If precision='fp32', the evaluation is normal. amp_level (str, optional): Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don’t support fp16 kernel and batchnorm. Default is O1(amp) num_workers (int, optional): Num workers for data loader. Default: 0. print_detail (bool, optional): Whether to print detailed information about the evaluation process. Default: True. auc_roc(bool, optional): whether add auc_roc metric Returns: float: The mIoU of validation datasets. float: The accuracy of validation datasets. """ if print_detail: #and hasattr(data_loader, 'dataset'): logger.info("Start evaluating (total_samples: {}, total_iters: {})...". format(len(list(data_loader.task_loaders.values())[0].dataset), len(data_loader))) model.eval() pred_res = [] with paddle.no_grad(): for iter, data in enumerate(tqdm(data_loader)): trans_info = data['segmentation']['trans_info'] img_path = data['segmentation']['im_path'][0] im_id = data['segmentation']['im_id'][0] id2path = data['segmentation']['id2path'] # imgname = os.path.splitext(os.path.basename(img_path))[0] + '.png' if aug_eval: pred, _ = aug_inference( model, data, trans_info=trans_info, scales=scales, flip_horizontal=flip_horizontal, flip_vertical=flip_vertical, is_slide=is_slide, stride=stride, crop_size=crop_size) else: pred, _ = inference( model, data, trans_info=trans_info, is_slide=is_slide, stride=stride, crop_size=crop_size) results = [] results_id = [] paddle.distributed.all_gather(results, pred) paddle.distributed.all_gather(results_id, im_id) for k, result in enumerate(results): # pred_img = pred.numpy().squeeze(0).transpose(1,2,0).astype(np.uint8) # cv2.imwrite(save_path + '/' + imgname, pred_img) res = mask2polygon(result.numpy().squeeze(0).squeeze(0).astype(np.uint8)) tmp = dict() id = results_id[k].numpy()[0] imgname = os.path.splitext(os.path.basename(id2path[0][id][0]))[0] + '.png' tmp[imgname] = res pred_res.append(tmp) with open("pred_seg.json", 'w') as f: f.write(json.dumps({'seg': pred_res})) return {}
def mask2polygon(mask_image): """ :param mask_image: 输入mask图片地址, 默认为gray, 且像素值为0或255 :return: list, 每个item为一个labelme的points """ cls_2_polygon = {} for i in range(19): mask = copy.deepcopy(mask_image) mask[mask != i] = 0 mask[mask == i] = 1 mask.astype('uint8') contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) results = [item.squeeze().tolist() for item in contours] cls_2_polygon[i] = results return cls_2_polygon #results
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。