当前位置:   article > 正文

使用NCNN在PC端部署PaddleOCR_paddleocr ncnn

paddleocr ncnn

序言

上一篇文章记录了paddle-lite在arm开发板上部署PaddleOCR的流程,文末也提到了自己使用NCNN去部署的预告,正好今天刚好有时间,就写一篇文章记录一下NCNN部署的过程,本文暂时是介绍了部署在PC上,后面会再出一篇在arm开发板上部署的文章,因为流程基本上是一样的,可能有些同学对NCNN的部署不是太了解,所以从PC端讲起会更容易理解一些。本文还是部署的PaddleOCR的移动端模型。

贴一段官方介绍:
ncnn 是一个为手机端极致优化的高性能神经网络前向计算框架。ncnn 从设计之初深刻考虑手机端的部署和使用。无第三方依赖,跨平台,手机端 cpu 的速度快于目前所有已知的开源框架。基于 ncnn,开发者能够将深度学习算法轻松移植到手机端高效执行,开发出人工智能 APP,将 AI 带到你的指尖。ncnn 目前已在腾讯多款应用中使用,如 QQ,Qzone,微信,天天P图等。

一、NCNN环境准备

因为只是在PC上部署,所以只需要ubuntu电脑一台,CMake这些环境默认是安装好了的,然后编译NCNN,编译的流程如下:

git clone https://github.com/Tencent/ncnn.git
cd ncnn
git submodule update --init
mkdir build
cd build
cmake ..
make
make install
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

一般来说编译都是比较顺利的,编译完后在build的文件夹中可以看到如下文件:
在这里插入图片描述
NCNN的环境就编译好了。

二、模型转换paddle->onnx->ncnn

2.1 paddle->onnx

使用ncnn推理,需要将paddle框架下的模型转换成ncnn后才可以使用,因为paddle没有直接转成ncnn的路径,所以需要经过onnx再转过去,首先需要将paddle的模型下载下来,路径PaddleOCR
在这里插入图片描述
我下载的是图上画圈的推理模型,为什么不用最新的ch_PP-OCRv2_xx,是因为我再转这个版本的识别模型的时候,模型转换成功了,每遇到什么报错,但是推理结果识别有点差,暂时没找到问题的所在,所以还是用了原来的mobilenetv3的版本,这个版本转换后识别是正常的。

下载下来后使用paddle2onnx将paddle模型转换成onnx模型,paddle2onnx的安装方式也比较简单:

 pip install paddle2onnx
  • 1

paddle2onnx转onnx有两种模型,一种是静态图转换,一种是动态图转换,这里指的是paddle的框架模型,因为paddle即支持静态图也支持动态图,而我们所熟悉的pytorch是动态图框架,早期的tensorflow是静态图框架。

paddleOCR的预训练模型是动态图权重,推理模型是经过转换后的静态图权重,因为我们下载的是推理模型,所以使用静态图的方式转换;

检测模型转换
paddle2onnx --model_dir ./inference/ch_ppocr_mobile_v2.0_det_infer --model_filename inference.pdmodel --params_filename inference.pdiparams --save_file ./inference/ch_ppocr_mobile_v2.0_rec_infer/det.onnx --opset_version 11 --enable_onnx_checker True 
  • 1

检测模型转onnx后默认的尺寸为(-1,3 ,640 ,640),如果我们需要修改模型输入的尺寸和batch,可以使用以下python代码用onnx修改得到的onnx模型输入尺寸(?, 3 ,?,?):

file_path = './inference/ch_ppocr_mobile_v2.0_det_infer/det.onnx'
model = onnx.load(file_path)
model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = '?'
model.graph.input[0].type.tensor_type.shape.dim[2].dim_param = '?'
model.graph.input[0].type.tensor_type.shape.dim[3].dim_param = '?'
onnx.save(model, './inference/ch_ppocr_mobile_v2.0_det_infer/det.onnx')

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
识别模型转换
paddle2onnx --model_dir ./inference/ch_ppocr_mobile_v2.0_rec_infer --model_filename inference.pdmodel --params_filename inference.pdiparams --save_file ./inference/ch_ppocr_mobile_v2.0_rec_infer/rec_mbv3.onnx --opset_version 11 --enable_onnx_checker True
  • 1

