当前位置:   article > 正文

Pytorch预训练模型(torch.hub)缓存地址修改

torch.hub

前言

在Pytorch中,有一些预训练模型或者预先封装的功能往往通过torch.hub模块中的一些方法进行加载,会保存一些文件在本地,通常默认地址是在C盘。考虑到某些预加载的资源很大,保存在C盘十分的占用存储空间,因此有时候需要修改这个保存地址。

注意!本文有较长篇幅分析Pytorch缓存路径的设置逻辑,若无相关需求,可直接跳到总结部分查看具体配置方法。

分析

其实不论是使用torch.hub.load()或者是Pytorch提供的预训练模型的服务,通过对源码的跟踪分析,会发现它们下载资源的方式都是通过torch.hub模块进行完成的,以最常见的预训练模型下载函数load_state_dict_from_url() 为例,可以在其函数声明部分看到 model_dir 参数。

def load_state_dict_from_url(
    url: str,
    model_dir: Optional[str] = None,
    map_location: Optional[Union[Callable[[str], str], Dict[str, str]]] = None,
    progress: bool = True,
    check_hash: bool = False,
    file_name: Optional[str] = None
) -> Dict[str, Any]:
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

model_dir 参数处于缺省状态时,该函数会调用同一模块下的 get_dir() 函数获取默认缓存地址。

    if model_dir is None:
        hub_dir = get_dir()
        model_dir = os.path.join(hub_dir, 'checkpoints')
  • 1
  • 2
  • 3

进入 get_dir() 函数可以看到,其调用了一个私有方法 _get_torch_home() 获取默认路径。

def get_dir():
    r"""
    Get the Torch Hub cache directory used for storing downloaded models & weights.

    If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where
    environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
    ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
    filesystem layout, with a default value ``~/.cache`` if the environment
    variable is not set.
    """
    # Issue warning to move data if old env is set
    if os.getenv('TORCH_HUB'):
        warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead')

    if _hub_dir is not None:
        return _hub_dir
    return os.path.join(_get_torch_home(), 'hub')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

函数的相关代码以及一些常量定义如下:

ENV_TORCH_HOME = 'TORCH_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'

def _get_torch_home():
    torch_home = os.path.expanduser(
        os.getenv(ENV_TORCH_HOME,
                  os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
                                         DEFAULT_CACHE_DIR), 'torch')))
    return torch_home
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

从其调用了os.getenv() 来看,显然是通过读取环境变量来确定默认目录的。接下来我们依次分析该代码段。首先是对于最外层的os.getenv(ENV_TORCH_HOME) ,其获取常量ENV_TORCH_HOME所指向的环境变量中的值,即环境变量TORCH_HOME中的值,若未找到该环境变量的值,则返回如下值:

os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
                                         DEFAULT_CACHE_DIR), 'torch')
  • 1
  • 2

该值由第一个参数与第二个参数拼接而成,其中第一个参数同样是尝试获取环境变量,而第二个参数则是常量文本 ‘torch’ ,对于第一个参数而言,其获取常量ENV_XDG_CACHE_HOME所指向的环境变量XDG_CACHE_HOME的值,若获取失败则返回常量DEFAULT_CACHE_DIR的值,即 ‘~/.cache’

同时,当返回路径如果是 **'~/.cache/torch’时,最外层 os.path.expanduser会自动替换~**关键符号为当前计算机的用户路径,例如 C:\Users\Administrator.cache\torch

总结

在hub.py 文件中

_get_torch_home 函数获取缓存默认存储位置

通常优先取环境变量 ‘TORCH_HOME’ 中的值,在代码中期被声明为ENV_TORCH_HOME 变量。

若不存在则取环境变量 ‘XDG_CACHE_HOME’ 的值拼接 ‘torch’ 为默认位置,其中 ‘XDG_CACHE_HOME’ 在代码中被声明为变量ENV_XDG_CACHE_HOME。

若依旧不存在,则返回 ‘~/.cache’ +‘torch’ ,并替换 ~ 为本地用户路径。

因此,可以通过配置环境变量来修改Pytorch的默认缓存位置,具体如下:

‘XDG_CACHE_HOME’ = Pytorch相关包存放缓存的默认位置

‘TORCH_HOME’ = %XDG_CACHE_HOME%\torch

具体步骤如下:
首先打开计算机的属性面板
在这里插入图片描述
接着在属性面板右上角打开 “高级系统设置”
在这里插入图片描述
从高级设置中进入环境变量设置界面在这里插入图片描述
通过点击新建,完成对环境变量的新增,其中用户变量仅对当前用户有效,而环境变量则对本机器所有用户生效。在这里插入图片描述
在我的设置中,我设置如下:
XDG_CACHE_HOME=D:\Python\cache
TORCH_HOME=%XDG_CACHE_HOME%\torch

即用D:\Python\cache存储Pytorch相关包下载的缓存,并使用D:\Python\cache\torch缓存Pytorch本身下载的一些缓存文件。其中 %XDG_CACHE_HOME% 可看做对环境变量XDG_CACHE_HOME的引用。在这里插入图片描述

或者在项目运行的代码前加上临时环境变量设置
os.environ[‘TORCH_HOME’]=‘E:/Data/torch-model’

// 全文完

因笔者能力有限,若文章内容存在错误或不恰当之处,欢迎留言、私信批评指正。
Email:YePeanut[at]foxmail.com

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/569603
推荐阅读
相关标签
  

闽ICP备14008679号