当前位置:   article > 正文

使用PaddleGAN中的NAFNet进行图像去模糊

nafnet

使用PaddleGAN中的NAFNet进行图像去模糊

1. 项目简介

1.1 项目背景

  • NAFNet是旷视研究院提出的用于图像复原的模型,在图像去模糊、去噪都取得了很好的性能,不仅计算高效同时性能优于之前SOTA方案,效果如下图所示。在双目超分任务上,基于NAFNet的双目超分模型NAFSSR获得NTIRE 2022的双目超分赛道冠军

在这里插入图片描述

1.2 项目目的

  • 尽管PaddleGAN中合入的是与去噪有关的训练、测试、预测等代码,但是NAFNet网络也已经放入repo中,稍作修改就可以体验其在去模糊任务上的性能
  • 本项目不涉及模型的训练,只是将NAFNet在GoPro与在REDS数据集上训练得到的两个最佳权重转换为paddle的权重,基于PaddleGAN中的NAFNet去进行图像去模糊
  • 对于torch权重转paddle权重,本项目不再赘述,代码都类似,可参考针对真实图像退化的盲图像超分BSRGAN复现,以及PPSIG:PMBANet深度图超分辨率重建模型复现

2. 如何使用

  • 首先我将NAFNet的deblur权重转换为Paddle的之后并挂载在项目的数据集中,一共有两个权重:
    • 在GoPro数据集上训练得到的NAFNet-GoPro-width64.pdparams, 主要用于运动模糊图像的去除
    • 在REDS数据集上训练得到的NAFNet-REDS-width64.pdparams,主要用于有压缩损失的模糊图像恢复
  • 接下来则是基于PaddleGAN来调用该权重,对我们手上的影像去模糊。Follow me!
# 克隆仓库,该步骤犹豫由于外网限速,比较慢,所以可以直接使用已经克隆下来的文件夹,不必执行本步骤
# !git clone https://github.com/PaddlePaddle/PaddleGAN
  • 1
  • 2
