赞
踩
背景:我们需要把模型上传集群运行,所以预训练的模型需要放在文件夹之内进行加载,把环境及配置拷入env之后,不能用文件夹之外的库。预训练的resnet101需要直接放入目录下加载。
目录
直接通过pytorch的models加载模型。
- class HGAT_FC(nn.Module):
- def __init__(self, backbone, groups, nclasses, nclasses_per_group, group_channels, class_channels):
- super(HGAT_FC, self).__init__()
- self.groups = groups
- self.nclasses = nclasses
- self.nclasses_per_group = nclasses_per_group
- self.group_channels = group_channels
- self.class_channels = class_channels
- if backbone == 'resnet101':
- model = models.resnet101(pretrained=True)
- elif backbone == 'resnet50':
- model = models.resnet50(pretrained=False)
- else:
- raise Exception()
其中需要导入的库为 torchvision.models
- import torch
- import torchvision.models as models
- from torch import nn
- import mymodels.utils as utils
- import torch
- from torch import nn
- import torch.nn.functional as F
- import torch
- import torchvision.models as models
- 。。。
- if backbone == 'resnet101':
- model = models.resnet101(pretrained=True)
- elif backbone == 'resnet50':
- model = models.resnet50(pretrained=False)
- else:
- raise Exception()
cd ~是返回home目录。这个表明torch再home目录下安装着。
- [xingxiangrui@xx.com ~]$ cd ~/.torch/models
- [xingxiangrui@xx.com models]$ pwd
- /home/xingxiangrui/.torch/models
- [xingxiangrui@xx.com models]$ ls
- resnet101-5d3b4d8f.pth
如果没有下载过,torchvision会自动联网下载模型。
但是没有网络的情况下或者没有权限的情况下,模型不会下载,因此不能运行,会报错。
requests.exceptions.ConnectionError: ('Connection aborted.', TimeoutError(10060, '由于连接方在一段时间后没有正确答复或连接的主机没有反应,连接尝试失败。', None, 10060, None))
因此需要用下面的方法,直接从目录之中加载模型。
每个环境下,模型位置不一定,如果模型已经下载,需要找到模型存储的位置
如果预训练,则相应语句为:
- def resnet101(pretrained=False, **kwargs):
- """Constructs a ResNet-101 model.
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
- if pretrained:
- model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
- return model
对load_url函数进行ctrl+b
找到相应的位置:即如果模型本地有,则从本地加载,如果没有,则从url下载。
- def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True):
- r"""Loads the Torch serialized object at the given URL.
- If the object is already present in `model_dir`, it's deserialized and
- returned. The filename part of the URL should follow the naming convention
- ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
- digits of the SHA256 hash of the contents of the file. The hash is used to
- ensure unique names and to verify the contents of the file.
- The default value of `model_dir` is ``$TORCH_HOME/checkpoints`` where
- environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
- ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
- filesytem layout, with a default value ``~/.cache`` if not set.
- Args:
- url (string): URL of the object to download
- model_dir (string, optional): directory in which to save the object
- map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
- progress (bool, optional): whether or not to display a progress bar to stderr
- Example:
- >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
- """
- # Issue warning to move data if old env is set
- if os.getenv('TORCH_MODEL_ZOO'):
- warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
-
- if model_dir is None:
- torch_home = _get_torch_home()
- model_dir = os.path.join(torch_home, 'checkpoints')
-
- try:
- os.makedirs(model_dir)
- except OSError as e:
- if e.errno == errno.EEXIST:
- # Directory already exists, ignore.
- pass
- else:
- # Unexpected OSError, re-raise.
- raise
-
- parts = urlparse(url)
- filename = os.path.basename(parts.path)
- cached_file = os.path.join(model_dir, filename)
- if not os.path.exists(cached_file):
- sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
- hash_prefix = HASH_REGEX.search(filename).group(1)
- _download_url_to_file(url, cached_file, hash_prefix, progress=progress)
- return torch.load(cached_file, map_location=map_location)
设置断点,用调试器找到模型位置:
这样就不用担心联网的问题,并且可以指定好相应的模型。
https://blog.csdn.net/u014264373/article/details/85332181
直接从pth文件之中进行加载。
例如
- import torch
- import torchvision.models as models
-
- # pretrained=True就可以使用预训练的模型
- net = models.squeezenet1_1(pretrained=False)
- pthfile = r'E:\anaconda\app\envs\luo\Lib\site-packages\torchvision\models\squeezenet1_1.pth'
- net.load_state_dict(torch.load(pthfile))
- print(net)
程序定义直接从目录下面读取文件。
直接从目录下加载
文件放在运行的目录下(语法很可能不对,只是参考):
- def gcn_resnet101(num_classes, t, pretrained=True, adj_file=None, in_channel=300):
- # fixme
- model = models.resnet101(pretrained=False)
- if pretrained:
- print('load pretrained model...')
- model.load_state_dict(torch.load('./resnet101-5d3b4d8f.pth'))
- return GCNResnet(model, num_classes, t=t, adj_file=adj_file, in_channel=in_channel)
cp ~/.torch/models/resnet101-5d3b4d8f.pth chun-ML_GCN/
注意,要与程序运行的位置和 load_state_dict的路径一致
- if backbone == 'resnet101':
- model = models.resnet101(pretrained=False)
- print('load pretrained model...')
- model.load_state_dict(torch.load('./resnet101-5d3b4d8f.pth'))
- elif backbone == 'resnet50':
- model = models.resnet50(pretrained=False)
- print('load pretrained model...')
- model.load_state_dict(torch.load('./resnet50-5d3b4d8f.pth'))
即直接加载运行目录下的resnet101-5d3b4d8f.pth 这个模型。
这部分是我们对自己程序的验证,其他可以不看。因为每个人模型不一样。
直接按上面的方法进行更改。
general_train.py之中,改为exp_3,hgat_fc.py之中按照上面进行修改。
直接在目录下,env/bin/python general_train.py如果不报错,即可。
集群预训练模型的解决
看出报错在于集群依然想要加载预训练模型。
- Downloading: "http://xxxxxxxr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth" to /home/xxx/.torch/models/se_resnet152-d17c99b7.pth
- Traceback (most recent call last):
- File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connection.py", line 159, in _new_conn
- (self._dns_host, self.port), self.timeout, **extra_kw)
- File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/util/connection.py", line 80, in create_connection
- raise err
- File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/util/connection.py", line 70, in create_connection
- sock.connect(sa)
- OSError: [Errno 101] Network is unreachable
-
- During handling of the above exception, another exception occurred:
-
- Traceback (most recent call last):
- File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connectionpool.py", line 600, in urlopen
- chunked=chunked)
- File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connectionpool.py", line 354, in _make_request
- conn.request(method, url, **httplib_request_kw)
- File "/home/sxxx/job/tmp/job-25509/torch/lib/python3.5/http/client.py", line 1107, in request
- self._send_request(method, url, body, headers)
- File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/ccccccccc/client.py", line 1152, in _send_request
- self.endheaders(body)
- File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/http/client.py", line 1103, in endheaders
- self._send_output(message_body)
- File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/http/client.py", line 934, in _send_output
- self.send(msg)
- File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/http/client.py", line 877, in send
- self.connect()
- File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connection.py", line 181, in connect
- conn = self._new_conn()
- File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connection.py", line 168, in _new_conn
- self, "Failed to establish a new connection: %s" % e)
- urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPConnection object at 0x7f03fa52d748>: Failed to establish a new connection: [Errno 101] Network is unreachable
-
- During handling of the above exception, another exception occurred:
-
- Traceback (most recent call last):
- File "/home/xx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/adapters.py", line 449, in send
- timeout=timeout
- File "/home/xxxxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connectionpool.py", line 638, in urlopen
- _stacktrace=sys.exc_info()[2])
- File "/home/xxxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/util/retry.py", line 398, in increment
- raise MaxRetryError(_pool, url, error or ResponseError(cause))
- urllib3.exceptions.MaxRetryError: HTTPConnectionPool(host='data.lip6.fr', port=80): Max retries exceeded with url: /cadene/pretrainedmodels/se_resnet152-d17c99b7.pth (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f03fa52d748>: Failed to establish a new connection: [Errno 101] Network is unreachable',))
-
- During handling of the above exception, another exception occurred:
-
- Traceback (most recent call last):
- File "train_se_clsgat.py", line 128, in <module>
- main()
- File "train_se_clsgat.py", line 107, in main
- model = util.get_model(args)
- File "/home/xxx/job/tmp/job-25509/util.py", line 266, in get_model
- class_channels=args.CLASS_CHANNELS)
- File "/home/xxxx/job/tmp/job-25509/models/se_clsgat.py", line 379, in __init__
- model=senet_origin.se_resnet152()
- File "/home/xxx/job/tmp/job-25509/models/senet_origin.py", line 423, in se_resnet152
- initialize_pretrained_model(model, num_classes, settings)
- File "/home/xxx/job/tmp/job-25509/models/senet_origin.py", line 377, in initialize_pretrained_model
- model.load_state_dict(model_zoo.load_url(settings['url']))
- File "/home/slurm/job/tmp/job-25509/torch/lib/python3.5/site-packages/torch/utils/model_zoo.py", line 65, in load_url
- _download_url_to_file(url, cached_file, hash_prefix, progress=progress)
- File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/torch/utils/model_zoo.py", line 71, in _download_url_to_file
- u = urlopen(url, stream=True)
- File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/api.py", line 75, in get
- return request('get', url, params=params, **kwargs)
- File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/api.py", line 60, in request
- return session.request(method=method, url=url, **kwargs)
- File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/sessions.py", line 533, in request
- resp = self.send(prep, **send_kwargs)
- File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/sessions.py", line 646, in send
- r = adapter.send(request, **kwargs)
- File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/adapters.py", line 516, in send
- raise ConnectionError(e, request=request)
- requests.exceptions.ConnectionError: HTTPConnectionPool(host='data.lip6.fr', port=80): Max retries exceeded with url: /cadene/pretrainedmodels/se_resnet152-d17c99b7.pth (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f03fa52d748>: Failed to establish a new connection: [Errno 101] Network is unreachable',))
需要将预训练模型放在目录之下免得集群重复加载。
程序没有运行到加载模型一步。
- ==== GLOBAL INFO ====
- IPLIST: xx.xx.xx.xx
- IP0: xx.xx.xx.xx
- ====================
- ==== NODE INFO ====
- NODE_RNAK: 0
- IP0: xx.xx.xx.xx
- NODE_IP: xx.xx.xx
- ===================
- {'ADJ_FILE': 'data/data/coco/coco_adj.pkl',
- 'ALPHA': 0.8,
- 'BACKBONE': 'resnet150',
- 'BATCH_SIZE': 16,
- 'CLASS_CHANNELS': 256,
- 'CPROB': array([[1.00000000e+00, 8.26410144e-01, 7.04392284e-01, ...,
- 4.03311258e-01, 4.45312500e-01, 5.40000000e-01],
- [4.18382255e-02, 1.00000000e+00, 1.02719033e-01, ...,
- 1.12582781e-02, 0.00000000e+00, 5.71428571e-03],
- [1.34192234e-01, 3.86532575e-01, 1.00000000e+00, ...,
- 3.84105960e-02, 7.81250000e-03, 8.57142857e-03],
- ...,
- [1.34812060e-02, 7.43331876e-03, 6.73948408e-03, ...,
- 1.00000000e+00, 2.34375000e-02, 8.57142857e-03],
- [1.26178775e-03, 0.00000000e+00, 1.16198001e-04, ...,
- 1.98675497e-03, 1.00000000e+00, 2.57142857e-02],
- [8.36764511e-03, 1.74901618e-03, 6.97188008e-04, ...,
- 3.97350993e-03, 1.40625000e-01, 1.00000000e+00]]),
- 'DATA': 'data/data/coco',
- 'DATA_TYPE': 'coco',
- 'DEEPMAR_LOSS': <loss.DeepMarWeights object at 0x7f04044800f0>,
- 'DEVICE_IDS': [0, 1, 2, 3, 4, 5, 6, 7],
- 'EPOCH': 100,
- 'EPOCH_STEP': 30,
- 'EVALUATE': False,
- 'EXP_NAME': 'se_clsgat',
- 'GROUPS': 12,
- 'GROUP_CHANNELS': 512,
- 'IMAGE_SIZE': 448,
- 'INP_NAME': 'data/data/coco/coco_glove_word2vec.pkl',
- 'IS_SLURM': False,
- 'LOSS_TYPE': 'DeepMarLoss',
- 'LR': 0.01,
- 'LRP': 0.01,
- 'LR_SCHEDULER': None,
- 'LR_SCHEDULER_PARAMS': None,
- 'MAX_EPOCH': 100,
- 'MODEL': 'se_clsgat',
- 'MOMENTUM': 0.9,
- 'NCLASSES': 80,
- 'NCLASSES_PER_GROUP': [1, 8, 5, 10, 5, 10, 7, 10, 6, 6, 5, 7],
- 'PRINT_FREQ': 10,
- 'RESUME': 'checkpoints/coco/se_clsgat/checkpoint.pth.tar',
- 'SAVE_MODEL_PATH': 'checkpoints/coco/se_clsgat',
- 'START_EPOCH': 0,
- 'WEIGHT_DECAY': 1e-05,
- 'WEIGHT_FILE': 'data/coco/coco_rate.pkl',
- 'WORKERS': 4}
- Compose(
- Resize(size=(512, 512), interpolation=PIL.Image.BILINEAR)
- MultiScaleCrop
- RandomHorizontalFlip(p=0.5)
- ToTensor()
- Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- )
- Compose(
- Warp (size=448, interpolation=2)
- ToTensor()
- Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- )
- [dataset] Done!
- [annotation] Done!
- [json] Done!
- [dataset] Done!
- [annotation] Done!
- [json] Done!
- -------------------------------------------------------
- Primary job terminated normally, but 1 process returned
- a non-zero exit code.. Per user-direction, the job has been aborted.
- -------------------------------------------------------
- --------------------------------------------------------------------------
- mpirun detected that one or more processes exited with non-zero status, thus causing
- the job to be terminated. The first process to do so was:
-
- Process name: [[58771,1],0]
- Exit code: 1
- --------------------------------------------------------------------------
通过上面1.5中的方法设置断点找到模型位置:拷贝过去
本地可以采用这种方法:
- (torch041) $ cd /Users/baidu/.cache/torch/checkpoints/
- (torch041) baidudeMacBook-Pro:checkpoints baidu$ ls
- resnet101-5d3b4d8f.pth se_resnet152-d17c99b7.pth
- (torch041) baidudeMacBook-Pro:checkpoints baidu$ cp se_resnet152-d17c99b7.pth /Users/baidu/Desktop/code/ML_GAT-master/
运行没有报错。
服务器已经知道相应的torch的缓存的地址:
- cd ~/.torch/models/
- ls
- resnet101-5d3b4d8f.pth resnet50-19c8e357.pth se_resnet152-d17c99b7.pth
直接更换更改好的
senet_origin
- def initialize_pretrained_model(model, num_classes, settings):
- assert num_classes == settings['num_classes'], \
- 'num_classes should be {}, but is {}'.format(
- settings['num_classes'], num_classes)
- # model.load_state_dict(model_zoo.load_url(settings['url']))
- print('loading pretrained model from local...')
- model.load_state_dict(torch.load('./se_resnet152-d17c99b7.pth'))
- model.input_space = settings['input_space']
- model.input_size = settings['input_size']
- model.input_range = settings['input_range']
- model.mean = settings['mean']
- model.std = settings['std']
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。