当前位置:   article > 正文

torch.hub.load()把联网加载修改为本地加载模型

torch.hub.load

1. torch.load()

torch.load函数用于从磁盘加载已保存的模型或张量,以便进行后续的操作。这也是我们常用的一种导入预训练模型的方式,可以使用以下方式调用该函数:

model = torch.load('model.pth')

其中,model.pth就是我们存放模型的路径。

2.  torch.hub.load()

最近在复现某一个关于yolo的项目中遇到了这个方法,从该方法的hub可以看出,它在每次加载模型时都要联网进行加载。比如:

  1. model = torch.hub.load(
  2. "ultralytics/yolov5",
  3. "custom",
  4. path=f"{local_model_path}/{model_name}",
  5. device=device,
  6. force_reload=[True if "refresh_yolov5" in opt else False][0],
  7. _verbose=True,
  8. )

其中custom表示自定义的模型,path是本地权重文件的路径,而"ultralytics/yolov5"表示该load方法每次加载模型时,都会访问到https://www.wpsshop.cn/w/AllinToyou/article/detail/532009

推荐阅读
相关标签
  

闽ICP备14008679号