当前位置:   article > 正文

PyTorch项目应用实例(一)加载(本地|官方)预训练模型_怎么加载预训练模型def

怎么加载预训练模型def

背景:我们需要把模型上传集群运行,所以预训练的模型需要放在文件夹之内进行加载,把环境及配置拷入env之后,不能用文件夹之外的库。预训练的resnet101需要直接放入目录下加载。

目录

一、预训练模型的加载

1.1 模型加载

1.2 加载流程

1.3 模型位置

1.4 缺点

1.5 找到预训练模型位置

二、加载指定位置模型

2.1 例子程序

2.2 把网络模型放入目录下

2.3 我们的程序

三、验证(可不看)

四、集群预训练模型的解决

4.1 相应报错

4.2 加载模型位置

4.3 服务器拷贝及运行


一、预训练模型的加载

1.1 模型加载

直接通过pytorch的models加载模型。

  1. class HGAT_FC(nn.Module):
  2. def __init__(self, backbone, groups, nclasses, nclasses_per_group, group_channels, class_channels):
  3. super(HGAT_FC, self).__init__()
  4. self.groups = groups
  5. self.nclasses = nclasses
  6. self.nclasses_per_group = nclasses_per_group
  7. self.group_channels = group_channels
  8. self.class_channels = class_channels
  9. if backbone == 'resnet101':
  10. model = models.resnet101(pretrained=True)
  11. elif backbone == 'resnet50':
  12. model = models.resnet50(pretrained=False)
  13. else:
  14. raise Exception()

其中需要导入的库为 torchvision.models

  1. import torch
  2. import torchvision.models as models
  3. from torch import nn
  4. import mymodels.utils as utils
  5. import torch
  6. from torch import nn
  7. import torch.nn.functional as F

1.2 加载流程

  1. import torch
  2. import torchvision.models as models
  3. 。。。
  4. if backbone == 'resnet101':
  5. model = models.resnet101(pretrained=True)
  6. elif backbone == 'resnet50':
  7. model = models.resnet50(pretrained=False)
  8. else:
  9. raise Exception()

1.3 模型位置

cd ~是返回home目录。这个表明torch再home目录下安装着。

  1. [xingxiangrui@xx.com ~]$ cd ~/.torch/models
  2. [xingxiangrui@xx.com models]$ pwd
  3. /home/xingxiangrui/.torch/models
  4. [xingxiangrui@xx.com models]$ ls
  5. resnet101-5d3b4d8f.pth

1.4 缺点

如果没有下载过,torchvision会自动联网下载模型。

但是没有网络的情况下或者没有权限的情况下,模型不会下载,因此不能运行,会报错。

requests.exceptions.ConnectionError: ('Connection aborted.', TimeoutError(10060, '由于连接方在一段时间后没有正确答复或连接的主机没有反应,连接尝试失败。', None, 10060, None))

因此需要用下面的方法,直接从目录之中加载模型。

1.5 找到预训练模型位置

每个环境下,模型位置不一定,如果模型已经下载,需要找到模型存储的位置

如果预训练,则相应语句为:

  1. def resnet101(pretrained=False, **kwargs):
  2. """Constructs a ResNet-101 model.
  3. Args:
  4. pretrained (bool): If True, returns a model pre-trained on ImageNet
  5. """
  6. model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
  7. if pretrained:
  8. model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
  9. return model

对load_url函数进行ctrl+b

找到相应的位置:即如果模型本地有,则从本地加载,如果没有,则从url下载。

  1. def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True):
  2. r"""Loads the Torch serialized object at the given URL.
  3. If the object is already present in `model_dir`, it's deserialized and
  4. returned. The filename part of the URL should follow the naming convention
  5. ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
  6. digits of the SHA256 hash of the contents of the file. The hash is used to
  7. ensure unique names and to verify the contents of the file.
  8. The default value of `model_dir` is ``$TORCH_HOME/checkpoints`` where
  9. environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
  10. ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
  11. filesytem layout, with a default value ``~/.cache`` if not set.
  12. Args:
  13. url (string): URL of the object to download
  14. model_dir (string, optional): directory in which to save the object
  15. map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
  16. progress (bool, optional): whether or not to display a progress bar to stderr
  17. Example:
  18. >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
  19. """
  20. # Issue warning to move data if old env is set
  21. if os.getenv('TORCH_MODEL_ZOO'):
  22. warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
  23. if model_dir is None:
  24. torch_home = _get_torch_home()
  25. model_dir = os.path.join(torch_home, 'checkpoints')
  26. try:
  27. os.makedirs(model_dir)
  28. except OSError as e:
  29. if e.errno == errno.EEXIST:
  30. # Directory already exists, ignore.
  31. pass
  32. else:
  33. # Unexpected OSError, re-raise.
  34. raise
  35. parts = urlparse(url)
  36. filename = os.path.basename(parts.path)
  37. cached_file = os.path.join(model_dir, filename)
  38. if not os.path.exists(cached_file):
  39. sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
  40. hash_prefix = HASH_REGEX.search(filename).group(1)
  41. _download_url_to_file(url, cached_file, hash_prefix, progress=progress)
  42. return torch.load(cached_file, map_location=map_location)

