当前位置:   article > 正文

【2023 · CANN训练营第一季】模型的加载与推理_推理场景下,模型加载支持哪些方式

推理场景下,模型加载支持哪些方式

【2023 · CANN训练营第一季】模型的加载与推理

文档参考:CANN文档社区版: 6.0.RC1.alpha001

在模型转换,数据预处理后,我的学习终于终于来到了模型的加载与推理。是时候用我预处理后的图片加上我ATC转换的模型进行一个推理和目标检测的步骤了!

一、接口调用流程

首次进行代码学习的时候,其实我们可以参考官方的文档。

模型加载:

以下为使用不同模型的加载接口的示意图:

如果我们要使用同一个模型加载接口,还可以创建一个config的形式来固定加载接口:

note:

1.当由用户管理内存时,为确保内存不浪费,在申请工作内存、权值内存前,需要调用aclmdlQuerySize接口查询模型运行时所需工作内存、权值内存的大小。

2.如果模型输入数据的Shape不确定,则不能调用aclmdlQuerySize接口查询内存大小,在加载模型时,就无法由用户管理内存,因此需选择由系统管理内存的模型加载接口(例如,aclmdlLoadFromFile、aclmdlLoadFromMem)。

模型执行:

模型执行的参考流程如下:

输入/输出数据类型结构准备:

这里还用到了一个数据类型集合的使用:

使用aclmdlDataset类型的数据描述模型的输入/输出数据,因为模型可能存在多个输入、多个输出

调用aclmdlDataset类型下的操作接口添加aclDataBuffer类型的数据、获取aclDataBuffer的个数等

每个输入/输出的内存地址、内存大小用aclDataBuffer类型的数据来描述。调用aclDataBuffer类型下的操作接口获取内存地址、内存大小等。

上述数据类型结构准备流程:

模型卸载:

模型卸载主要就是要注意不要有漏网之鱼,毕竟我们创建的东西太多了,没有释放会导致内存泄漏等一系列问题,可以参考以下内容检查是否有释放完全:

1.卸载模型

2.释放模型描述信息

3.释放模型运行的工作内存

4.释放模型运行的权值内存

二、接口简介

以下简单介绍一下我用到了的几个API接口:

note:模型加载、模型执行、模型卸载的操作必须在同一个Context下。这就意味着,我们必须在同一个线程下进行使用,因为context无法跨线程使用

/*
函数功能:
从文件加载离线模型数据(适配昇腾AI处理器的离线模型),由用户自行管理模型运行的内存,同步接口。
系统完成模型加载后,返回的模型ID,作为后续操作时用于识别模型的标志。

参数:
modelPath:
离线模型文件路径的指针,路径中包含文件名。运行程序(APP)的用户需要对该存储路径有访问权限。

modelId:
输入图片信息的指针。
模型ID的指针。系统成功加载模型后会返回的模型ID。

workPt:
Device上模型所需工作内存(存放模型执行过程中的临时数据)的地址指针,由用户自行管理,模型执行过程中不能释放该内存。
如果在workPtr参数处传入空指针,表示由系统管理内存。

workSize:
模型所需工作内存的大小,单位Byte。workPtr为空指针时无效。

weightPtr:
Device上模型权值内存(存放权值数据)的地址指针,由用户自行管理,模型执行过程中不能释放该内存。
如果在weightPtr参数处传入空指针,表示由系统管理内存。

weightSize:
模型所需权值内存的大小,单位Byte。weightPtr为空指针时无效。

*/
aclError aclmdlLoadFromFileWithMem(const char *modelPath, uint32_t *modelId, void *workPtr, 
						size_t workSize, void *weightPtr, size_t weightSize)

/*
函数功能:
根据模型文件获取模型执行时所需的权值内存大小、工作内存大小,同步接口。

参数:
fileName:
离线模型文件路径的指针,路径中包含文件名。运行程序(APP)的用户需要对该存储路径有访问权限。

workSize:
模型执行时所需的工作内存大小的指针,单位Byte。

weightSize:
模型执行时所需权值内存大小的指针,单位Byte。
*/
aclError aclmdlQuerySize(const char *fileName, size_t *workSize, size_t *weightSize)

