当前位置:   article > 正文

基于GAN对抗网进行图像修复_gan图像修复代码

gan图像修复代码

一、简介

使用PyTorch实现的生成对抗网络(GAN)模型,包括编码器(Encoder)、解码器(Decoder)、生成器(ResnetGenerator)和判别器(Discriminator)。其中,编码器和解码器用于将输入图像进行编码和解码,生成器用于生成新的图像,判别器用于判断输入图像是真实的还是生成的。在训练过程中,生成器和判别器分别使用不同的损失函数进行优化。

二、相关技术

2.1数据准备


image_paths = sorted([str(p) for p in glob('../input/celebahq-resized-256x256/celeba_hq_256' + '/*.jpg')])

# 定义数据预处理的transforms
image_size = 128

# 数据预处理的transforms,将图像大小调整为image_size,并进行标准化
transforms = T.Compose([
    T.Resize((image_size, image_size), Image.BICUBIC),
    T.ToTensor(),
    T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))  # to scale [-1,1] with tanh activation
])

inverse_transforms = T.Compose([
    T.Normalize(-1, 2),
    T.ToPILImage()
])

# 划分训练集、验证集和测试集
train, valid = train_test_split(image_paths, test_size=5000, shuffle=True, random_state=seed)
valid, test = train_test_split(valid, test_size=1000, shuffle=True, random_state=seed)
# 输出数据集长度
print(f'Train size: {len(train)}, validation size: {len(valid)}, test size: {len(test)}.')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

2.2超参数的设置

配置了批次、学习率、迭代、遮盖图像的大小、指定GPU等等

epochs = 30
batch_size = 16
lr = 8e-5
mask_size = 64
path = r'painting_model.pth'
b1 = 0.5
b2 = 0.999
patch_h, patch_w = int(mask_size / 2 ** 3), int(mask_size / 2 ** 3)
patch = (1, patch_h, patch_w)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

2.3创建数据集

#创建数据集
其中apply_center_mask: 将掩码应用于图像的中心部分,遮挡中心部分。该方法接受一个图像作为输入,并返回应用了掩码的图像和掩码区域的索引。
apply_random_mask(self, image): 将掩码随机应用于图像的某个区域。该方法接受一个图像作为输入,并返回应用了掩码的图像和被遮挡的部分。

class CelebaDataset(Dataset):
    def __init__(self, images_paths, transforms=transforms, train=True):
        self.images_paths = images_paths
        self.transforms = transforms
        self.train = train
        
    def __len__(self):
        return len(self.images_paths)
    
    def apply_center_mask(self, image):
        # 将mask应用于图像的中心部分//遮挡中心部分
        idx = (image_size - mask_size) // 2
        masked_image = image.clone()
        masked_image[:, idx:idx+mask_size, idx:idx+mask_size] = 1
        masked_part = image[:, idx:idx+mask_size, idx:idx+mask_size]
        return masked_image, idx
    
    def apply_random_mask(self, image):
        # 将mask随机应用于图像的某个区域
        y1, x1 = np.random.randint(0, image_size-mask_size, 2)
        y2, x2 = y1 + mask_size, x1 + mask_size
        masked_part = image[:, y1:y2, x1:x2]
        masked_image = image.clone()
        masked_image[:, y1:y2, x1:x2] = 1
        return masked_image, masked_part
    
    def __getitem__(self, ix):
        path = self.images_paths[ix]
        image = Image.open(path)
        image = self.transforms(image)
        
        if self.train:
            masked_image, masked_part = self.apply_random_mask(image)
        else:
            masked_image, masked_part = self.apply_center_mask(image)
            
        return image, masked_image, masked_part
    
    def collate_fn(self, batch):
        images, masked_images, masked_parts = list(zip(*batch))
        images, masked_images, masked_parts = [[tensor[None].to(device) for tensor in ims] for ims in [images, masked_images, masked_parts]]
        images, masked_images, masked_parts = [torch.cat(ims) for ims in [images, masked_images, masked_parts]]
        return images, masked_images, masked_parts
        
 # 创建数据集和数据加载器
train_dataset = CelebaDataset(train)
valid_dataset = CelebaDataset(valid, train=True)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=train_dataset.collate_fn, drop_last=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=valid_dataset.collate_fn, drop_last=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50

2.4 构建神经网络

2.4.1定义初始化函数

定义了初始化函数init_weights,用于初始化卷积层、反卷积层和批归一化层的权重。同时,还定义梯度更新函数set_params,用于设置模型参数是否需要梯度更新。

def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/587665
推荐阅读
相关标签
  

闽ICP备14008679号