正克隆到 'PaddleGAN'...
remote: Enumerating objects: 5401, done.[K
remote: Counting objects: 100% (203/203), done.[K
remote: Compressing objects: 100% (159/159), done.[K
remote: Total 5401 (delta 101), reused 95 (delta 41), pack-reused 5198[K
接收对象中: 100% (5401/5401), 163.52 MiB | 10.71 MiB/s, 完成.
处理 delta 中: 100% (3499/3499), 完成.
检查连接... 完成。
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
# 安装依赖
%cd PaddleGAN/
!pip install -r requirements.txt
  • 1
  • 2
  • 3
import cv2
from glob import glob
from natsort import natsorted
import numpy as np
import os
from tqdm import tqdm

import paddle

from ppgan.models.generators import NAFNetLocal
from ppgan.utils.download import get_path_from_url
from ppgan.apps.base_predictor import BasePredictor

# 模型参数定义
model_cfgs = {
    'Deblur': {
        'img_channel': 3,
        'width': 64,
        'enc_blk_nums': [1, 1, 1, 28],
        'middle_blk_num': 1,
        'dec_blk_nums': [1, 1, 1, 1]
    }
}

# 定义去模糊的预测类
class NAFNetDeblurer(BasePredictor):

    def __init__(self,
                 output_path='output_dir',
                 weight_path=None):
        self.output_path = output_path
        task = 'Deblur'
        self.task = task

        checkpoint = paddle.load(weight_path)

        self.generator = NAFNetLocal(
            img_channel=model_cfgs[task]['img_channel'],
            width=model_cfgs[task]['width'],
            enc_blk_nums=model_cfgs[task]['enc_blk_nums'],
            middle_blk_num=model_cfgs[task]['middle_blk_num'],
            dec_blk_nums=model_cfgs[task]['dec_blk_nums'])

        self.generator.set_state_dict(checkpoint)
        self.generator.eval()

    def get_images(self, images_path):
        if os.path.isdir(images_path):
            return natsorted(
                glob(os.path.join(images_path, '*.jpeg')) +
                glob(os.path.join(images_path, '*.jpg')) +
                glob(os.path.join(images_path, '*.JPG')) +
                glob(os.path.join(images_path, '*.png')) +
                glob(os.path.join(images_path, '*.PNG')))
        else:
            return [images_path]

    def imread_uint(self, path, n_channels=3):
        #  input: path
        # output: HxWx3(RGB or GGG), or HxWx1 (G)
        if n_channels == 1:
            img = cv2.imread(path, 0)  # cv2.IMREAD_GRAYSCALE
            img = np.expand_dims(img, axis=2)  # HxWx1
        elif n_channels == 3:
            img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # BGR or G
            if img.ndim == 2:
                img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)  # GGG
            else:
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # RGB

        return img

    def uint2single(self, img):

        return np.float32(img / 255.)

    # convert single (HxWxC) to 3-dimensional paddle tensor
    def single2tensor3(self, img):
        return paddle.Tensor(np.ascontiguousarray(
            img, dtype=np.float32)).transpose([2, 0, 1])

    def run(self, images_path=None):
        os.makedirs(self.output_path, exist_ok=True)
        task_path = os.path.join(self.output_path, self.task)
        os.makedirs(task_path, exist_ok=True)
        image_files = self.get_images(images_path)
        for image_file in tqdm(image_files):
            img_L = self.imread_uint(image_file, 3)

            image_name = os.path.basename(image_file)
            img = cv2.cvtColor(img_L, cv2.COLOR_RGB2BGR)
            cv2.imwrite(os.path.join(task_path, image_name), img)

            tmps = image_name.split('.')
            assert len(
                tmps) == 2, f'Invalid image name: {image_name}, too much "."'
            restoration_save_path = os.path.join(
                task_path, f'{tmps[0]}_restoration.{tmps[1]}')

            img_L = self.uint2single(img_L)

            # HWC to CHW, numpy to tensor
            img_L = self.single2tensor3(img_L)
            img_L = img_L.unsqueeze(0)
            with paddle.no_grad():
                output = self.generator(img_L)

            restored = paddle.clip(output, 0, 1)

            restored = restored.numpy()
            restored = restored.transpose(0, 2, 3, 1)
            restored = restored[0]
            restored = restored * 255
            restored = restored.astype(np.uint8)

            cv2.imwrite(restoration_save_path,
                        cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))

        print('Done, output path is:', task_path)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120

2.1 普通图像的去模糊

  • 一般图像的大小不会超过4k,所以可以直接将图像送入网络中,执行以下操作即可

:本项目示范所用权重为基于REDS数据集训练的权重

# 定义输出路径
output_path = r"../work/output"
# 定义权重所在路径
weight_path = r"../data/data174576/NAFNet-REDS-width64.pdparams" 
# 定义去模糊类
deblur_predictor = NAFNetDeblurer(output_path, weight_path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
W1030 19:54:46.673647   192 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1030 19:54:46.677815   192 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
  • 1
  • 2
# 定义输入路径
input_path = r"../work/inputs/"
# 执行预测
deblur_predictor.run(images_path=input_path)
  • 1
  • 2
  • 3
  • 4
100%|██████████| 3/3 [00:03<00:00,  1.42s/it]

Done, output path is: ../work/output/Deblur
  • 1
  • 2
  • 3
  • 对预测的结果进行展示
# 展示预测的结果
import numpy as np
import cv2
import matplotlib.pyplot as plt
%matplotlib inline

def imread(img_path):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def display(img1, img2):
    fig = plt.figure(figsize=(25, 10))
    ax1 = fig.add_subplot(1, 2, 1) 
    plt.title('Input image', fontsize=16)
    ax1.axis('off')
    ax2 = fig.add_subplot(1, 2, 2)
    plt.title('NAFNet output', fontsize=16)
    ax2.axis('off')
    ax1.imshow(img1)
    ax2.imshow(img2)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
input_path = '../work/inputs/blurry-reds-1.jpg'
output_path = '../work/output/Deblur/blurry-reds-1_restoration.jpg'

img_input = imread(input_path)
img_output = imread(output_path)
display(img_input, img_output)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

在这里插入图片描述

2.2 超过4k的大图去模糊

  • 对于现实中要恢复的比较大的影像,直接预测会导致爆显存(Out of Memory, OOM),所以要切块预测
  • 继承刚才设置的去模糊预测类,新增切块预测类如下
class CropPredictor(NAFNetDeblurer):
    def __init__(self,
                 output_path='output_dir',
                 weight_path=None):
        super(CropPredictor, self).__init__(output_path, weight_path)

    def crop_predict(self, img_lq):
        sf = self.sf
        tile = self.tile
        overlap = self.overlap
        b, c, h, w = img_lq.shape
        tile_overlap = overlap
        stride = tile - tile_overlap
        h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
        w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
        E = paddle.zeros([b, c, h*sf, w*sf], dtype=img_lq.dtype)
        W = paddle.zeros_like(E)

        for h_idx in h_idx_list:
            for w_idx in w_idx_list:
                h_idx = int(h_idx)
                w_idx = int(w_idx)
                in_patch = img_lq[:, :,h_idx:h_idx+tile, w_idx:w_idx+tile]
                out_patch = self.generator(in_patch)
                out_patch_mask = paddle.ones_like(out_patch)

                E[:, :, h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf] += out_patch
                W[:, :, h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf] += out_patch_mask

        output = E.divide(W)
        return output

    def run_patches(self, images_path=None, tile=1024, overlap=128):
        os.makedirs(self.output_path, exist_ok=True)
        task_path = os.path.join(self.output_path, self.task)
        os.makedirs(task_path, exist_ok=True)
        image_files = self.get_images(images_path)
        self.tile = tile
        self.overlap = overlap
        self.sf = 1

        for image_file in tqdm(image_files):
            img_L = self.imread_uint(image_file, 3)

            image_name = os.path.basename(image_file)
            img = cv2.cvtColor(img_L, cv2.COLOR_RGB2BGR)
            cv2.imwrite(os.path.join(task_path, image_name), img)

            tmps = image_name.split('.')
            assert len(
                tmps) == 2, f'Invalid image name: {image_name}, too much "."'
            restoration_save_path = os.path.join(
                task_path, f'{tmps[0]}_restoration.{tmps[1]}')

            img_L = self.uint2single(img_L)

            # HWC to CHW, numpy to tensor
            img_L = self.single2tensor3(img_L)
            img_L = img_L.unsqueeze(0)
            with paddle.no_grad():
                output = self.crop_predict(img_L)

            restored = paddle.clip(output, 0, 1)

            restored = restored.numpy()
            restored = restored.transpose(0, 2, 3, 1)
            restored = restored[0]
            restored = restored * 255
            restored = restored.astype(np.uint8)

            cv2.imwrite(restoration_save_path,
                        cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))

        print('Done, output path is:', task_path)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
# 定义滑窗预测输出路径
crop_output_path = r"../work/crop_output"
# 定义权重路径
weight_path = r"../data/data174576/NAFNet-REDS-width64.pdparams" 
# 定义滑窗去模糊类
croppredictor = CropPredictor(crop_output_path, weight_path=weight_path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
W1102 21:20:36.032223   212 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1102 21:20:36.036273   212 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
  • 1
  • 2
# 大图所在路径
bigimages_path = r"../work/big_inputs/"

# 开始预测
croppredictor.run_patches(images_path=bigimages_path, tile=1024, overlap=128)
  • 1
  • 2
  • 3
  • 4
  • 5
100%|██████████| 1/1 [00:08<00:00,  8.01s/it]

Done, output path is: ../work/crop_output/Deblur
  • 1
  • 2
  • 3
# 展示效果
input_path = '../work/big_inputs/beautiful.png'
output_path = '../work/crop_output/Deblur/beautiful_restoration.png'

img_input = imread(input_path)
img_output = imread(output_path)
display(img_input, img_output)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

在这里插入图片描述

3. 总结

  • 本项目介绍了在我们有去模糊的任务需求时,如何使用已经合入PaddleGAN的NAFNet,对模糊图像进行恢复,其实在日常生活中还是挺实用的
  • NAFNet还可以进行双目超分,有机会把这个也做出来,挖个坑~

请点击此处查看本环境基本用法.

Please click here for more detailed instructions.

此文章为搬运
原项目链接

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
  

闽ICP备14008679号