/*
函数功能:
执行模型推理,直到返回推理结果,同步接口。

参数:
modelId:
指定需要执行推理的模型的ID。

input:
模型推理的输入数据的指针。

output:
模型推理的输出数据的指针。

note:
若用户使用aclrtMalloc或aclrtMallocHost接口申请大块内存并自行划分、管理内存时,用户在管理内存时,模型输入数据的内存有对齐和补齐要求,首地址128字节对齐,对齐后再加32字节,然后再向上对齐到128字节。
*/
aclError aclmdlExecute(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65

创建输入输出类型结构的相关接口可以参考手册了解,太多了这里就不写了。

三、代码验证与学习

代码思路:

1.初始化相关通道,创建输入输出所需的一些内容,加载模型

2.读取图片,将图片整合进输入流,执行模型

3.读出结果,处理模型输出的数据

4.依次销毁所有创建的环境

接下来我们直接进入我们的代码运行过程!!!

进行代码编译,编译成功!

首先,我们先进行代码验证,在之前的文章中,我自己用ATC工具,转了一个om模型,首次验证,我们先不使用它,因为我无法保证我自己的模型是否正常,所以我的思路是,先找了一个确定能行的我用过的om模型,进行代码验证。如果我要是用我自己的模型,这样排查是我的模型加载相关代码异常还是我的模型本身就有异常会造成一定的弯路。

代码验证如下:

结果是识别到了两个小车车,一个小停止路牌标志,让我们回顾一下,我的resize文章中输出的640x640的原图。

能识别出两个车,效果还算不错了,虽然挡住的那个不太好识别。各位有兴趣的小伙伴也可以用cv画出bounding box看看识别情况

接下来!我们就使用我之前ATC的文章中转好的yolov5s_bs.om进行验证吧!!!!!!

结果如下:

哦豁!有错误,让我们看看这个错误码是什么意思!

查看昇腾的C++应用文档对比错误码,发现是参数校验失败,根据我们上面的验证思路,问题可能出在了模型上。

将两个模型用netron打开进行对比:
这是正常模型的输入这里,可以看到还是进行了AIPP的,但是不知道为啥没有显示这个算子

这是我自己转的om的情况:

接下来我们再查看输出,这是运行正常的模型:

以下这是我的异常模型:

同志们!!!!我们发现了问题,这怎么是三个特征图在输出呢!

这里我们需要用到YoloPreDetection和YoloV5DetectionOutput两个算子,对输出的onnx模型进行调整。

以下参考昇腾社区modelZoo的his版本中的yolov5案例

首先,

(1)我们修改models/common.py文件,对其中Focus的forward函数做修改

class Focus(nn.Module):
    def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2)
        # <==== 修改内容
        if torch.onnx.is_in_onnx_export():
            a, b = x[..., ::2, :].transpose(-2, -1), x[..., 1::2, :].transpose(-2, -1)
            c = torch.cat([a[..., ::2, :], b[..., ::2, :], a[..., 1::2, :], b[..., 1::2, :]], 1).transpose(-2, -1)
            return self.conv(c)
        else:
            return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
        # return self.conv(self.contract(x))
        # =====>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

(2)修改models/yolo.py脚本,使后处理部分不被导出

class Detect(nn.Module):
    def forward(self, x):
        # ...
        # <==== 修改内容
        self.training = True  # v6.0版本需补充该行
        # =====>
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
            # <==== 修改内容
            if torch.onnx.is_in_onnx_export():
                continue
            # =====>
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
        # ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

(3)修改models/experimental.py文件,将其中的attempt_download()所在行注释掉

def attempt_load(weights, map_location=None):
    # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
    model = Ensemble()
    for w in weights if isinstance(weights, list) else [weights]:
        # <==== 修改内容
        # attempt_download(w)  # v6.0中attempt_download(w)被下一行代码引用,修改如该处所示
        # =====>
        ckpt = torch.load(w, map_location=map_location)  # load
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

(4)修改models/export.py文件,将转换的onnx算子版本设为11,v6.0无需修改

torch.onnx.export(model, img, f, verbose=False, opset_version=11, input_names=['images'], do_constant_folding=True,
                  output_names=['classes', 'boxes'] if y is None else ['output'])
                  dynamic_axes={'images': {0: 'batch'},
                                'output': {0: 'batch'}} if opt.dynamic else None)
  • 1
  • 2
  • 3
  • 4

修改完成,我们把模型放在yolov5文件夹下,使用官方的export.py运行如下命令

export PYTHONPATH=`pwd`:$PYTHONPATH
python models/export.py --weights=./yolov5s.pt --img-size=640 --batch-size 1  # 用于v2.0->v5.0
python export.py --weights=./yolov5s.pt --imgsz=640 --batch-size=1 --opset=11  # 用于v6.0
  • 1
  • 2
  • 3

这里我使用的是v6.0版本的git版本,在windows下运行也不用export那个pwd获取的路径

添加后处理算子

python modify_model.py --model=yolov5s.onnx --conf-thres=0.4 --iou-thres=0.5  # 非量化模型
  • 1

在这里生成了一个yolov5s_t.onnx

使用atc命令进行转换:

atc --model=yolov5s_t.onnx --framework=5 --output=yolov5s_t_bs1 --input_format=NCHW --input_shape="images:1,3,640,640;img_info:1,4" --log=error --soc_version=Ascend310 --insert_op_conf=insert_op.cfg
  • 1

转换结果如下:

生成了yolov5s_t_bs1.om

将该模型使用我们代码进行加载和执行:

看到推理成功并且输出了很多很多很多的结果!说明我们终于终于终于成功了。

但是保留一个我自己的小疑问是不知道为啥推理的结果有这么多,而且结果不太对,等后面有空我在讨论一下这个模型本身的问题。本文我们要解决的是如何让我们的模型顺利的加载并推理。我们完成了我们的目标。

以下附上我的测试的代码。写的有些粗糙,因为最近时间比较紧张。

runInfer.cpp