设置断点,用调试器找到模型位置:

二、加载指定位置模型

这样就不用担心联网的问题,并且可以指定好相应的模型。

https://blog.csdn.net/u014264373/article/details/85332181

直接从pth文件之中进行加载。

例如

  1. import torch
  2. import torchvision.models as models
  3. # pretrained=True就可以使用预训练的模型
  4. net = models.squeezenet1_1(pretrained=False)
  5. pthfile = r'E:\anaconda\app\envs\luo\Lib\site-packages\torchvision\models\squeezenet1_1.pth'
  6. net.load_state_dict(torch.load(pthfile))
  7. print(net)

2.1 例子程序

程序定义直接从目录下面读取文件。

直接从目录下加载

文件放在运行的目录下(语法很可能不对,只是参考):

  1. def gcn_resnet101(num_classes, t, pretrained=True, adj_file=None, in_channel=300):
  2. # fixme
  3. model = models.resnet101(pretrained=False)
  4. if pretrained:
  5. print('load pretrained model...')
  6. model.load_state_dict(torch.load('./resnet101-5d3b4d8f.pth'))
  7. return GCNResnet(model, num_classes, t=t, adj_file=adj_file, in_channel=in_channel)

2.2 把网络模型放入目录下

cp ~/.torch/models/resnet101-5d3b4d8f.pth chun-ML_GCN/

注意,要与程序运行的位置和 load_state_dict的路径一致

2.3 我们的程序

  1. if backbone == 'resnet101':
  2. model = models.resnet101(pretrained=False)
  3. print('load pretrained model...')
  4. model.load_state_dict(torch.load('./resnet101-5d3b4d8f.pth'))
  5. elif backbone == 'resnet50':
  6. model = models.resnet50(pretrained=False)
  7. print('load pretrained model...')
  8. model.load_state_dict(torch.load('./resnet50-5d3b4d8f.pth'))

即直接加载运行目录下的resnet101-5d3b4d8f.pth 这个模型。

三、验证(可不看)

这部分是我们对自己程序的验证,其他可以不看。因为每个人模型不一样。

直接按上面的方法进行更改。

general_train.py之中,改为exp_3,hgat_fc.py之中按照上面进行修改。

直接在目录下,env/bin/python general_train.py如果不报错,即可。

四、集群预训练模型的解决

集群预训练模型的解决

4.1 相应报错

