赞
踩
使用PyTorch实现的生成对抗网络(GAN)模型,包括编码器(Encoder)、解码器(Decoder)、生成器(ResnetGenerator)和判别器(Discriminator)。其中,编码器和解码器用于将输入图像进行编码和解码,生成器用于生成新的图像,判别器用于判断输入图像是真实的还是生成的。在训练过程中,生成器和判别器分别使用不同的损失函数进行优化。
image_paths = sorted([str(p) for p in glob('../input/celebahq-resized-256x256/celeba_hq_256' + '/*.jpg')]) # 定义数据预处理的transforms image_size = 128 # 数据预处理的transforms,将图像大小调整为image_size,并进行标准化 transforms = T.Compose([ T.Resize((image_size, image_size), Image.BICUBIC), T.ToTensor(), T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # to scale [-1,1] with tanh activation ]) inverse_transforms = T.Compose([ T.Normalize(-1, 2), T.ToPILImage() ]) # 划分训练集、验证集和测试集 train, valid = train_test_split(image_paths, test_size=5000, shuffle=True, random_state=seed) valid, test = train_test_split(valid, test_size=1000, shuffle=True, random_state=seed) # 输出数据集长度 print(f'Train size: {len(train)}, validation size: {len(valid)}, test size: {len(test)}.')
配置了批次、学习率、迭代、遮盖图像的大小、指定GPU等等
epochs = 30
batch_size = 16
lr = 8e-5
mask_size = 64
path = r'painting_model.pth'
b1 = 0.5
b2 = 0.999
patch_h, patch_w = int(mask_size / 2 ** 3), int(mask_size / 2 ** 3)
patch = (1, patch_h, patch_w)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#创建数据集
其中apply_center_mask: 将掩码应用于图像的中心部分,遮挡中心部分。该方法接受一个图像作为输入,并返回应用了掩码的图像和掩码区域的索引。
apply_random_mask(self, image): 将掩码随机应用于图像的某个区域。该方法接受一个图像作为输入,并返回应用了掩码的图像和被遮挡的部分。
class CelebaDataset(Dataset): def __init__(self, images_paths, transforms=transforms, train=True): self.images_paths = images_paths self.transforms = transforms self.train = train def __len__(self): return len(self.images_paths) def apply_center_mask(self, image): # 将mask应用于图像的中心部分//遮挡中心部分 idx = (image_size - mask_size) // 2 masked_image = image.clone() masked_image[:, idx:idx+mask_size, idx:idx+mask_size] = 1 masked_part = image[:, idx:idx+mask_size, idx:idx+mask_size] return masked_image, idx def apply_random_mask(self, image): # 将mask随机应用于图像的某个区域 y1, x1 = np.random.randint(0, image_size-mask_size, 2) y2, x2 = y1 + mask_size, x1 + mask_size masked_part = image[:, y1:y2, x1:x2] masked_image = image.clone() masked_image[:, y1:y2, x1:x2] = 1 return masked_image, masked_part def __getitem__(self, ix): path = self.images_paths[ix] image = Image.open(path) image = self.transforms(image) if self.train: masked_image, masked_part = self.apply_random_mask(image) else: masked_image, masked_part = self.apply_center_mask(image) return image, masked_image, masked_part def collate_fn(self, batch): images, masked_images, masked_parts = list(zip(*batch)) images, masked_images, masked_parts = [[tensor[None].to(device) for tensor in ims] for ims in [images, masked_images, masked_parts]] images, masked_images, masked_parts = [torch.cat(ims) for ims in [images, masked_images, masked_parts]] return images, masked_images, masked_parts # 创建数据集和数据加载器 train_dataset = CelebaDataset(train) valid_dataset = CelebaDataset(valid, train=True) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=train_dataset.collate_fn, drop_last=True) valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=valid_dataset.collate_fn, drop_last=True)
定义了初始化函数init_weights,用于初始化卷积层、反卷积层和批归一化层的权重。同时,还定义梯度更新函数set_params,用于设置模型参数是否需要梯度更新。
def init_weights(m):
if isinstance(m, nn.Conv2d) or isinstance
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。