#include "inference.h"

Result paramsInit(int32_t& deviceId,aclrtRunMode& runMode_,aclrtContext& context,aclrtStream& stream)
{
    const char *aclConfigPath = "./acl.json";
  
    aclError ret = aclInit(aclConfigPath);
    if (ret != ACL_SUCCESS) {
        ERROR_LOG("acl init failed, errorCode = %d", static_cast<int32_t>(ret));
        return FAILED;
    }
    INFO_LOG("acl init success");
  
    ret = aclrtSetDevice(deviceId);
    if (ret != ACL_SUCCESS) {
        ERROR_LOG("acl set device %d failed, errorCode = %d", deviceId, static_cast<int32_t>(ret));
        return FAILED;
    }
    INFO_LOG("set device %d success", deviceId);

    // create context (set current)
    ret = aclrtCreateContext(&context, deviceId);
    if (ret != ACL_SUCCESS) {
        ERROR_LOG("acl create context failed, deviceId = %d, errorCode = %d",
                  deviceId, static_cast<int32_t>(ret));
        return FAILED;
    }
    INFO_LOG("create context success");

    // create stream
    ret = aclrtCreateStream(&stream);
    if (ret != ACL_SUCCESS) {
        ERROR_LOG("acl create stream failed, deviceId = %d, errorCode = %d",
                  deviceId, static_cast<int32_t>(ret));
        return FAILED;
    }
    INFO_LOG("create stream success");

     // get run mode
    ret = aclrtGetRunMode(&runMode_);
    if (ret != ACL_SUCCESS) {
        ERROR_LOG("acl get run mode failed, errorCode = %d", static_cast<int32_t>(ret));
        return FAILED;
    }
    INFO_LOG("get run mode success");

    return SUCCESS;
}

Result modelsInit(string modelsPath,uint32_t modelW,uint32_t modelH,aclrtRunMode& runMode_)
{
    Result ret;

    ret = LoadModelFromFileWithMem(modelsPath);
    if (ret != SUCCESS) {
        ERROR_LOG("execute LoadModelFromFileWithMem failed");
        return FAILED;
    }

    ret = CreateDesc();
    if (ret != SUCCESS) {
        ERROR_LOG("execute CreateDesc failed");
        return FAILED;
    }

    ret = CreateOutput();
    if (ret != SUCCESS) {
        ERROR_LOG("execute CreateOutput failed");
        return FAILED;
    }

    ret = CreateImageInfoBuffer(modelW,modelH,runMode_);
    if (ret != SUCCESS) {
        ERROR_LOG("Create image info buf failed");
        return FAILED;
    }
    return SUCCESS;
}

int32_t destroyFunc(int32_t& deviceId,aclrtRunMode& runMode_,aclrtContext& context,aclrtStream& stream)
{
    aclError ret;

    ret =  DestroyResource();
    if (ret != ACL_SUCCESS) {
            ERROR_LOG("destroy models resource failed");
    }

    if (stream != nullptr) {
        ret = aclrtDestroyStream(stream);
        if (ret != ACL_SUCCESS) {
            ERROR_LOG("destroy stream failed");
        }
        stream = nullptr;
    }
    INFO_LOG("end to destroy stream");

    if (context != nullptr) {
        ret = aclrtDestroyContext(context);
        if (ret != ACL_SUCCESS) {
            ERROR_LOG("destroy context failed");
        }
        context = nullptr;
    }
    INFO_LOG("end to destroy context");

    ret = aclrtResetDevice(deviceId);
    if (ret != ACL_SUCCESS) {
        ERROR_LOG("reset device failed");
    }
    INFO_LOG("end to reset device is %d", deviceId);
  
    ret = aclFinalize();
    if (ret != ACL_SUCCESS) {
        ERROR_LOG("finalize acl failed");
    }
    INFO_LOG("end to finalize acl");
    return 0;
}


