赞
踩
NAFNet-GoPro-width64.pdparams
, 主要用于运动模糊图像的去除NAFNet-REDS-width64.pdparams
,主要用于有压缩损失的模糊图像恢复# 克隆仓库,该步骤犹豫由于外网限速,比较慢,所以可以直接使用已经克隆下来的文件夹,不必执行本步骤
# !git clone https://github.com/PaddlePaddle/PaddleGAN
正克隆到 'PaddleGAN'...
remote: Enumerating objects: 5401, done.[K
remote: Counting objects: 100% (203/203), done.[K
remote: Compressing objects: 100% (159/159), done.[K
remote: Total 5401 (delta 101), reused 95 (delta 41), pack-reused 5198[K
接收对象中: 100% (5401/5401), 163.52 MiB | 10.71 MiB/s, 完成.
处理 delta 中: 100% (3499/3499), 完成.
检查连接... 完成。
# 安装依赖
%cd PaddleGAN/
!pip install -r requirements.txt
import cv2 from glob import glob from natsort import natsorted import numpy as np import os from tqdm import tqdm import paddle from ppgan.models.generators import NAFNetLocal from ppgan.utils.download import get_path_from_url from ppgan.apps.base_predictor import BasePredictor # 模型参数定义 model_cfgs = { 'Deblur': { 'img_channel': 3, 'width': 64, 'enc_blk_nums': [1, 1, 1, 28], 'middle_blk_num': 1, 'dec_blk_nums': [1, 1, 1, 1] } } # 定义去模糊的预测类 class NAFNetDeblurer(BasePredictor): def __init__(self, output_path='output_dir', weight_path=None): self.output_path = output_path task = 'Deblur' self.task = task checkpoint = paddle.load(weight_path) self.generator = NAFNetLocal( img_channel=model_cfgs[task]['img_channel'], width=model_cfgs[task]['width'], enc_blk_nums=model_cfgs[task]['enc_blk_nums'], middle_blk_num=model_cfgs[task]['middle_blk_num'], dec_blk_nums=model_cfgs[task]['dec_blk_nums']) self.generator.set_state_dict(checkpoint) self.generator.eval() def get_images(self, images_path): if os.path.isdir(images_path): return natsorted( glob(os.path.join(images_path, '*.jpeg')) + glob(os.path.join(images_path, '*.jpg')) + glob(os.path.join(images_path, '*.JPG')) + glob(os.path.join(images_path, '*.png')) + glob(os.path.join(images_path, '*.PNG'))) else: return [images_path] def imread_uint(self, path, n_channels=3): # input: path # output: HxWx3(RGB or GGG), or HxWx1 (G) if n_channels == 1: img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE img = np.expand_dims(img, axis=2) # HxWx1 elif n_channels == 3: img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G if img.ndim == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG else: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB return img def uint2single(self, img): return np.float32(img / 255.) # convert single (HxWxC) to 3-dimensional paddle tensor def single2tensor3(self, img): return paddle.Tensor(np.ascontiguousarray( img, dtype=np.float32)).transpose([2, 0, 1]) def run(self, images_path=None): os.makedirs(self.output_path, exist_ok=True) task_path = os.path.join(self.output_path, self.task) os.makedirs(task_path, exist_ok=True) image_files = self.get_images(images_path) for image_file in tqdm(image_files): img_L = self.imread_uint(image_file, 3) image_name = os.path.basename(image_file) img = cv2.cvtColor(img_L, cv2.COLOR_RGB2BGR) cv2.imwrite(os.path.join(task_path, image_name), img) tmps = image_name.split('.') assert len( tmps) == 2, f'Invalid image name: {image_name}, too much "."' restoration_save_path = os.path.join( task_path, f'{tmps[0]}_restoration.{tmps[1]}') img_L = self.uint2single(img_L) # HWC to CHW, numpy to tensor img_L = self.single2tensor3(img_L) img_L = img_L.unsqueeze(0) with paddle.no_grad(): output = self.generator(img_L) restored = paddle.clip(output, 0, 1) restored = restored.numpy() restored = restored.transpose(0, 2, 3, 1) restored = restored[0] restored = restored * 255 restored = restored.astype(np.uint8) cv2.imwrite(restoration_save_path, cv2.cvtColor(restored, cv2.COLOR_RGB2BGR)) print('Done, output path is:', task_path)
注:本项目示范所用权重为基于REDS
数据集训练的权重
# 定义输出路径
output_path = r"../work/output"
# 定义权重所在路径
weight_path = r"../data/data174576/NAFNet-REDS-width64.pdparams"
# 定义去模糊类
deblur_predictor = NAFNetDeblurer(output_path, weight_path)
W1030 19:54:46.673647 192 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1030 19:54:46.677815 192 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
# 定义输入路径
input_path = r"../work/inputs/"
# 执行预测
deblur_predictor.run(images_path=input_path)
100%|██████████| 3/3 [00:03<00:00, 1.42s/it]
Done, output path is: ../work/output/Deblur
# 展示预测的结果 import numpy as np import cv2 import matplotlib.pyplot as plt %matplotlib inline def imread(img_path): img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img def display(img1, img2): fig = plt.figure(figsize=(25, 10)) ax1 = fig.add_subplot(1, 2, 1) plt.title('Input image', fontsize=16) ax1.axis('off') ax2 = fig.add_subplot(1, 2, 2) plt.title('NAFNet output', fontsize=16) ax2.axis('off') ax1.imshow(img1) ax2.imshow(img2)
input_path = '../work/inputs/blurry-reds-1.jpg'
output_path = '../work/output/Deblur/blurry-reds-1_restoration.jpg'
img_input = imread(input_path)
img_output = imread(output_path)
display(img_input, img_output)
class CropPredictor(NAFNetDeblurer): def __init__(self, output_path='output_dir', weight_path=None): super(CropPredictor, self).__init__(output_path, weight_path) def crop_predict(self, img_lq): sf = self.sf tile = self.tile overlap = self.overlap b, c, h, w = img_lq.shape tile_overlap = overlap stride = tile - tile_overlap h_idx_list = list(range(0, h-tile, stride)) + [h-tile] w_idx_list = list(range(0, w-tile, stride)) + [w-tile] E = paddle.zeros([b, c, h*sf, w*sf], dtype=img_lq.dtype) W = paddle.zeros_like(E) for h_idx in h_idx_list: for w_idx in w_idx_list: h_idx = int(h_idx) w_idx = int(w_idx) in_patch = img_lq[:, :,h_idx:h_idx+tile, w_idx:w_idx+tile] out_patch = self.generator(in_patch) out_patch_mask = paddle.ones_like(out_patch) E[:, :, h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf] += out_patch W[:, :, h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf] += out_patch_mask output = E.divide(W) return output def run_patches(self, images_path=None, tile=1024, overlap=128): os.makedirs(self.output_path, exist_ok=True) task_path = os.path.join(self.output_path, self.task) os.makedirs(task_path, exist_ok=True) image_files = self.get_images(images_path) self.tile = tile self.overlap = overlap self.sf = 1 for image_file in tqdm(image_files): img_L = self.imread_uint(image_file, 3) image_name = os.path.basename(image_file) img = cv2.cvtColor(img_L, cv2.COLOR_RGB2BGR) cv2.imwrite(os.path.join(task_path, image_name), img) tmps = image_name.split('.') assert len( tmps) == 2, f'Invalid image name: {image_name}, too much "."' restoration_save_path = os.path.join( task_path, f'{tmps[0]}_restoration.{tmps[1]}') img_L = self.uint2single(img_L) # HWC to CHW, numpy to tensor img_L = self.single2tensor3(img_L) img_L = img_L.unsqueeze(0) with paddle.no_grad(): output = self.crop_predict(img_L) restored = paddle.clip(output, 0, 1) restored = restored.numpy() restored = restored.transpose(0, 2, 3, 1) restored = restored[0] restored = restored * 255 restored = restored.astype(np.uint8) cv2.imwrite(restoration_save_path, cv2.cvtColor(restored, cv2.COLOR_RGB2BGR)) print('Done, output path is:', task_path)
# 定义滑窗预测输出路径
crop_output_path = r"../work/crop_output"
# 定义权重路径
weight_path = r"../data/data174576/NAFNet-REDS-width64.pdparams"
# 定义滑窗去模糊类
croppredictor = CropPredictor(crop_output_path, weight_path=weight_path)
W1102 21:20:36.032223 212 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1102 21:20:36.036273 212 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
# 大图所在路径
bigimages_path = r"../work/big_inputs/"
# 开始预测
croppredictor.run_patches(images_path=bigimages_path, tile=1024, overlap=128)
100%|██████████| 1/1 [00:08<00:00, 8.01s/it]
Done, output path is: ../work/crop_output/Deblur
# 展示效果
input_path = '../work/big_inputs/beautiful.png'
output_path = '../work/crop_output/Deblur/beautiful_restoration.png'
img_input = imread(input_path)
img_output = imread(output_path)
display(img_input, img_output)
请点击此处查看本环境基本用法.
Please click here for more detailed instructions.
此文章为搬运
原项目链接
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。