识别模型转换后默认的尺寸是(-1,3,32,100),修改动态尺寸方法同上。

两个模型默认的输入输出分别为:x,save_infer_model/scale_0.tmp_1,这个可以使用netron进行查看。

如果不确定自己的onnx模型转换出来是否正确,可以使用以下python代码进行检测,提供了onnx的推理识别过程,代码比较长贴在文末。

2.2 onnx->ncnn

转onnx的动态尺寸修改是可选的,因为ncnn本身就支持动态输入,所以不用以上代码修改也是可以的,贴出来只是为了方便有时候需要onnx推理校验的时候用到,最后转换后得到两个onnx文件:det.onnx和rec_mbv3.onnx,直接转ncnn的话会有点问题,需要用onnxsim简化一下,合并一些op,onnxsim的使用如下:

pip install onnx-simplifier                  # pip安装,如果已经安装了可以不用执行这一步
python -m onnxsim det.onnx det_sim.onnx      # 直接运行
  • 1
  • 2

cd到onnx文件所在的目录下,分别简化det和rec_mbv3的onnx模型,得到det_sim.onnx和rec_mbv3_sim.onnx文件。将这两个文件拷贝到刚才编译的ncnn文件夹中,拷贝的位置为./ncnn/build/tools/onnx

在这里插入图片描述
然后在该目录下打开终端,分别运行:

./onnx2ncnn det_sim.onnx det_sim.param det_sim.bin

./onnx2ncnn rec_mbv3_sim.onnx rec_mbv3_sim.param rec_mbv3_sim.bin
  • 1
  • 2
  • 3

得到如下四个文件:
在这里插入图片描述
然后拷贝出来备用。ncnn还提供了fp16和int8量化功能,这个可以自己去了解,识别模型量化会有问题,可以只量化检测模型。
NCNN量化之ncnn2table和ncnn2int8

三、推理代码

先上一张代码结构,环境是Clion,文件结构如下:
在这里插入图片描述
其中,dict里是字典文件,include和lib文件夹是从ncnn中拷贝过来,在./ncnn/build/install中,方便引用;四个源文件cpp,其头文件也新建在include文件夹中;model里存放了刚才生成的ncnn模型文件,新建一个CMakeLists.txt文件夹,编写内容如下:

cmake_minimum_required(VERSION 3.16)
project(ocr_ncnn)

find_package(OpenCV REQUIRED)

include_directories(${OpenCV_INCLUDE_DIRS})

include_directories(include/ncnn)
include_directories(include)
link_directories(lib)

FIND_PACKAGE( OpenMP REQUIRED)
if(OPENMP_FOUND)
    message("OPENMP FOUND")
    set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
    set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
endif()

set(CMAKE_CXX_STANDARD 14)
add_executable(ocr_ncnn demo.cpp src/clipper.cpp src/DbNet.cpp src/OcrStruct.cpp src/Crnn.cpp)
target_link_libraries(ocr_ncnn ncnn ${OpenCV_LIBS})
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

接口调用示例:

//
// Created by cai on 2021/10/13.
//

#include <opencv2/opencv.hpp>
#include <iostream>
#include <opencv2/highgui/highgui.hpp>
#include "net.h"
#include <OcrStruct.h>
#include <DbNet.h>
#include <Crnn.h>

using namespace cv;
using namespace std;