看出报错在于集群依然想要加载预训练模型。

  1. Downloading: "http://xxxxxxxr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth" to /home/xxx/.torch/models/se_resnet152-d17c99b7.pth
  2. Traceback (most recent call last):
  3. File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connection.py", line 159, in _new_conn
  4. (self._dns_host, self.port), self.timeout, **extra_kw)
  5. File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/util/connection.py", line 80, in create_connection
  6. raise err
  7. File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/util/connection.py", line 70, in create_connection
  8. sock.connect(sa)
  9. OSError: [Errno 101] Network is unreachable
  10. During handling of the above exception, another exception occurred:
  11. Traceback (most recent call last):
  12. File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connectionpool.py", line 600, in urlopen
  13. chunked=chunked)
  14. File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connectionpool.py", line 354, in _make_request
  15. conn.request(method, url, **httplib_request_kw)
  16. File "/home/sxxx/job/tmp/job-25509/torch/lib/python3.5/http/client.py", line 1107, in request
  17. self._send_request(method, url, body, headers)
  18. File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/ccccccccc/client.py", line 1152, in _send_request
  19. self.endheaders(body)
  20. File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/http/client.py", line 1103, in endheaders
  21. self._send_output(message_body)
  22. File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/http/client.py", line 934, in _send_output
  23. self.send(msg)
  24. File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/http/client.py", line 877, in send
  25. self.connect()
  26. File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connection.py", line 181, in connect
  27. conn = self._new_conn()
  28. File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connection.py", line 168, in _new_conn
  29. self, "Failed to establish a new connection: %s" % e)
  30. urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPConnection object at 0x7f03fa52d748>: Failed to establish a new connection: [Errno 101] Network is unreachable
  31. During handling of the above exception, another exception occurred:
  32. Traceback (most recent call last):
  33. File "/home/xx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/adapters.py", line 449, in send
  34. timeout=timeout
  35. File "/home/xxxxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connectionpool.py", line 638, in urlopen
  36. _stacktrace=sys.exc_info()[2])
  37. File "/home/xxxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/util/retry.py", line 398, in increment
  38. raise MaxRetryError(_pool, url, error or ResponseError(cause))
  39. urllib3.exceptions.MaxRetryError: HTTPConnectionPool(host='data.lip6.fr', port=80): Max retries exceeded with url: /cadene/pretrainedmodels/se_resnet152-d17c99b7.pth (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f03fa52d748>: Failed to establish a new connection: [Errno 101] Network is unreachable',))
  40. During handling of the above exception, another exception occurred:
  41. Traceback (most recent call last):
  42. File "train_se_clsgat.py", line 128, in <module>
  43. main()
  44. File "train_se_clsgat.py", line 107, in main
  45. model = util.get_model(args)
  46. File "/home/xxx/job/tmp/job-25509/util.py", line 266, in get_model
  47. class_channels=args.CLASS_CHANNELS)
  48. File "/home/xxxx/job/tmp/job-25509/models/se_clsgat.py", line 379, in __init__
  49. model=senet_origin.se_resnet152()
  50. File "/home/xxx/job/tmp/job-25509/models/senet_origin.py", line 423, in se_resnet152
  51. initialize_pretrained_model(model, num_classes, settings)
  52. File "/home/xxx/job/tmp/job-25509/models/senet_origin.py", line 377, in initialize_pretrained_model
  53. model.load_state_dict(model_zoo.load_url(settings['url']))
  54. File "/home/slurm/job/tmp/job-25509/torch/lib/python3.5/site-packages/torch/utils/model_zoo.py", line 65, in load_url
  55. _download_url_to_file(url, cached_file, hash_prefix, progress=progress)
  56. File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/torch/utils/model_zoo.py", line 71, in _download_url_to_file
  57. u = urlopen(url, stream=True)
  58. File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/api.py", line 75, in get
  59. return request('get', url, params=params, **kwargs)
  60. File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/api.py", line 60, in request
  61. return session.request(method=method, url=url, **kwargs)
  62. File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/sessions.py", line 533, in request
  63. resp = self.send(prep, **send_kwargs)
  64. File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/sessions.py", line 646, in send
  65. r = adapter.send(request, **kwargs)
  66. File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/adapters.py", line 516, in send
  67. raise ConnectionError(e, request=request)
  68. requests.exceptions.ConnectionError: HTTPConnectionPool(host='data.lip6.fr', port=80): Max retries exceeded with url: /cadene/pretrainedmodels/se_resnet152-d17c99b7.pth (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f03fa52d748>: Failed to establish a new connection: [Errno 101] Network is unreachable',))

需要将预训练模型放在目录之下免得集群重复加载。

程序没有运行到加载模型一步。

  1. ==== GLOBAL INFO ====
  2. IPLIST: xx.xx.xx.xx
  3. IP0: xx.xx.xx.xx
  4. ====================
  5. ==== NODE INFO ====
  6. NODE_RNAK: 0
  7. IP0: xx.xx.xx.xx
  8. NODE_IP: xx.xx.xx
  9. ===================
  10. {'ADJ_FILE': 'data/data/coco/coco_adj.pkl',
  11. 'ALPHA': 0.8,
  12. 'BACKBONE': 'resnet150',
  13. 'BATCH_SIZE': 16,
  14. 'CLASS_CHANNELS': 256,
  15. 'CPROB': array([[1.00000000e+00, 8.26410144e-01, 7.04392284e-01, ...,
  16. 4.03311258e-01, 4.45312500e-01, 5.40000000e-01],
  17. [4.18382255e-02, 1.00000000e+00, 1.02719033e-01, ...,
  18. 1.12582781e-02, 0.00000000e+00, 5.71428571e-03],
  19. [1.34192234e-01, 3.86532575e-01, 1.00000000e+00, ...,
  20. 3.84105960e-02, 7.81250000e-03, 8.57142857e-03],
  21. ...,
  22. [1.34812060e-02, 7.43331876e-03, 6.73948408e-03, ...,
  23. 1.00000000e+00, 2.34375000e-02, 8.57142857e-03],
  24. [1.26178775e-03, 0.00000000e+00, 1.16198001e-04, ...,
  25. 1.98675497e-03, 1.00000000e+00, 2.57142857e-02],
  26. [8.36764511e-03, 1.74901618e-03, 6.97188008e-04, ...,
  27. 3.97350993e-03, 1.40625000e-01, 1.00000000e+00]]),
  28. 'DATA': 'data/data/coco',
  29. 'DATA_TYPE': 'coco',
  30. 'DEEPMAR_LOSS': <loss.DeepMarWeights object at 0x7f04044800f0>,
  31. 'DEVICE_IDS': [0, 1, 2, 3, 4, 5, 6, 7],
  32. 'EPOCH': 100,
  33. 'EPOCH_STEP': 30,
  34. 'EVALUATE': False,
  35. 'EXP_NAME': 'se_clsgat',
  36. 'GROUPS': 12,
  37. 'GROUP_CHANNELS': 512,
  38. 'IMAGE_SIZE': 448,
  39. 'INP_NAME': 'data/data/coco/coco_glove_word2vec.pkl',
  40. 'IS_SLURM': False,
  41. 'LOSS_TYPE': 'DeepMarLoss',
  42. 'LR': 0.01,
  43. 'LRP': 0.01,
  44. 'LR_SCHEDULER': None,
  45. 'LR_SCHEDULER_PARAMS': None,
  46. 'MAX_EPOCH': 100,
  47. 'MODEL': 'se_clsgat',
  48. 'MOMENTUM': 0.9,
  49. 'NCLASSES': 80,
  50. 'NCLASSES_PER_GROUP': [1, 8, 5, 10, 5, 10, 7, 10, 6, 6, 5, 7],
  51. 'PRINT_FREQ': 10,
  52. 'RESUME': 'checkpoints/coco/se_clsgat/checkpoint.pth.tar',
  53. 'SAVE_MODEL_PATH': 'checkpoints/coco/se_clsgat',
  54. 'START_EPOCH': 0,
  55. 'WEIGHT_DECAY': 1e-05,
  56. 'WEIGHT_FILE': 'data/coco/coco_rate.pkl',
  57. 'WORKERS': 4}
  58. Compose(
  59. Resize(size=(512, 512), interpolation=PIL.Image.BILINEAR)
  60. MultiScaleCrop
  61. RandomHorizontalFlip(p=0.5)
  62. ToTensor()
  63. Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  64. )
  65. Compose(
  66. Warp (size=448, interpolation=2)
  67. ToTensor()
  68. Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  69. )
  70. [dataset] Done!
  71. [annotation] Done!
  72. [json] Done!
  73. [dataset] Done!
  74. [annotation] Done!
  75. [json] Done!
  76. -------------------------------------------------------
  77. Primary job terminated normally, but 1 process returned
  78. a non-zero exit code.. Per user-direction, the job has been aborted.
  79. -------------------------------------------------------
  80. --------------------------------------------------------------------------
  81. mpirun detected that one or more processes exited with non-zero status, thus causing
  82. the job to be terminated. The first process to do so was:
  83. Process name: [[58771,1],0]
  84. Exit code: 1
  85. --------------------------------------------------------------------------

4.2 加载模型位置

通过上面1.5中的方法设置断点找到模型位置:拷贝过去

本地可以采用这种方法:

  1. (torch041) $ cd /Users/baidu/.cache/torch/checkpoints/
  2. (torch041) baidudeMacBook-Pro:checkpoints baidu$ ls
  3. resnet101-5d3b4d8f.pth se_resnet152-d17c99b7.pth
  4. (torch041) baidudeMacBook-Pro:checkpoints baidu$ cp se_resnet152-d17c99b7.pth /Users/baidu/Desktop/code/ML_GAT-master/

运行没有报错。

4.3 服务器拷贝及运行

服务器已经知道相应的torch的缓存的地址:

  1. cd ~/.torch/models/
  2. ls
  3. resnet101-5d3b4d8f.pth resnet50-19c8e357.pth se_resnet152-d17c99b7.pth

直接更换更改好的

senet_origin

  1. def initialize_pretrained_model(model, num_classes, settings):
  2. assert num_classes == settings['num_classes'], \
  3. 'num_classes should be {}, but is {}'.format(
  4. settings['num_classes'], num_classes)
  5. # model.load_state_dict(model_zoo.load_url(settings['url']))
  6. print('loading pretrained model from local...')
  7. model.load_state_dict(torch.load('./se_resnet152-d17c99b7.pth'))
  8. model.input_space = settings['input_space']
  9. model.input_size = settings['input_size']
  10. model.input_range = settings['input_range']
  11. model.mean = settings['mean']
  12. model.std = settings['std']
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/569614
推荐阅读
相关标签
  

闽ICP备14008679号