int main()
{
    int32_t  deviceId=0;
    aclrtRunMode runMode_;
    aclrtContext context;
    aclrtStream stream;
    PicDesc picDesc;
    uint32_t modelWidth_ = 640;
    uint32_t modelHeight_ =640;
    aclmdlDataset* inferenceOutput;
    string modelPath_ = "/home/HwHiAiUser/cann_learn/infer/models/yolov5s_t_bs1.om";
    string img_path = "/home/HwHiAiUser/cann_learn/infer/img/";

    struct  timeval rStart,rEnd;
    double rTimeuse;

    Result res = paramsInit(deviceId,runMode_,context,stream);/*acl相关初始化*/
    if (res != SUCCESS)
    {
        ERROR_LOG("ACL Init Failed! proccess end...");
        return -1;
    }
    res = modelsInit(modelPath_,modelWidth_,modelHeight_,runMode_);/*模型加载和相关初始化*/
    if (res != SUCCESS)
    {
        ERROR_LOG("Models Init Failed! proccess end...");
        return -2;
    }

    picDesc.picName = img_path+"1.yuv";
    picDesc.width = 640;
    picDesc.height = 640;
    picDesc.dataSize = picDesc.width * picDesc.height * 3 /2;
    picDesc.data = getDeviceBufferOfPicture(picDesc,picDesc.dataSize,runMode_);

#if 1/*时间记录*/
    gettimeofday(&rStart,NULL);
#endif
    res = inferExecute(inferenceOutput,picDesc);

#if 1/*时间记录*/
    gettimeofday(&rEnd,NULL);
    rTimeuse = 1000000*(rEnd.tv_sec - rStart.tv_sec) + \
                (rEnd.tv_usec - rStart.tv_usec);
#endif
    INFO_LOG("models execute time: %f ms",(rTimeuse/1000));

    Postprocess(picDesc,inferenceOutput,runMode_,modelWidth_,modelHeight_);

    destroyFunc(deviceId,runMode_,context,stream);

    return 0;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174

inference.cpp(推理功能实现文件)

#include "inference.h"

using namespace std;

bool g_loadFlag_;  // model load flag
uint32_t g_modelId_;
void *g_modelMemPtr_;
size_t g_modelMemSize_;
void *g_modelWeightPtr_;
size_t g_modelWeightSize_;
aclmdlDesc *g_modelDesc_;
aclmdlDataset *g_input_;
aclmdlDataset *g_output_;

void* imageInfoBuf_ = nullptr;
void* imageDataBuf_ = nullptr ;
uint32_t imageInfoSize_ = 0;

namespace {
    const static std::vector<std::string> yolov3Label = { "person", "bicycle", "car", "motorbike",
        "aeroplane", "bus", "train", "truck", "boat",
        "traffic light", "fire hydrant", "stop sign", "parking meter",
        "bench", "bird", "cat", "dog", "horse",
        "sheep", "cow", "elephant", "bear", "zebra",
        "giraffe", "backpack", "umbrella", "handbag", "tie",
        "suitcase", "frisbee", "skis", "snowboard", "sports ball",
        "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
        "tennis racket", "bottle", "wine glass", "cup",
        "fork", "knife", "spoon", "bowl", "banana",
        "apple", "sandwich", "orange", "broccoli", "carrot",
        "hot dog", "pizza", "donut", "cake", "chair",
        "sofa", "potted plant", "bed", "dining table", "toilet",
        "TV monitor", "laptop", "mouse", "remote", "keyboard",
        "cell phone", "microwave", "oven", "toaster", "sink",
        "refrigerator", "book", "clock", "vase", "scissors",
        "teddy bear", "hair drier", "toothbrush" };

    const uint32_t g_bBoxDataBufId = 0;
    const uint32_t g_boxNumDataBufId = 1;

    enum BBoxIndex { TOPLEFTX = 0, TOPLEFTY, BOTTOMRIGHTX, BOTTOMRIGHTY, SCORE, LABEL };
}

void initParams()
{
    g_loadFlag_=false;
    g_modelId_=0;
    g_modelMemPtr_=nullptr; 
    g_modelMemSize_=0;
    g_modelWeightPtr_=nullptr; 
    g_modelWeightSize_=0;
    g_modelDesc_=nullptr;
    g_input_=nullptr;
    g_output_=nullptr; 
}

void* getDeviceBufferOfPicture(PicDesc &picDesc, uint32_t &devPicBufferSize,aclrtRunMode runMode){
    if (picDesc.picName.empty()) {
        ERROR_LOG("picture file name is empty");
        return nullptr;
        }

    FILE *fp = fopen(picDesc.picName.c_str(), "rb");
    if (fp == nullptr) {
        ERROR_LOG("open file %s failed", picDesc.picName.c_str());
        return nullptr;
        }

    fseek(fp, 0, SEEK_END);
    uint32_t fileLen = ftell(fp);
    fseek(fp, 0, SEEK_SET);

    uint32_t inputBuffSize = fileLen;

    char* inputBuff = new(std::nothrow) char[inputBuffSize];
    size_t readSize = fread(inputBuff, sizeof(char), inputBuffSize, fp);
    if (readSize < inputBuffSize) {
        ERROR_LOG("need read file %s %u bytes, but only %zu readed",
        picDesc.picName.c_str(), inputBuffSize, readSize);
        delete[] inputBuff;
		fclose(fp);
        return nullptr;
        }

    INFO_LOG("get yuv image info successed, width=%d, height=%d, picSize=%d", picDesc.width, picDesc.height, picDesc.dataSize);
  
    void *inBufferDev = nullptr;
    aclError ret = acldvppMalloc(&inBufferDev, inputBuffSize);
    if (ret !=  ACL_SUCCESS) {
        delete[] inputBuff;
        ERROR_LOG("malloc device data buffer failed, aclRet is %d", ret);
		fclose(fp);
        return nullptr;
    }

    if (runMode == ACL_HOST) {
        ret = aclrtMemcpy(inBufferDev, inputBuffSize, inputBuff, inputBuffSize, ACL_MEMCPY_HOST_TO_DEVICE);
    }
    else {
        ret = aclrtMemcpy(inBufferDev, inputBuffSize, inputBuff, inputBuffSize, ACL_MEMCPY_DEVICE_TO_DEVICE);
    }
    if (ret != ACL_SUCCESS) {
        ERROR_LOG("memcpy failed. Input host buffer size is %u",
        inputBuffSize);
        acldvppFree(inBufferDev);

        delete[] inputBuff;
		fclose(fp);
        return nullptr;
    }
  
    delete[] inputBuff;
    devPicBufferSize = inputBuffSize;
	fclose(fp);
    return inBufferDev; 
}

uint32_t SaveOutputImg(const void *devPtr, uint32_t dataSize,aclrtRunMode runMode)
{
    DIR *dp;
    // string output_path = "/home/HwHiAiUser/cann_learn/resize/outYuv";
    string output_path = "/home/HwHiAiUser/resize/outYuv";
    string saveName = output_path+"/outResizeImg.yuv";

    dp = opendir(output_path.c_str());
	if (dp == NULL)
	{
		INFO_LOG("no this dir: %s ,creat it",output_path.c_str());
        mkdir(output_path.c_str(),0777);
	}
    else
    {
        closedir(dp);
    }

    FILE* outFileFp = fopen(saveName.c_str(), "wb+");
    if (runMode == ACL_HOST) {
        void* hostPtr = nullptr;
        aclrtMallocHost(&hostPtr, dataSize);
        aclrtMemcpy(hostPtr, dataSize, devPtr, dataSize, ACL_MEMCPY_DEVICE_TO_HOST);
        fwrite(hostPtr, sizeof(char), dataSize, outFileFp);
        (void)aclrtFreeHost(hostPtr);
    } else {
        fwrite(devPtr, sizeof(char), dataSize, outFileFp);
    }
    fflush(outFileFp);
    fclose(outFileFp);
    INFO_LOG("save img success!img in Path:%s",saveName.c_str());
    return SUCCESS;
}

Result LoadModelFromFileWithMem(const string modelPath)
{
    if (g_loadFlag_) {
        ERROR_LOG("has already loaded a model");
        return FAILED;
    }

    aclError ret = aclmdlQuerySize(modelPath.c_str(), &g_modelMemSize_, &g_modelWeightSize_);
    if (ret != ACL_SUCCESS) {
        ERROR_LOG("query model failed, model file is %s", modelPath.c_str());
        return FAILED;
    }

    ret = aclrtMalloc(&g_modelMemPtr_, g_modelMemSize_, ACL_MEM_MALLOC_HUGE_FIRST);
    if (ret != ACL_SUCCESS) {
        ERROR_LOG("malloc buffer for mem failed, require size is %zu", g_modelMemSize_);
        return FAILED;
    }

    ret = aclrtMalloc(&g_modelWeightPtr_, g_modelWeightSize_, ACL_MEM_MALLOC_HUGE_FIRST);
    if (ret != ACL_SUCCESS) {
        ERROR_LOG("malloc buffer for weight failed, require size is %zu", g_modelWeightSize_);
        return FAILED;
    }

    ret = aclmdlLoadFromFileWithMem(modelPath.c_str(), &g_modelId_, g_modelMemPtr_,
    g_modelMemSize_, g_modelWeightPtr_, g_modelWeightSize_);
    if (ret != ACL_SUCCESS) {
        ERROR_LOG("load model from file failed, model file is %s", modelPath.c_str());
        return FAILED;
    }

    g_loadFlag_ = true;
    INFO_LOG("load model %s success,model id is %d", modelPath.c_str(),g_modelId_);
    return SUCCESS;
}

Result CreateDesc()
{
    g_modelDesc_ = aclmdlCreateDesc();
    if (g_modelDesc_ == nullptr) {
        ERROR_LOG("create model description failed");
        return FAILED;
    }

    aclError ret = aclmdlGetDesc(g_modelDesc_, g_modelId_);
    if (ret != ACL_SUCCESS) {
        ERROR_LOG("get model description failed");
        return FAILED;
    }

    INFO_LOG("create model description success");
    return SUCCESS;
}

Result CreateOutput()
{
    if (g_modelDesc_ == nullptr) {
        ERROR_LOG("no model description, create ouput failed");
        return FAILED;
    }

    g_output_ = aclmdlCreateDataset();
    if (g_output_ == nullptr) {
        ERROR_LOG("can't create dataset, create output failed");
        return FAILED;
    }

    size_t outputSize = aclmdlGetNumOutputs(g_modelDesc_);
    for (size_t i = 0; i < outputSize; ++i) {
        size_t buffer_size = aclmdlGetOutputSizeByIndex(g_modelDesc_, i);

        void *outputBuffer = nullptr;
        aclError ret = aclrtMalloc(&outputBuffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
        if (ret != ACL_SUCCESS) {
            ERROR_LOG("can't malloc buffer, size is %zu, create output failed", buffer_size);
            return FAILED;
        }

        aclDataBuffer* outputData = aclCreateDataBuffer(outputBuffer, buffer_size);
        if (ret != ACL_SUCCESS) {
            ERROR_LOG("can't create data buffer, create output failed");
            aclrtFree(outputBuffer);
            return FAILED;
        }

        ret = aclmdlAddDatasetBuffer(g_output_, outputData);
        if (ret != ACL_SUCCESS) {
            ERROR_LOG("can't add data buffer, create output failed");
            aclrtFree(outputBuffer);
            aclDestroyDataBuffer(outputData);
            return FAILED;
        }
    }

    INFO_LOG("create model output success");
    return SUCCESS;
}

void* CopyDataToDevice(void* data, uint32_t dataSize, aclrtMemcpyKind policy)
{
    void* buffer = nullptr;
    aclError aclRet = aclrtMalloc(&buffer, dataSize, ACL_MEM_MALLOC_HUGE_FIRST);
    if (aclRet != ACL_SUCCESS) {
        ERROR_LOG("malloc device data buffer failed, aclRet is %d", aclRet);
        return nullptr;
    }

    aclRet = aclrtMemcpy(buffer, dataSize, data, dataSize, policy);
    if (aclRet != ACL_SUCCESS) {
        ERROR_LOG("Copy data to device failed, aclRet is %d", aclRet);
        (void)aclrtFree(buffer);
        return nullptr;
    }

    return buffer;
}

void* CopyDataDeviceToLocal(void* deviceData, uint32_t dataSize)
{
    uint8_t* buffer = new uint8_t[dataSize];
    if (buffer == nullptr) {
        ERROR_LOG("New malloc memory failed");
        return nullptr;
    }

    aclError aclRet = aclrtMemcpy(buffer, dataSize, deviceData, dataSize, ACL_MEMCPY_DEVICE_TO_HOST);
    if (aclRet != ACL_SUCCESS) {
        ERROR_LOG("Copy device data to local failed, aclRet is %d", aclRet);
        delete[](buffer);
        return nullptr;
    }

    return (void*)buffer;
}

void* CopyDataDeviceToDevice(void* deviceData, uint32_t dataSize)
{
    return CopyDataToDevice(deviceData, dataSize, ACL_MEMCPY_DEVICE_TO_DEVICE);
}

void* CopyDataHostToDevice(void* deviceData, uint32_t dataSize)
{
    return CopyDataToDevice(deviceData, dataSize, ACL_MEMCPY_HOST_TO_DEVICE);
}

Result CreateImageInfoBuffer(uint32_t modelWidth,uint32_t modelHeight,aclrtRunMode& runMode_)
{
    const float imageInfo[4] = {(float)modelWidth, (float)modelWidth,
                                (float)modelHeight, (float)modelHeight};
    imageInfoSize_ = sizeof(imageInfo);
    aclError ret = aclrtGetRunMode(&runMode_);
    if (ret != ACL_SUCCESS) {
        ERROR_LOG("acl get run mode failed, errorCode = %d", static_cast<int32_t>(ret));
        return FAILED;
    }
    if (runMode_ == ACL_HOST)
        imageInfoBuf_ = CopyDataHostToDevice((void *)imageInfo, imageInfoSize_);
    else
        imageInfoBuf_ = CopyDataDeviceToDevice((void *)imageInfo, imageInfoSize_);
    if (imageInfoBuf_ == nullptr) {
        ERROR_LOG("Copy image info to device failed");
        return FAILED;
    }

    return SUCCESS;
}

Result DestroyResource()
{
  
    if (!g_loadFlag_) {
        WARN_LOG("no model had been loaded, unload failed");
        return FAILED;
    }

    aclError ret = aclmdlUnload(g_modelId_);
    if (ret != ACL_SUCCESS) {
        ERROR_LOG("unload model failed, modelId is %u", g_modelId_);
    }

    if (g_modelDesc_ != nullptr) {
        (void)aclmdlDestroyDesc(g_modelDesc_);
        g_modelDesc_ = nullptr;
    }

    if (g_modelMemPtr_ != nullptr) {
        aclrtFree(g_modelMemPtr_);
        g_modelMemPtr_ = nullptr;
        g_modelMemSize_ = 0;
    }

    if (g_modelWeightPtr_ != nullptr) {
        aclrtFree(g_modelWeightPtr_);
        g_modelWeightPtr_ = nullptr;
        g_modelWeightSize_ = 0;
    }

    g_loadFlag_ = false;
    INFO_LOG("unload model success, modelId is %u", g_modelId_);

    if (g_modelDesc_ != nullptr) {
        (void)aclmdlDestroyDesc(g_modelDesc_);
        g_modelDesc_ = nullptr;
    }

    if (g_output_ == nullptr) {
        return FAILED;
    }

    for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(g_output_); ++i) {
        aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(g_output_, i);
        void* data = aclGetDataBufferAddr(dataBuffer);
        (void)aclrtFree(data);
        (void)aclDestroyDataBuffer(dataBuffer);
    }

    (void)aclmdlDestroyDataset(g_output_);
    g_output_ = nullptr;
    return SUCCESS;
}

Result CreateInput(void *input1, size_t input1size,void* input2, size_t input2size)
{   
    g_input_ = aclmdlCreateDataset();
    if (g_input_ == nullptr) {
        ERROR_LOG("can't create dataset, create input failed");
        return FAILED;
    }

    aclDataBuffer* inputData = aclCreateDataBuffer(input1, input1size);
    if (inputData == nullptr) {
        ERROR_LOG("can't create data buffer, create input failed");
        return FAILED;
    }

    aclmdlAddDatasetBuffer(g_input_, inputData);
    if (inputData == nullptr) {
        ERROR_LOG("can't add data buffer, create input failed");
        aclDestroyDataBuffer(inputData);
        inputData = nullptr;
        return FAILED;
    }

    aclDataBuffer* inputData2 = aclCreateDataBuffer(input2, input2size);
    if (inputData2 == nullptr) {
        ERROR_LOG("can't create data buffer, create input failed");
        return FAILED;
    }

    aclmdlAddDatasetBuffer(g_input_, inputData2);
    if (inputData2 == nullptr) {
        ERROR_LOG("can't add data buffer, create input failed");
        aclDestroyDataBuffer(inputData2);
        inputData = nullptr;
        return FAILED;
    }

    return SUCCESS;
}

void DestroyInput()
{
    if (g_input_ == nullptr) {
        return;
    }

    for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(g_input_); ++i) {
        aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(g_input_, i);
        aclDestroyDataBuffer(dataBuffer);
    }
    aclmdlDestroyDataset(g_input_);
    g_input_ = nullptr;
}