int main(){
    DbNet dbNet;
    CRNN Crnn;

    bool retDbNet = dbNet.initModel("../model/det_int8");
    bool retCrnn = Crnn.initModel("../model/rec_mbv3");

    if (!retDbNet || !retCrnn){
        printf("DBNet load model fail!");
    }

    const char*imagepath = "../test_img";
    vector<String> imagesPath;
    cv::glob(imagepath,imagesPath);

    for (int i =0;i<imagesPath.size();i++) {
        //载入图像
        cout << imagesPath[i] << endl;
        Mat image = imread(imagesPath[i]);

        if (image.empty()) {
            cout << "Error: Could not load image" << endl;
            return -1;
        }
        //【3】记录起始时间
        double time0 = static_cast<double>(getTickCount());  // 记录开始时间

        vector<cv::Mat> crop;
        crop = dbNet.getTextImages(image);

        vector<std::string> result;
        result = Crnn.getRecText(crop);

        //【5】计算运行时间并输出
        time0 = ((double) getTickCount() - time0) / getTickFrequency();   //结束时间-开始时间,并化为秒单位
        cout << "\t识别运行时间为: " << time0 << "秒" << endl;    //输出运行时间

        for (auto &txt : result) {                       // 输出识别结果
            cout << txt << "\n" << endl;
        }
    }
    return 0;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58

运行示例:
在这里插入图片描述

补充onnx的python推理代码:
import os
import sys
import cv2
import time
import onnx
import math
import copy
import onnxruntime
import numpy as np
import pyclipper
from shapely.geometry import Polygon

# PalldeOCR 检测模块 需要用到的图片预处理类
class NormalizeImage(object):
    """ normalize image such as substract mean, divide std
    """

    def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
        if isinstance(scale, str):
            scale = eval(scale)
        self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
        mean = mean if mean is not None else [0.485, 0.456, 0.406]
        std = std if std is not None else [0.229, 0.224, 0.225]

        shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
        self.mean = np.array(mean).reshape(shape).astype('float32')
        self.std = np.array(std).reshape(shape).astype('float32')

    def __call__(self, data):
        img = data['image']
        from PIL import Image
        if isinstance(img, Image.Image):
            img = np.array(img)

        assert isinstance(img,
                          np.ndarray), "invalid input 'img' in NormalizeImage"
        data['image'] = (
            img.astype('float32') * self.scale - self.mean) / self.std
        return data


class ToCHWImage(object):
    """ convert hwc image to chw image
    """

    def __init__(self, **kwargs):
        pass

    def __call__(self, data):
        img = data['image']
        from PIL import Image
        if isinstance(img, Image.Image):
            img = np.array(img)
        data['image'] = img.transpose((2, 0, 1))
        return data


class KeepKeys(object):
    def __init__(self, keep_keys, **kwargs):
        self.keep_keys = keep_keys

    def __call__(self, data):
        data_list = []
        for key in self.keep_keys:
            data_list.append(data[key])
        return data_list

class DetResizeForTest(object):
    def __init__(self, **kwargs):
        super(DetResizeForTest, self).__init__()
        self.resize_type = 0
        self.limit_side_len = kwargs['limit_side_len']
        self.limit_type = kwargs.get('limit_type', 'min')

    def __call__(self, data):
        img = data['image']

        src_h, src_w, _ = img.shape
        img, [ratio_h, ratio_w] = self.resize_image_type0(img)

        data['image'] = img
        data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
        return data

    def resize_image_type0(self, img):
        """
        resize image to a size multiple of 32 which is required by the network
        args:
            img(array): array with shape [h, w, c]
        return(tuple):
            img, (ratio_h, ratio_w)
        """
        limit_side_len = self.limit_side_len
        h, w, _ = img.shape

        # limit the max side
        if max(h, w) > limit_side_len:
            if h > w:
                ratio = float(limit_side_len) / h
            else:
                ratio = float(limit_side_len) / w
        else:
            ratio = 1.
        resize_h = int(h * ratio)
        resize_w = int(w * ratio)


        resize_h = int(round(resize_h / 32) * 32)
        resize_w = int(round(resize_w / 32) * 32)

        try:
            if int(resize_w) <= 0 or int(resize_h) <= 0:
                return None, (None, None)
            img = cv2.resize(img, (int(resize_w), int(resize_h)))
        except:
            print(img.shape, resize_w, resize_h)
            sys.exit(0)
        ratio_h = resize_h / float(h)
        ratio_w = resize_w / float(w)
        # return img, np.array([h, w])
        return img, [ratio_h, ratio_w]

### 检测结果后处理过程(得到检测框)
class DBPostProcess(object):
    """
    The post process for Differentiable Binarization (DB).
    """

    def __init__(self,
                 thresh=0.3,
                 box_thresh=0.7,
                 max_candidates=1000,
                 unclip_ratio=2.0,
                 use_dilation=False,
                 **kwargs):
        self.thresh = thresh
        self.box_thresh = box_thresh
        self.max_candidates = max_candidates
        self.unclip_ratio = unclip_ratio
        self.min_size = 3
        self.dilation_kernel = None if not use_dilation else np.array(
            [[1, 1], [1, 1]])

    def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
        '''
        _bitmap: single map with shape (1, H, W),
                whose values are binarized as {0, 1}
        '''

        bitmap = _bitmap
        height, width = bitmap.shape

        outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
                                cv2.CHAIN_APPROX_SIMPLE)
        if len(outs) == 3:
            img, contours, _ = outs[0], outs[1], outs[2]
        elif len(outs) == 2:
            contours, _ = outs[0], outs[1]

        num_contours = min(len(contours), self.max_candidates)

        boxes = []
        scores = []
        for index in range(num_contours):
            contour = contours[index]
            points, sside = self.get_mini_boxes(contour)
            if sside < self.min_size:
                continue
            points = np.array(points)
            score = self.box_score_fast(pred, points.reshape(-1, 2))
            if self.box_thresh > score:
                continue

            box = self.unclip(points).reshape(-1, 1, 2)
            box, sside = self.get_mini_boxes(box)
            if sside < self.min_size + 2:
                continue
            box = np.array(box)

            box[:, 0] = np.clip(
                np.round(box[:, 0] / width * dest_width), 0, dest_width)
            box[:, 1] = np.clip(
                np.round(box[:, 1] / height * dest_height), 0, dest_height)
            boxes.append(box.astype(np.int16))
            scores.append(score)
        return np.array(boxes, dtype=np.int16), scores

    def unclip(self, box):
        unclip_ratio = self.unclip_ratio
        poly = Polygon(box)
        distance = poly.area * unclip_ratio / poly.length
        offset = pyclipper.PyclipperOffset()
        offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
        expanded = np.array(offset.Execute(distance))
        return expanded

    def get_mini_boxes(self, contour):
        bounding_box = cv2.minAreaRect(contour)
        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])

        index_1, index_2, index_3, index_4 = 0, 1, 2, 3
        if points[1][1] > points[0][1]:
            index_1 = 0
            index_4 = 1
        else:
            index_1 = 1
            index_4 = 0
        if points[3][1] > points[2][1]:
            index_2 = 2
            index_3 = 3
        else:
            index_2 = 3
            index_3 = 2

        box = [
            points[index_1], points[index_2], points[index_3], points[index_4]
        ]
        return box, min(bounding_box[1])

    def box_score_fast(self, bitmap, _box):
        h, w = bitmap.shape[:2]
        box = _box.copy()
        xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
        xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
        ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
        ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)

        mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
        box[:, 0] = box[:, 0] - xmin
        box[:, 1] = box[:, 1] - ymin
        cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
        return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]

    def __call__(self, outs_dict, shape_list):
        pred = outs_dict
        pred = pred[:, 0, :, :]
        segmentation = pred > self.thresh

        boxes_batch = []
        for batch_index in range(pred.shape[0]):
            src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]               # 图片缩放比例
            if self.dilation_kernel is not None:
                mask = cv2.dilate(
                    np.array(segmentation[batch_index]).astype(np.uint8),
                    self.dilation_kernel)
            else:
                mask = segmentation[batch_index]
            boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
                                                   src_w, src_h)
            boxes_batch.append({'points': boxes})
        return boxes_batch

