赞
踩
昇腾AI应用,探索人工智能的无限可能,使能千行百业
以Generative Adversarial Networks(GAN)为基础,其架构包括一个生成器(Generator)和一个判别器(Discriminator),通过迭代地训练两个网络(即生成器和判别器),由判别器提供的对抗性损失可以对修复的图像进行真假判别。系统整体结构如下图。
系统将生成器替换为两阶的自编码结构。第一阶的自编码器是一个粗糙的自编码器,用来生成待修复部分的图像的大体轮廓,自编码器在训练时记录了大量的图像信息,即使图像部分缺失,也具有重建图像的能力。但自编码器生成的图像会模糊,这是自编码器的固有缺陷。假定擦除区域的图像是Mask内的图像,此时自编码器修复出来的Mask内的图像会非常的模糊,之后再将该图像送入到第二阶的自编码器。
第二阶自编码器是一个精细的自编码器,会对上阶生成的修复的Mask内的图像进行精细加工,使得该区域图像变得清晰。该阶自编码器的原理是将图像切成指定数量的Patch,比如对大小为512*512的原始图像切成32*32个相同大小的Patch,那么每个Patch大小是16*16。在该阶的编码器高层将特征图也切成32*32个Patch,比如在高层的特征图大小是96*96,Channel数为256,此时在高层用3*3的卷积核对每个Patch内进行特征提取,这样每个Patch会生成一个256维的向量,这256维的向量记录了原图对应Patch的特征。当两个Patch的特征相似时它们对应的256维向量的余弦相似度就会较大,用这个原理可以生成Patch之间的相关性,这个两两之间的相关性可以用注意力矩阵来表示,反应了图片内所有Patch之间的相似程度,相似度越大注意力分数越高。
修复Mask内的图像就是要利用相似度从Mask外找到与Mask内待修复Patch相似性高的Patch,在编码器的低层将Mask外的Patch与以对应的注意力分数作为权值相乘后再相加得到待修补Patch内的特征值。
对于生成器,希望它尽可能生成真实清晰的图像,而判别器则希望尽可能的区分真实样本与生成样本,以促使生成器尽可能生成更真实清晰的图像。
由于处理的是高分辨率的图像,修补的图像可能还不那么清晰,接下来要用图像锐化的原理将图像修补的区域进行增强,锐化需要获取图片的细节,细节是由原图减去下采样再上采样生成的模糊图片得到的,此时得到的细节图片的Mask区域需要再进一步增强,增强的原理是把细节图片也分成32*32个Patch再利用上面得到的注意力矩阵,加权求和Mask外的细节特征到Mask内,对增强的Mask内细节区域的图片再加到修复后图像对应的Mask内的区域,使得修复后的图像更加清晰自然。
基于GAN的对抗性训练方式,我们列出生成器和判别器的损失函数。如下是判别器损失函数,就是要让生成的图片分布的期望尽可能小,真实图片分布的期望尽可能大。
如下是生成器的损失函数,由两部分组成,一部分是重构损失,重构损失也由两部分组成分别是Mask内和Mask外的重构损失,生成器另一部分是生成器损失,就是要让生成器生成的图片分布期望尽可能大。
如下图所示,昇腾软件栈中存在一个ATC模型转换工具,针对本应用,我们需要使用该工具将原始模型转换成系统支持的om模型。
本应用采用了下图所示的模块化设计,通过各模块之间的协调配合完成一张图片的推理输出。
其中各个模块的主要功能点如下所示:
图像预处理
预处理部分首先使用OpenCV读取原图和Mask图,将原图和Mask图进行大小缩放,缩放至3072*3072大小,之后再将原图像和Mask图像缩小到512*512用于送入模型进行推理,之所以将图片缩小后送入模型推理是为了节省算力和内存空间,加速推理时间;将读取到的图像数据拷贝至设备侧申请的内存空间中,为接下来构建模型输入数据做好准备。最后分别得到3072*3072和512*512的原图和Mask图。
上述功能函数原型为:
1. def pre_process(raw_img, raw_Mask):
参数说明:
raw_img[in]: 原始图像
raw_Mask[in] : Mask图像
原图像和Mask图像函数定义及相关源码注释如下所示:
- def pre_process(raw_img, raw_Mask):
- # normalization Mask
- raw_Mask = raw_Mask.astype(NPTYPE_FLOAT32) / 255.
- raw_img = raw_img.astype(NPTYPE_FLOAT32)
-
- # resize raw image & Mask to desinated size
- large_img = cv2.resize(raw_img, (MULTIPLE * INPUT_SIZE, MULTIPLE * INPUT_SIZE), interpolation = cv2. INTER_LINEAR)
- large_Mask = cv2.resize(raw_Mask, (MULTIPLE * INPUT_SIZE, MULTIPLE * INPUT_SIZE), interpolation = cv2.INTER_NEAREST)
-
- # down-sample large image & Mask to 512x512
- small_img = resize_ave(large_img, MULTIPLE)
- small_Mask = cv2.resize(raw_Mask, (INPUT_SIZE, INPUT_SIZE), interpolation = cv2.INTER_NEAREST)
-
- # set hole region to 1. and backgroun to 0.
- small_Mask = 1. - small_Mask
- return large_img, large_Mask, small_img, small_Mask
由于模型接受输入的图像是NHWC所以这里要做个转换
- #input to om model should be NHWC
- img_512_hwc = np.ascontiguousarray(img_512)
- Mask_512_hwc = Mask_512[:,:,0:1]
- Mask_512_hwc = Mask_512_hwc.transpose(2,0,1).copy()
模型推理结果后处理
模型的推理结果后处理函数原型为:
def post_process(model,raw_img, large_img, large_Mask, inpainted_512, img_512, Mask_512, attention):
后处理是整个实验中最复杂的部分,主要是对模型生成器修复的图像对修复区域进行增强,增强方式是采用了图像锐化的原理,将修复图像的修复区域与原图细节抽取的图片对应的区域相加,另外本实验中利用注意力矩阵对原图细节图片也做了进一步加工,使得细节图片包含更丰富的信息,以此得到锐化后的修复区域图像更清晰。这部分代码实现使用了很多的技巧,接下来会对该段代码作重点讲解。
这段代码首先将修复后的图片和原始的512*512的图片分别放大到3076*3076,之后再将放大后的两张图片相减得到3076*3076的细节图片,再将细节图片乘以放大的3076*3076的Mask图片之后得到去除了Mask区域的细节图片,将处理后的细节图片再使用注意力矩阵利用Mask外的细节图片信息对Mask内的细节图片进行修补,生成细节图片Mask内区域的图片。再将细节图片与修复后放大的图片相加,以达到对修复区域的图片锐化的效果,使得修复的区域变得更清晰,之后再分别将锐化后的图片和Mask图片缩放到与原图尺寸大小一致,最后再将缩放后锐化的图片的Mask内区域与原图Mask外的区域合并组成一张完整的图像。
后处理代码如下:
- def post_process(model,raw_img, large_img, large_Mask, inpainted_512, img_512, Mask_512, attention):
- # compute the raw residual map
- h, w, c = raw_img.shape
- low_base = cv2.resize(inpainted_512.astype(NPTYPE_FLOAT32), (INPUT_SIZE * MULTIPLE, INPUT_SIZE * MULTIPLE), interpolation = cv2.INTER_LINEAR)
- low_large = cv2.resize(img_512.astype(NPTYPE_FLOAT32), (INPUT_SIZE * MULTIPLE, INPUT_SIZE * MULTIPLE), interpolation = cv2.INTER_LINEAR)
- residual = (large_img - low_large) * large_Mask
-
- # reconstruct residual map using residual aggregation module
- residual = residual_aggregate(model,residual, attention)
-
- # compute large inpainted result
- res_large = low_base + residual
- res_large = np.clip(res_large, 0., 255.)
-
- # resize large inpainted result to raw size
- res_raw = cv2.resize(res_large, (w, h), interpolation = cv2.INTER_LINEAR)
-
- # paste the hole region to the original raw image
- Mask = cv2.resize(Mask_512.astype(NPTYPE_FLOAT32), (w, h), interpolation = cv2.INTER_LINEAR)
- Mask = np.expand_dims(Mask, axis=2)
-
- res_raw = res_raw * Mask + raw_img * (1. - Mask)
- return res_raw.astype(np.uint8)
在上面的后处理中用到矩阵乘法加速功能,由于细节图片有3072*3072大小,而且计算过程复杂,如果在CPU中对这么大的图片做矩阵运算会非常的耗时,所以这里使用了ACL对外提供的矩阵乘法单算子接口,利用NPU强大的算力来加速大矩阵的相乘运算。
首先在extract_image_Patches中将细节图像按宽高等距切割成32*32个Patch,这样每个Patch大小为96*96(3072/32 = 96),考虑到一张细节图有3个channel, 所以每个Patch有96*96*3=27648个像素, 再将这32*32个Patch按顺序排成一列一共有1024列,再把每个Patch所有像素按序拉成一行,共有27648行,由此组成了一个1024*27648的矩阵。由于attention矩阵大小为1024*1024,如果直接让attention矩阵(1024*1024)与reshape后的细节图矩阵(1024*27648)相乘有可能会把NPU的内存撑爆,所以这里就把reshape后的细节图矩阵平均分成了9份,每份大小为1024*3072(27648/9=3072),这样每次在NPU中实现1024*1024和1024*3072两个矩阵相乘,具体见matmul_ex接口。
reconstruct_residual_from_Patches是将上面reshape的细节图矩阵还原成之前的模样,是之前操作的一个逆过程,最终得到3072*3072*3的细节图像。
GenOutputImage函数定义及相关源码注释如下所示:
- # residual aggregation module
- def residual_aggregate(model,residual, attention):
- residual = extract_image_Patches(residual, MULTIPLE * INPUT_SIZE//ATTENTION_SIZE)
- residual = np.reshape(residual, [1, residual.shape[0] * residual.shape[1], -1])
- residual = matmul_om(model,attention,residual)
- #residual = np.matmul(attention, residual)
- residual = reconstruct_residual_from_Patches(residual, MULTIPLE * INPUT_SIZE//ATTENTION_SIZE)
- return residual
-
- # extract image Patches
- def extract_image_Patches(img, multiple):
- h, w, c = img.shape
- img = np.reshape(img, [h//multiple, multiple, w//multiple, multiple, c])
- img = np.transpose(img, [0,2,1,3,4])
- return img
-
- # reconstruct residual from Patches
- def reconstruct_residual_from_Patches(residual, multiple):
- residual = np.reshape(residual, [ATTENTION_SIZE, ATTENTION_SIZE, multiple, multiple, 3])
- residual = np.transpose(residual, [0,2,1,3,4])
- return np.reshape(residual, [ATTENTION_SIZE * multiple, ATTENTION_SIZE * multiple, 3])
效果展示
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。