赞
踩
import argparse import cv2 import numpy as np from ais_bench.infer.interface import InferSession CLASSES = {0: 'class_0', 1: 'class_1', 2: 'class_2', ...} # 更换为你模型对应的类别列表 def preprocess_image(image_path, target_size=(224, 224)): # 假设预处理尺寸为224x224 """ 预处理图像至模型所需的尺寸。 """ image = cv2.imread(image_path) image = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR) image = image / 255.0 # 归一化到0-1之间 image = image.astype(np.float32) image = np.expand_dims(image, axis=0) # 添加批量维度 return image def classify_image(session, image_path): """ 执行图像分类推理并打印分类结果。 """ image_data = preprocess_image(image_path) begin_time = time.time() outputs = session.infer(feeds=image_data, mode="static") end_time = time.time() print("OM infer time:", end_time - begin_time) # 假设模型输出是一个长度等于类别的向量,直接取最大值的索引作为预测类别 prediction = outputs[0] predicted_class_id = np.argmax(prediction) predicted_class = CLASSES[predicted_class_id] confidence = prediction[predicted_class_id] print(f"Predicted Class: {predicted_class} with confidence {confidence:.2f}") def main(om_model, input_image): session = InferSession(device_id=0, model_path=om_model) classify_image(session, input_image) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", default="classification_model.om", help="Input your OM model for classification.") parser.add_argument("--img", default="path_to_your_image.jpg", help="Path to input image.") args = parser.parse_args() main(args.model, args.img)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。