## 根据推理结果解码识别结果
class process_pred(object):
    def __init__(self, character_dict_path=None, character_type='ch', use_space_char=False):
        self.character_str = ''
        with open(character_dict_path, 'rb') as fin:
            lines = fin.readlines()
            for line in lines:
                line = line.decode('utf-8').strip('\n').strip('\r\n')
                self.character_str += line
        if use_space_char:
            self.character_str += ' '
        dict_character = list(self.character_str)

        dict_character = self.add_special_char(dict_character)
        self.dict = {}
        for i, char in enumerate(dict_character):
            self.dict[char] = i
        self.character = dict_character

    def add_special_char(self, dict_character):
        dict_character = ['blank'] + dict_character
        return dict_character

    def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
        result_list = []
        ignored_tokens = [0]
        batch_size = len(text_index)
        for batch_idx in range(batch_size):
            char_list = []
            conf_list = []
            for idx in range(len(text_index[batch_idx])):
                if text_index[batch_idx][idx] in ignored_tokens:
                    continue
                if is_remove_duplicate:
                    if idx > 0 and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]:
                        continue
                char_list.append(self.character[int(text_index[batch_idx][idx])])
                if text_prob is not None:
                    conf_list.append(text_prob[batch_idx][idx])
                else:
                    conf_list.append(1)
            text = ''.join(char_list)
            result_list.append((text, np.mean(conf_list)))
        return result_list

    def __call__(self, preds, label=None):
        if not isinstance(preds, np.ndarray):
            preds = np.array(preds)
        preds_idx = preds.argmax(axis=2)
        preds_prob = preds.max(axis=2)
        text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
        if label is None:
            return text
        label = self.decode(label)
        return text, label