Result inferExecute(aclmdlDataset*& inferenceOutput, PicDesc &picDesc)
{  
    aclError ret = CreateInput(picDesc.data,picDesc.dataSize,imageInfoBuf_, imageInfoSize_);
    if (ret != SUCCESS) {
        ERROR_LOG("Create mode input dataset failed");
        return FAILED;
    }

    ret = aclmdlExecute(g_modelId_, g_input_, g_output_);
    if (ret != ACL_SUCCESS) {
        ERROR_LOG("execute model failed, modelId is %u, error code is %d", g_modelId_,ret);
        return FAILED;
    }

    INFO_LOG("model execute success");

    inferenceOutput = g_output_;

    DestroyInput();
    (void)acldvppFree(picDesc.data);
    picDesc.data = nullptr;
    return SUCCESS;
}


void* GetInferenceOutputItem(uint32_t& itemDataSize,
                                           aclmdlDataset* inferenceOutput,
                                           uint32_t idx,
                                           aclrtRunMode& runMode_)
{
    aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(inferenceOutput, idx);
    if (dataBuffer == nullptr) {
        ERROR_LOG("Get the %dth dataset buffer from model "
                  "inference output failed", idx);
        return nullptr;
    }

    void* dataBufferDev = aclGetDataBufferAddr(dataBuffer);
    if (dataBufferDev == nullptr) {
        ERROR_LOG("Get the %dth dataset buffer address "
                  "from model inference output failed", idx);
        return nullptr;
    }

    size_t bufferSize = aclGetDataBufferSizeV2(dataBuffer);
    if (bufferSize == 0) {
        ERROR_LOG("The %dth dataset buffer size of "
                  "model inference output is 0", idx);
        return nullptr;
    }
    void* data = nullptr;
    if (runMode_ == ACL_HOST) {
        data = CopyDataDeviceToLocal(dataBufferDev, bufferSize);
        if (data == nullptr) {
            ERROR_LOG("Copy inference output to host failed");
            return nullptr;
        }
    } else {
        data = dataBufferDev;
    }

    itemDataSize = bufferSize;
    return data;
}

