当前位置:   article > 正文

Deep3DFaceRecon_pytorch-master项目学习-util.py

deep3dfacerecon
"""This script contains basic utilities for Deep3DFaceRecon_pytorch
"""
from __future__ import print_function
import numpy as np
import torch
from PIL import Image
import os
import importlib
import argparse
from argparse import Namespace
import torchvision

#字符串转布尔
def str2bool(v):
	#isinstance用于检查给定对象是否是指定类的实例
    if isinstance(v, bool):
        return v
        
    #lower() 的应用。这个方法返回原始字符串的小写版本
    #然后检查它是否属于一个包含表示布尔值 "True" 的字符串的元组。
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
    	#抛出自定义异常
        raise argparse.ArgumentTypeError('Boolean value expected.')

#**kwargs任意数量的关键字参数
def copyconf(default_opt, **kwargs):
    conf = Namespace(**vars(default_opt))
    for key in kwargs:
        setattr(conf, key, kwargs[key])
    return conf

def genvalconf(train_opt, **kwargs):
    conf = Namespace(**vars(train_opt))
    attr_dict = train_opt.__dict__
    for key, value in attr_dict.items():
        if 'val' in key and key.split('_')[0] in attr_dict:
            setattr(conf, key.split('_')[0], value)

    for key in kwargs:
        setattr(conf, key, kwargs[key])

    return conf
        
def find_class_in_module(target_cls_name, module):
	#这一行将目标类名中的所有下划线(_)替换为空字符串,并将整个字符串转换为小写。
    target_cls_name = target_cls_name.replace('_', '').lower()
    #importlib 库的 import_module 函数来动态导入指定的模块。clslib 现在是一个指向该模块的引用
    clslib = importlib.import_module(module)
    cls = None
    #这一行遍历模块 clslib 的字典属性(__dict__),这个字典包含了模块中定义的所有变量、函数和类。name 是字典中的键(名称),clsobj 是相应的值(对象)。
    for name, clsobj in clslib.__dict__.items():
        if name.lower() == target_cls_name:
            cls = clsobj

    assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)

    return cls


def tensor2im(input_image, imtype=np.uint8):

    #这一行检查输入的图像是否为 NumPy 数组。如果不是 NumPy 数组。
    if not isinstance(input_image, np.ndarray):
    	#如果输入图像是一个 PyTorch 张量,
        if isinstance(input_image, torch.Tensor):  # get the data from a variable
            image_tensor = input_image.data
        else:
            return input_image
        #将 image_tensor 裁剪到 [0, 1] 范围,将其从 GPU(如果在 GPU 上)移动到 CPU,将其转换为 float 类型,然后将其从张量转换为 NumPy 数组,赋值给 image_numpy。
        image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy()

		#如果图像是灰度图,使用 np.tile() 函数将其复制成一个 3 通道的 RGB 图像
        if image_numpy.shape[0] == 1:  # grayscale to RGB
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        #使用 np.transpose() 函数将图像数组的维度顺序从 (通道, 高度, 宽度) 转换为 (高度, 宽度, 通道),并将像素值的范围从 [0, 1] 缩放到 [0, 255]。
        image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0  
    else:  # if it is a numpy array, do nothing
        image_numpy = input_image
    return image_numpy.astype(imtype)

#计算并打印神经网络中所有参数梯度的平均绝对值
def diagnose_network(net, name='network'):
    mean = 0.0
    count = 0
    #遍历网络中的每个参数。
    for param in net.parameters():
    	#如果当前参数的梯度非空,则计算梯度的绝对值,然后计算其平均值,并将其累加到 mean 变量中。
        if param.grad is not None:
            mean += torch.mean(torch.abs(param.grad.data))
            count += 1
    if count > 0:
        mean = mean / count
    print(name)
    print(mean)