class det_rec_functions(object):

    def __init__(self, image, use_large=False):
        self.img = image.copy()
        self.det_file = './weights/det_d.onnx'
        self.small_rec_file = './weights/rec_d.onnx'
        self.large_rec_file = './weights/rec_d.onnx'
        self.onet_det_session = onnxruntime.InferenceSession(self.det_file)
        if use_large:
            self.onet_rec_session = onnxruntime.InferenceSession(self.large_rec_file)
        else:
            self.onet_rec_session = onnxruntime.InferenceSession(self.small_rec_file)
        self.infer_before_process_op, self.det_re_process_op = self.get_process()
        self.postprocess_op = process_pred('./torchocr/datasets/alphabets/ppocr_keys_v1.txt', 'ch', True)

    ## 图片预处理过程
    def transform(self, data, ops=None):
        """ transform """
        if ops is None:
            ops = []
        for op in ops:
            data = op(data)
            if data is None:
                return None
        return data

    def create_operators(self, op_param_list, global_config=None):
        """
        create operators based on the config

        Args:
            params(list): a dict list, used to create some operators
        """
        assert isinstance(op_param_list, list), ('operator config should be a list')
        ops = []
        for operator in op_param_list:
            assert isinstance(operator,
                              dict) and len(operator) == 1, "yaml format error"
            op_name = list(operator)[0]
            param = {} if operator[op_name] is None else operator[op_name]
            if global_config is not None:
                param.update(global_config)
            op = eval(op_name)(**param)
            ops.append(op)
        return ops

    ### 检测框的后处理
    def order_points_clockwise(self, pts):
        """
        reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
        # sort the points based on their x-coordinates
        """
        xSorted = pts[np.argsort(pts[:, 0]), :]

        # grab the left-most and right-most points from the sorted
        # x-roodinate points
        leftMost = xSorted[:2, :]
        rightMost = xSorted[2:, :]

        # now, sort the left-most coordinates according to their
        # y-coordinates so we can grab the top-left and bottom-left
        # points, respectively
        leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
        (tl, bl) = leftMost

        rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
        (tr, br) = rightMost

        rect = np.array([tl, tr, br, bl], dtype="float32")
        return rect

    def clip_det_res(self, points, img_height, img_width):
        for pno in range(points.shape[0]):
            points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
            points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
        return points

    def filter_tag_det_res(self, dt_boxes, image_shape):
        img_height, img_width = image_shape[0:2]
        dt_boxes_new = []
        for box in dt_boxes:
            box = self.order_points_clockwise(box)
            box = self.clip_det_res(box, img_height, img_width)
            rect_width = int(np.linalg.norm(box[0] - box[1]))
            rect_height = int(np.linalg.norm(box[0] - box[3]))
            if rect_width <= 3 or rect_height <= 3:
                continue
            dt_boxes_new.append(box)
        dt_boxes = np.array(dt_boxes_new)
        return dt_boxes

    ### 定义图片前处理过程,和检测结果后处理过程
    def get_process(self):
        det_db_thresh = 0.3
        det_db_box_thresh = 0.5
        max_candidates = 2000
        unclip_ratio = 1.6
        use_dilation = True

        pre_process_list = [{
            'DetResizeForTest': {
                'limit_side_len': 2500,
                'limit_type': 'max'
            }
        }, {
            'NormalizeImage': {
                'std': [0.5, 0.5, 0.5],
                'mean': [0.5, 0.5, 0.5],
                'scale': '1./255.',
                'order': 'hwc'
            }
        }, {
            'ToCHWImage': None
        }, {
            'KeepKeys': {
                'keep_keys': ['image', 'shape']
            }
        }]

        infer_before_process_op = self.create_operators(pre_process_list)
        det_re_process_op = DBPostProcess(det_db_thresh, det_db_box_thresh, max_candidates, unclip_ratio, use_dilation)
        return infer_before_process_op, det_re_process_op

    def sorted_boxes(self, dt_boxes):
        """
        Sort text boxes in order from top to bottom, left to right
        args:
            dt_boxes(array):detected text boxes with shape [4, 2]
        return:
            sorted boxes(array) with shape [4, 2]
        """
        num_boxes = dt_boxes.shape[0]
        sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
        _boxes = list(sorted_boxes)

        for i in range(num_boxes - 1):
            if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
                    (_boxes[i + 1][0][0] < _boxes[i][0][0]):
                tmp = _boxes[i]
                _boxes[i] = _boxes[i + 1]
                _boxes[i + 1] = tmp
        return _boxes

    ### 图像输入预处理
    def resize_norm_img(self, img, max_wh_ratio):
        imgC, imgH, imgW = [int(v) for v in "3, 32, 100".split(",")]
        assert imgC == img.shape[2]
        imgW = int((32 * max_wh_ratio))
        h, w = img.shape[:2]
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
        resized_image = cv2.resize(img, (resized_w, imgH))
        resized_image = resized_image.astype('float32')
        resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

    ## 推理检测图片中的部分
    def get_boxes(self):
        img_ori = self.img
        img_part = img_ori.copy()
        data_part = {'image': img_part}
        data_part = self.transform(data_part, self.infer_before_process_op)
        img_part, shape_part_list = data_part
        img_part = np.expand_dims(img_part, axis=0)
        shape_part_list = np.expand_dims(shape_part_list, axis=0)
        inputs_part = {self.onet_det_session.get_inputs()[0].name: img_part}
        outs_part = self.onet_det_session.run(None, inputs_part)

        post_res_part = self.det_re_process_op(outs_part[0], shape_part_list)
        dt_boxes_part = post_res_part[0]['points']
        dt_boxes_part = self.filter_tag_det_res(dt_boxes_part, img_ori.shape)
        dt_boxes_part = self.sorted_boxes(dt_boxes_part)
        return dt_boxes_part

    ### 根据bounding box得到单元格图片
    def get_rotate_crop_image(self, img, points):
        img_crop_width = int(
            max(
                np.linalg.norm(points[0] - points[1]),
                np.linalg.norm(points[2] - points[3])))
        img_crop_height = int(
            max(
                np.linalg.norm(points[0] - points[3]),
                np.linalg.norm(points[1] - points[2])))
        pts_std = np.float32([[0, 0], [img_crop_width, 0],
                              [img_crop_width, img_crop_height],
                              [0, img_crop_height]])
        M = cv2.getPerspectiveTransform(points, pts_std)
        dst_img = cv2.warpPerspective(
            img,
            M, (img_crop_width, img_crop_height),
            borderMode=cv2.BORDER_REPLICATE,
            flags=cv2.INTER_CUBIC)
        dst_img_height, dst_img_width = dst_img.shape[0:2]
        if dst_img_height * 1.0 / dst_img_width >= 1.5:
            dst_img = np.rot90(dst_img)
        return dst_img

    ### 单张图片推理
    def get_img_res(self, onnx_model, img, process_op):
        h, w = img.shape[:2]
        img = self.resize_norm_img(img, w * 1.0 / h)
        img = img[np.newaxis, :]
        inputs = {onnx_model.get_inputs()[0].name: img}
        outs = onnx_model.run(None, inputs)
        result = process_op(outs[0])
        return result

    def recognition_img(self, dt_boxes):
        img_ori = self.img
        img = img_ori.copy()
        ### 识别过程
        ## 根据bndbox得到小图片
        img_list = []
        for box in dt_boxes:
            tmp_box = copy.deepcopy(box)
            img_crop = self.get_rotate_crop_image(img, tmp_box)
            img_list.append(img_crop)

        ## 识别小图片
        results = []
        results_info = []
        for pic in img_list:
            res = self.get_img_res(self.onet_rec_session, pic, self.postprocess_op)
            results.append(res[0])
            results_info.append(res)
        return results, results_info

