当前位置:   article > 正文

快速使用OpenVINO的 Anomalib实现训练和推理_anomalib 怎么训练

anomalib 怎么训练

快速使用OpenVINO的 Anomalib实现训练和推理

代码


import os
from pathlib import Path
from anomalib.data import MVTec
from anomalib import TaskType
from anomalib.deploy import ExportType, OpenVINOInferencer
from anomalib.engine import Engine
from anomalib.models import Padim, Patchcore,Stfpm
from matplotlib import pyplot as plt
from anomalib.data.utils import read_image
import time



def  train_and_export_model ( object_type, model, transform= None ): 
    """
    训练并导出MVTec数据集上的模型为OpenVINO格式。
    
    Args:
        object_type (str): MVTec数据集的类别,如'bottle'、'cap'等。
        model (torch.nn.Module): 待训练的深度学习模型。
        transform (Callable, optional): 数据预处理函数,默认为None。
    
    Returns:
        str: 导出模型保存的根目录路径。
    
    """

    
    datamodule = MVTec()
    datamodule.category=object_type
    engine = Engine(task=TASK) 
    engine.fit(model=model, datamodule=datamodule) 

    ## 将模型导出为 OpenVINO 格式以进行快速推理
    engine.export( 
        model=model, 
        export_type=ExportType.OPENVINO, 
    ) 
    print(f"Model save to {engine.trainer.default_root_dir}).") 
    return  engine.trainer.default_root_dir




if __name__ == '__main__':

    # Initialize the datamodule, model and engine
    OBJECT = "transistor"  ## 要训练的对象
    TASK = TaskType.SEGMENTATION ## 模型的任务类型

    # model = Padim()
    # model =Patchcore()
    # model=Stfpm()
    
    # train_and_export_model(OBJECT, model)

    output_path=Path("results/Padim/MVTec/transistor/latest")
    openvino_model_path = output_path / "weights" / "openvino" / "model.bin"
    metadata_path = output_path / "weights" / "openvino" / "metadata.json"
    print(openvino_model_path.exists(), metadata_path.exists())

    inferencer = OpenVINOInferencer(
        path=openvino_model_path,  # Path to the OpenVINO IR model.
        metadata=metadata_path,  # Path to the metadata file.
        device="AUTO",  # We would like to run it on an Intel CPU.
    )

    # 定义文件夹路径
    folder_path = "./datasets/MVTec/transistor/test/bent_lead/"

    # 获取文件夹中所有的.png文件
    png_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]

    for file_name in png_files:
         
        image = read_image(path=folder_path +'/'+ file_name)

        # 记录开始时间
        start_time = time.time()
        predictions = inferencer.predict(image=image)
        # 记录结束时间
        end_time = time.time()

        # 计算耗时
        elapsed_time = end_time - start_time
        print(f"Prediction took {elapsed_time:.4f} seconds.")
        print(predictions.pred_score, predictions.pred_label)
        # 创建一个新的图形窗口
        fig, axs = plt.subplots(1, 3, figsize=(18, 8))  # 创建一个1行3列的子图网格

        # 原始图像
        axs[0].imshow(image)
        axs[0].set_title('Original Image')
        axs[0].axis('off')  # 关闭坐标轴

        # 热图
        axs[1].imshow(predictions.heat_map, cmap='hot', interpolation='nearest')
        axs[1].set_title('Heat Map')
        axs[1].axis('off')  # 关闭坐标轴

        # 预测掩模
        axs[2].imshow(predictions.pred_mask, cmap='gray', interpolation='nearest')
        axs[2].set_title('Predicted Mask')
        axs[2].axis('off')  # 关闭坐标轴


        # 添加文本信息到图形的上方中间位置
        fig_text_x = 0.1  # x坐标在图形宽度的中心位置
        fig_text_y = 0.95  # y坐标稍微靠近图形的顶部,避免与子图重叠
        fig.text(fig_text_x, fig_text_y,
                f'Prediction Time: {elapsed_time:.4f} s\n'
                f'Predicted Class: {predictions.pred_label}\n'
                # f'Score: {predictions.pred_score:.4f}\n'
                f'Threshold: {predictions.pred_score:.4f}' if hasattr(predictions, 'pred_score') else '',
                ha='left', va='center', fontsize=12,
                bbox=dict(boxstyle="round", fc="w", ec="0.5", alpha=0.5))  

        # 显示整个图形
        plt.tight_layout()  # 调整子图间的间距
        plt.show()
    print("Done")

    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124

运行的结果截图

运行结果

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小桥流水78/article/detail/1011718
推荐阅读
相关标签
  

闽ICP备14008679号