Result Postprocess(PicDesc &picDesc, 
                        aclmdlDataset*& modelOutput,aclrtRunMode& runMode_,
                        uint32_t modelWidth,uint32_t modelHeight)
{   
    struct  timeval tstart,tend;
    double timeuse;
    gettimeofday(&tstart,NULL);
    aclError ret = aclrtGetRunMode(&runMode_);
    if (ret != ACL_SUCCESS) {
        ERROR_LOG("acl get run mode failed, errorCode = %d", static_cast<int32_t>(ret));
        return FAILED;
    }
    uint32_t dataSize = 0;
    float* detectData = (float *)GetInferenceOutputItem(dataSize, modelOutput,
    g_bBoxDataBufId,runMode_);
    uint32_t* boxNum = (uint32_t *)GetInferenceOutputItem(dataSize, modelOutput,
    g_boxNumDataBufId,runMode_);

    if (boxNum == nullptr) return FAILED;
    uint32_t totalBox = boxNum[0];
    vector<BoundingBox> detectResults;

    float widthScale = (float)(picDesc.width) / modelWidth;
    float heightScale = (float)(picDesc.height) / modelHeight;

    for (uint32_t i = 0; i < totalBox; i++) {
        BoundingBox boundBox;

        uint32_t score = uint32_t(detectData[totalBox * SCORE + i] * 100);
        boundBox.lt_x = detectData[totalBox * TOPLEFTX + i] * widthScale;
        boundBox.lt_y = detectData[totalBox * TOPLEFTY + i] * heightScale;
        boundBox.rb_x = detectData[totalBox * BOTTOMRIGHTX + i] * widthScale;
        boundBox.rb_y = detectData[totalBox * BOTTOMRIGHTY + i] * heightScale;

        uint32_t objIndex = (uint32_t)detectData[totalBox * LABEL + i];
        boundBox.text = yolov3Label[objIndex] + std::to_string(score) + "\%";
        printf("%d %d %d %d %s\n", boundBox.lt_x, boundBox.lt_y,
               boundBox.rb_x, boundBox.rb_y, boundBox.text.c_str());

        detectResults.emplace_back(boundBox);
    }
  
    if (runMode_ == ACL_HOST) {
        delete[]((uint8_t *)detectData);
        delete[]((uint8_t*)boxNum);
    }
    gettimeofday(&tend,NULL);
    timeuse = 1000000*(tend.tv_sec - tstart.tv_sec) + \
				(tend.tv_usec - tstart.tv_usec);
    INFO_LOG("Postprocess time: %f ms",timeuse/1000);
    return SUCCESS;

}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392
  • 393
  • 394
  • 395
  • 396
  • 397
  • 398
  • 399
  • 400
  • 401
  • 402
  • 403
  • 404
  • 405
  • 406
  • 407
  • 408
  • 409
  • 410
  • 411
  • 412
  • 413
  • 414
  • 415
  • 416
  • 417
  • 418
  • 419
  • 420
  • 421
  • 422
  • 423
  • 424
  • 425
  • 426
  • 427
  • 428
  • 429
  • 430
  • 431
  • 432
  • 433
  • 434
  • 435
  • 436
  • 437
  • 438
  • 439
  • 440
  • 441
  • 442
  • 443
  • 444
  • 445
  • 446
  • 447
  • 448
  • 449
  • 450
  • 451
  • 452
  • 453
  • 454
  • 455
  • 456
  • 457
  • 458
  • 459
  • 460
  • 461
  • 462
  • 463
  • 464
  • 465
  • 466
  • 467
  • 468
  • 469
  • 470
  • 471
  • 472
  • 473
  • 474
  • 475
  • 476
  • 477
  • 478
  • 479
  • 480
  • 481
  • 482
  • 483
  • 484
  • 485
  • 486
  • 487
  • 488
  • 489
  • 490
  • 491
  • 492
  • 493
  • 494
  • 495
  • 496
  • 497
  • 498
  • 499
  • 500
  • 501
  • 502
  • 503
  • 504
  • 505
  • 506
  • 507
  • 508
  • 509
  • 510
  • 511
  • 512
  • 513
  • 514
  • 515
  • 516
  • 517
  • 518
  • 519
  • 520
  • 521
  • 522
  • 523
  • 524
  • 525
  • 526
  • 527
  • 528
  • 529
  • 530
  • 531
  • 532
  • 533
  • 534
  • 535
  • 536
  • 537
  • 538
  • 539
  • 540
  • 541
  • 542
  • 543
  • 544