if __name__=='__main__':
    import os

    img_path = "../test_img"

    for name in os.listdir(img_path):
        time1 = time.time()
        image = cv2.imread(os.path.join(img_path,name))
        # 读取图片
        # image = cv2.imread('./7.png')
        # OCR-检测-识别
        ocr_sys = det_rec_functions(image)
        # 得到检测框
        dt_boxes = ocr_sys.get_boxes()

        # 识别 results: 单纯的识别结果,results_info: 识别结果+置信度
        results, results_info = ocr_sys.recognition_img(dt_boxes)
        time2 = time.time()
        print(time2-time1)
        print(results)
        print()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392
  • 393
  • 394
  • 395
  • 396
  • 397
  • 398
  • 399
  • 400
  • 401
  • 402
  • 403
  • 404
  • 405
  • 406
  • 407
  • 408
  • 409
  • 410
  • 411
  • 412
  • 413
  • 414
  • 415
  • 416
  • 417
  • 418
  • 419
  • 420
  • 421
  • 422
  • 423
  • 424
  • 425
  • 426
  • 427
  • 428
  • 429
  • 430
  • 431
  • 432
  • 433
  • 434
  • 435
  • 436
  • 437
  • 438
  • 439
  • 440
  • 441
  • 442
  • 443
  • 444
  • 445
  • 446
  • 447
  • 448
  • 449
  • 450
  • 451
  • 452
  • 453
  • 454
  • 455
  • 456
  • 457
  • 458
  • 459
  • 460
  • 461
  • 462
  • 463
  • 464
  • 465
  • 466
  • 467
  • 468
  • 469
  • 470
  • 471
  • 472
  • 473
  • 474
  • 475
  • 476
  • 477
  • 478
  • 479
  • 480
  • 481
  • 482
  • 483
  • 484
  • 485
  • 486
  • 487
  • 488
  • 489
  • 490
  • 491
  • 492
  • 493
  • 494
  • 495
  • 496
  • 497
  • 498
  • 499
  • 500
  • 501
  • 502
  • 503
  • 504
  • 505
  • 506
  • 507
  • 508
  • 509
  • 510
  • 511
  • 512
  • 513
  • 514
  • 515
  • 516
  • 517
  • 518
  • 519
  • 520
  • 521
  • 522
  • 523
  • 524
  • 525
  • 526
  • 527
  • 528
  • 529
  • 530
  • 531
  • 532
  • 533
  • 534
  • 535
  • 536
  • 537
  • 538
  • 539
  • 540
  • 541
  • 542
  • 543
  • 544
  • 545
  • 546
  • 547
  • 548
  • 549
  • 550
  • 551
  • 552
  • 553
  • 554
  • 555
  • 556
  • 557
  • 558
  • 559
  • 560
  • 561
  • 562
  • 563
  • 564
  • 565
  • 566
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/679064
推荐阅读
相关标签
  

闽ICP备14008679号