当前位置:   article > 正文

【MindStudio训练营第一期】垃圾分类案例代码详解_aclliteresource

aclliteresource

该代码基于基于垃圾分类案例
直接看classify_test.py中的main函数了解流程:

    acl_resource = AclLiteResource() #创建一个AclResource类的实例
    acl_resource.init() ##AscendCL资源初始化(申请device、创建Context、创建Stream)
    model = AclLiteModel(MODEL_PATH)
  • 1
  • 2
  • 3

这里主要做的事情:

  1. 运行管理资源申请:用于初始化系统内部资源,固定的调用流程。
  2. 加载模型文件并构建输出内存

之后是读取需要处理的图像

dvpp = AclLiteImageProc(acl_resource)

使用DVPP处理图像的代码
 resized_image = pre_process(image, dvpp)
# 这个函数我们再看看它做了什么
def pre_process(image, dvpp):
    """preprocess"""
    image_input = image.copy_to_dvpp()
    yuv_image = dvpp.jpegd(image_input)

    print("decode jpeg end")
    resized_image = dvpp.resize(yuv_image, 
                    MODEL_WIDTH, MODEL_HEIGHT)

    print("resize yuv end")
    return resized_image
# 不难看出,做了两件事,使用DVPP的JPEGD接口解码图片,
# JPEGD解码出来的是YUV格式。然后将图片进行缩放

# 其中resize的尺寸由全局变量得到
MODEL_WIDTH = 224
MODEL_HEIGHT = 224
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

这里使用DVPP的JPEGD接口解码图片,JPEGD解码出来的是YUV格式。然后将图片进行缩放。
DVPP(数字视觉预处理) 是昇腾AI软件栈中的编解码和图像转换模块,包含6个模块,视频解码(VDEC)模块、视频编码(VENC)模块、JPEG解码(JPEGD)模块、JPEG编码(JPEGE)模块、PNG解码(PNGD)模块和视觉预处理(VPC)模块。

模型推理

result = model.execute([resized_image,])  
  • 1

后处理
得到模型推理结果后只是一些向量数字,数字的大小代表了属于不同类的概率,选择最大置信度即可。

post_process(result, image_file)    
def post_process(infer_output, image_file):
    print("post process")
    data = infer_output[0] #获取推理输出
    vals = data.flatten() # 展平
    top_k = vals.argsort()[-1:-6:-1] # 排序结果
    object_class = get_image_net_class(top_k[0]) #选择置信度最高的结果
    output_path = os.path.join(os.path.join(SRC_PATH, "../out"), os.path.basename(image_file))
    origin_image = Image.open(image_file)
    draw = ImageDraw.Draw(origin_image)
    font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", size=20)
    font.size =50
    draw.text((10, 50), object_class, font=font, fill=255) #将预测的结果放置到图片上
    origin_image.save(output_path) #保存有预测结果的输出
    object_class = get_image_net_class(top_k[0])        
    return 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/338975
推荐阅读
相关标签
  

闽ICP备14008679号