当前位置:   article > 正文

GFPGAN源码分析—第三篇_python facerestorehelper

python facerestorehelper

2021SC@SDUSC

源码:utils.py

本篇主要分析utils.py中的class GFPGANer ( )的初始化以及load_file_from_url( )方法

目录

1.获取当前项目路径

2.class GFOGANer ( )——init()

(1)优先选择在cupy+gpu上运行

(2)根据参数arch选择性初始化GFP-GAN

(3)初始化face helper

(4)增加了一个model路径是网址时的处理,然而需要的model已经下载到本地,并没有用到

(5)读取model并继续初始化

3.load_file_from_url( )


1.获取当前项目路径

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

os.path.abspath(file)获取当前文件的绝对路径C:\Users\Vaifer\Desktop\GFPGAN-v.0.2.1\gfpgan\utils.py

os.path.dirname()再获取该文件所在的目录路径C:\Users\Vaifer\Desktop\GFPGAN-v.0.2.1\gfpgan

最终应该得到C:\xxx\GFPGAN-v.0.2.

2.class GFOGANer ( )——init()

参数:(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None)

(1)优先选择在cupy+gpu上运行

self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

(2)根据参数arch选择性初始化GFP-GAN

  1. if arch == 'clean':
  2. self.gfpgan = GFPGANv1Clean(
  3. out_size=512,
  4. num_style_feat=512,
  5. channel_multiplier=channel_multiplier,
  6. decoder_load_path=None,
  7. fix_decoder=False,
  8. num_mlp=8,
  9. input_is_latent=True,
  10. different_w=True,
  11. narrow=1,
  12. sft_half=True)
  13. else:
  14. self.gfpgan = GFPGANv1(
  15. out_size=512,
  16. num_style_feat=512,
  17. channel_multiplier=channel_multiplier,
  18. decoder_load_path=None,
  19. fix_decoder=True,
  20. num_mlp=8,
  21. input_is_latent=True,
  22. different_w=True,
  23. narrow=1,
  24. sft_half=True)

可以看到分别调用了GFPGANv1Clean与GFPGANv1进行初始化,之后我们会具体分析这两个类

(3)初始化face helper

这边就使用到了facexlib包中的face restoration helper

  1. self.face_helper = FaceRestoreHelper(
  2. upscale,
  3. face_size=512,
  4. crop_ratio=(1, 1),
  5. det_model='retinaface_resnet50',
  6. save_ext='png',
  7. device=self.device)

(4)增加了一个model路径是网址时的处理,然而需要的model已经下载到本地,并没有用到

if model_path.startswith('https://'):

(5)读取model并继续初始化

  1. loadnet = torch.load(model_path)
  2. if 'params_ema' in loadnet:
  3. keyname = 'params_ema'
  4. else:
  5. keyname = 'params'
  6. self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
  7. self.gfpgan.eval()
  8. self.gfpgan = self.gfpgan.to(self.device)

3.load_file_from_url( )

从指定url中下载文件并读取的一个函数,简单介绍下

在读取model时如果路径是网址,就会调用这个函数下载相应的model

参数:(url, model_dir=None, progress=True, file_name=None)

  1. def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
  2. """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
  3. """
  4. hub_dir = get_dir()
  5. model_dir = os.path.join(hub_dir, 'checkpoints')
  6. print('hub_dir',hub_dir)
  7. print('model_dir',model_dir)
  8. if model_dir is None:
  9. hub_dir = get_dir()
  10. model_dir = os.path.join(hub_dir, 'checkpoints')
  11. #做路径的拼接,并递归创建目录
  12. os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
  13. parts = urlparse(url)
  14. filename = os.path.basename(parts.path)
  15. if file_name is not None:
  16. filename = file_name
  17. cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
  18. if not os.path.exists(cached_file):
  19. print(f'Downloading: "{url}" to {cached_file}\n')
  20. download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
  21. return cached_file

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/401857
推荐阅读
相关标签
  

闽ICP备14008679号