赞
踩
2021SC@SDUSC
源码:utils.py
本篇主要分析utils.py中的class GFPGANer ( )的初始化以及load_file_from_url( )方法
目录
(4)增加了一个model路径是网址时的处理,然而需要的model已经下载到本地,并没有用到
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
os.path.abspath(file)获取当前文件的绝对路径C:\Users\Vaifer\Desktop\GFPGAN-v.0.2.1\gfpgan\utils.py
os.path.dirname()再获取该文件所在的目录路径C:\Users\Vaifer\Desktop\GFPGAN-v.0.2.1\gfpgan
最终应该得到C:\xxx\GFPGAN-v.0.2.
参数:(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- if arch == 'clean':
- self.gfpgan = GFPGANv1Clean(
- out_size=512,
- num_style_feat=512,
- channel_multiplier=channel_multiplier,
- decoder_load_path=None,
- fix_decoder=False,
- num_mlp=8,
- input_is_latent=True,
- different_w=True,
- narrow=1,
- sft_half=True)
- else:
- self.gfpgan = GFPGANv1(
- out_size=512,
- num_style_feat=512,
- channel_multiplier=channel_multiplier,
- decoder_load_path=None,
- fix_decoder=True,
- num_mlp=8,
- input_is_latent=True,
- different_w=True,
- narrow=1,
- sft_half=True)
可以看到分别调用了GFPGANv1Clean与GFPGANv1进行初始化,之后我们会具体分析这两个类
这边就使用到了facexlib包中的face restoration helper
- self.face_helper = FaceRestoreHelper(
- upscale,
- face_size=512,
- crop_ratio=(1, 1),
- det_model='retinaface_resnet50',
- save_ext='png',
- device=self.device)
if model_path.startswith('https://'):
- loadnet = torch.load(model_path)
- if 'params_ema' in loadnet:
- keyname = 'params_ema'
- else:
- keyname = 'params'
- self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
- self.gfpgan.eval()
- self.gfpgan = self.gfpgan.to(self.device)
从指定url中下载文件并读取的一个函数,简单介绍下
在读取model时如果路径是网址,就会调用这个函数下载相应的model
参数:(url, model_dir=None, progress=True, file_name=None)
- def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
- """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
- """
- hub_dir = get_dir()
- model_dir = os.path.join(hub_dir, 'checkpoints')
- print('hub_dir',hub_dir)
- print('model_dir',model_dir)
- if model_dir is None:
- hub_dir = get_dir()
- model_dir = os.path.join(hub_dir, 'checkpoints')
- #做路径的拼接,并递归创建目录
- os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
-
- parts = urlparse(url)
- filename = os.path.basename(parts.path)
- if file_name is not None:
- filename = file_name
- cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
- if not os.path.exists(cached_file):
- print(f'Downloading: "{url}" to {cached_file}\n')
- download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
- return cached_file
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。