#将输入的 NumPy 图像数组保存到磁盘上的指定路径
#aspect_ratio(宽高比,默认为 1.0,表示不改变宽高比)。
def save_image(image_numpy, image_path, aspect_ratio=1.0):
	#将 NumPy 图像数组转换为 PIL 图像对象:使用 PIL 提供的方法来操作和保存图像。
    image_pil = Image.fromarray(image_numpy)
    h, w, _ = image_numpy.shape

    if aspect_ratio is None:
        pass
    #如果宽高比大于 1.0,将图像宽度乘以宽高比,然后调整图像大小
    #使用 PIL 图像对象的 resize() 方法调整图像大小,新的宽度为 int(w * aspect_ratio)。
    #这里使用 BICUBIC 插值方法进行调整。
    elif aspect_ratio > 1.0:
        image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
    elif aspect_ratio < 1.0:
        image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
    image_pil.save(image_path)

#用于打印输入 NumPy 数组的一些统计信息
def print_numpy(x, val=True, shp=False):
    x = x.astype(np.float64)
    if shp:
        print('shape,', x.shape)
    if val:
        x = x.flatten()
        print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
            np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))

#用于创建一个或多个不存在的空目录。让我们逐句解释这段代码:
def mkdirs(paths):

    if isinstance(paths, list) and not isinstance(paths, str):
        for path in paths:
            mkdir(path)
    else:
        mkdir(paths)

#其目的是在指定路径不存在时创建一个空目录
def mkdir(path):
	#检查传入的 path 是否已经存在
    if not os.path.exists(path):
    	#创建目录
        os.makedirs(path)

#调整张量中每个图像的大小,并返回一个调整后的张量。
def correct_resize_label(t, size):
    device = t.device#获取输入张量 t 的设备(CPU 或 GPU)
    t = t.detach().cpu()
    resized = []
    for i in range(t.size(0)):
        one_t = t[i, :1]
        one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
        one_np = one_np[:, :, 0]
        one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
        #将调整大小后的 PIL.Image 对象转换回 NumPy 数组,然后转换为 Torch 张量,并将数据类型设置为 long
        resized_t = torch.from_numpy(np.array(one_image)).long()
        
        resized.append(resized_t)
    return torch.stack(resized, dim=0).to(device)

#它用于将输入的 PyTorch 张量(图像)调整到指定的尺寸
def correct_resize(t, size, mode=Image.BICUBIC):
    device = t.device
    t = t.detach().cpu()
    resized = []
    for i in range(t.size(0)):
        one_t = t[i:i + 1]
        one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)
        #将张量的值范围从 [0, 1] 调整为 [-1, 1]
        resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
        resized.append(resized_t)
    return torch.stack(resized, dim=0).to(device)


#用于在图像上绘制人脸关键点,landmark(关键点坐标)、step(用于定义关键点绘制时的像素步长,默认为 2)。
def draw_landmarks(img, landmark, color='r', step=2):

    if color =='r':
        c = np.array([255., 0, 0])
    else:
        c = np.array([0, 0, 255.])

    _, H, W, _ = img.shape
    img, landmark = img.copy(), landmark.copy()
    #landmark[..., 1] 选取了所有关键点的 y 坐标
    landmark[..., 1] = H - 1 - landmark[..., 1]
    #将关键点坐标四舍五入并转换为 int32 类型。
    landmark = np.round(landmark).astype(np.int32)
    #遍历所有关键点。
    for i in range(landmark.shape[1]):
        x, y = landmark[:, i, 0], landmark[:, i, 1]
        #遍历关键点周围的像素,从 -step 到 step(不包括 step)。
        for j in range(-step, step):
        	#遍历关键点周围的像素
            for k in range(-step, step):
            	#计算新的 x 坐标 u,将其限制在图像宽度范围内
                u = np.clip(x + j, 0, W - 1)
                #计算新的 y 坐标 v,将其限制在图像高度范围内。
                v = np.clip(y + k, 0, H - 1)
                for m in range(landmark.shape[0]):
                    img[m, v[m], u[m]] = c
    return img

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号