inference.h

#ifndef DVPP_DECODE_H_
#define DVPP_DECODE_H_

#include <string>
#include <memory>
#include <iostream>
#include <sys/time.h>
#include <dirent.h>
#include <sys/stat.h>
#include <vector>
#include "acl/acl.h"
#include "acl/ops/acl_dvpp.h"

using namespace std;

#define INFO_LOG(fmt, args...) fprintf(stdout, "[INFO]  " fmt "\n", ##args)
#define WARN_LOG(fmt, args...) fprintf(stdout, "[WARN]  " fmt "\n", ##args)
#define ERROR_LOG(fmt, args...) fprintf(stdout, "[ERROR] " fmt "\n", ##args)

enum Result {
    SUCCESS = 0,
    FAILED = 1
};

typedef struct PicDesc {
    void *data;
    string picName;   
    uint32_t width;
    uint32_t height;
    uint32_t dataSize;
} PicDesc;

struct BoundingBox {
    uint32_t lt_x;
    uint32_t lt_y;
    uint32_t rb_x;
    uint32_t rb_y;
    uint32_t attribute;
    float score;
    std::string text;
};

Result LoadModelFromFileWithMem(const string modelPath);
Result CreateDesc();
Result DestroyResource();
Result CreateOutput();
Result CreateImageInfoBuffer(uint32_t modelWidth,uint32_t modelHeight,aclrtRunMode& runMode_);
Result inferExecute(aclmdlDataset*& inferenceOutput, PicDesc &picDesc);
void* getDeviceBufferOfPicture(PicDesc &picDesc, uint32_t &devPicBufferSize,aclrtRunMode runMode);
Result Postprocess(PicDesc &picDesc, aclmdlDataset*& modelOutput,
                        aclrtRunMode& runMode_, uint32_t modelWidth,uint32_t modelHeight);

#endif

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54

本文实现了模型的加载,推理,和后处理结果,如果有同学有需要,也可以使用cv,将后处理后的推理结果的boundingbox画出来看看效果哦。

ps:该文仅是为了记录CANN训练营的学习过程所用,不参与任何商业用途,有任何代码问题可以和我一起讨论修改哦!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/534000
推荐阅读
相关标签
  

闽ICP备14008679号