赞
踩
BargainNet是bcmi的一个项目。具体项目介绍见GitHub链接。出于各种原因需要使用BargainNet,因为有些不习惯用命令行启动训练模型,所以将里面使用的默认模型、参数直接提取出来,简化成了简单的“读取数据”和“训练模型”两个文件。
训练数据的文件结构如下(去不掉水印我也很烦):
IHD_train.txt的结构很简单,就是文件列表而已:
其他的就是读取数据部分的代码和模型代码放在同一文件夹,改一下读取数据代码里数据集的路径就可以
文件名为:HarmonyDataset.py,方便模型导入
import os.path
import random
from abc import ABC
import cv2.cv2 as cv2
import numpy as np
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from albumentations import HorizontalFlip, RandomResizedCrop, Compose, DualTransform, ToGray
class HCompose(Compose): def __init__(self, transforms, *args, additional_targets=None, no_nearest_for_masks=True, **kwargs): if additional_targets is None: additional_targets = { 'real': 'image', 'mask': 'mask' } self.additional_targets = additional_targets super().__init__(transforms, *args, additional_targets=additional_targets, **kwargs) if no_nearest_for_masks: for t in transforms: if isinstance(t, DualTransform): t._additional_targets['mask'] = 'image' def get_transform(params=None, no_flip=True, grayscale=False): transform_list = [] if grayscale: transform_list.append(ToGray()) if params is None: transform_list.append(RandomResizedCrop(512, 512, scale=(0.5, 1.0))) if not no_flip: if params is None: transform_list.append(HorizontalFlip()) return HCompose(transform_list) class Iharmony4Dataset(data.Dataset, ABC): def __init__(self, dataset_root,): self.image_paths = [] print('loading training file: ') self.keep_background_prob = 0.05 self.file = dataset_root.replace("com", "") + 'IHD_train.txt' with open(self.file, 'r') as f: for line in f.readlines(): self.image_paths.append(os.path.join(dataset_root, line.rstrip())) self.transform = get_transform() self.input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) def __getitem__(self, index): sample = self.get_sample(index) self.check_sample_types(sample) sample = self.augment_sample(sample) comp = self.input_transform(sample['image']) real = self.input_transform(sample['real']) mask = sample['mask'].astype(np.float32) mask = mask[np.newaxis, ...].astype(np.float32) output = { 'comp': comp.unsqueeze(0), 'mask': torch.from_numpy(mask).unsqueeze(0), 'real': real.unsqueeze(0), 'img_path': sample['img_path'] } return output def check_sample_types(self, sample): assert sample['comp'].dtype == 'uint8' if 'real' in sample: assert sample['real'].dtype == 'uint8' def augment_sample(self, sample): if self.transform is None: return sample additional_targets = {target_name: sample[target_name] for target_name in self.transform.additional_targets.keys()} valid_augmentation = False while not valid_augmentation: aug_output = self.transform(image=sample['comp'], **additional_targets) valid_augmentation = self.check_augmented_sample(aug_output) for target_name, transformed_target in aug_output.items(): sample[target_name] = transformed_target return sample def check_augmented_sample(self, aug_output): if self.keep_background_prob < 0.0 or random.random() < self.keep_background_prob: return True return aug_output['mask'].sum() > 1.0 def get_sample(self, index): path = self.image_paths[index] name_parts = path.split('_') mask_path = self.image_paths[index].replace('com', 'mask') mask_path = mask_path.replace(('_' + name_parts[-1]), '.png') target_path = self.image_paths[index].replace('com', 'gt') target_path = target_path.replace(('_' + name_parts[-1]), '.png') comp = cv2.imread(path) comp = cv2.cvtColor(comp, cv2.COLOR_BGR2RGB) real = cv2.imread(target_path) real = cv2.cvtColor(real, cv2.COLOR_BGR2RGB) mask = cv2.imread(mask_path) mask = mask[:, :, 0].astype(np.float32) / 255. mask = mask.astype(np.uint8) return {'comp': comp, 'mask': mask, 'real': real, 'img_path': path} def __len__(self): return len(self.image_paths)
comp为合成后的图片—————————— mask为合成区域的mask——————real为groundtrue
目标自然就是让comp → real了
叫啥都行
import functools
import torch
import torch.nn.functional as F
import tqdm
from torch import nn
from torch.nn import init
from torch.optim import lr_scheduler
class UnetGenerator(nn.Module): def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, use_attention=False): super(UnetGenerator, self).__init__() # construct unet structure weight = torch.FloatTensor([0.1]) self.weight = torch.nn.Parameter(weight, requires_grad=True) unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) # gradually reduce the number of filters from ngf * 8 to ngf unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_attention=use_attention) unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_attention=use_attention) unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_attention=use_attention) self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer def forward(self, inputs): ori_code_map = inputs[:, 4:, :, :] code_map_input = ori_code_map * torch.clamp(self.weight, min=0.001) mew_inputs = torch.cat([inputs[:, :4, :, :], code_map_input], 1) return self.model(mew_inputs) class UnetSkipConnectionBlock(nn.Module): def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, use_attention=False): super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d if input_nc is None: input_nc = outer_nc downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) uprelu = nn.ReLU(True) upnorm = norm_layer(outer_nc) if outermost: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) down = [downconv] up = [uprelu, upconv, nn.Tanh()] model = down + [submodule] + up elif innermost: upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv] up = [uprelu, upconv, upnorm] model = down + up else: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] if use_dropout: model = down + [submodule] + up + [nn.Dropout(0.5)] else: model = down + [submodule] + up self.use_attention = use_attention if use_attention: attention_conv = nn.Conv2d(outer_nc + input_nc, outer_nc + input_nc, kernel_size=1) attention_sigmoid = nn.Sigmoid() self.attention = nn.Sequential(*[attention_conv, attention_sigmoid]) self.model = nn.Sequential(*model) def forward(self, x): if self.outermost: return self.model(x) else: ret = torch.cat([x, self.model(x)], 1) return self.attention(ret) * ret if self.use_attention else ret
class PartialConv2d(nn.Conv2d): def __init__(self, *args, **kwargs): # whether the mask is multi-channel or not if 'multi_channel' in kwargs: self.multi_channel = kwargs['multi_channel'] kwargs.pop('multi_channel') else: self.multi_channel = False self.return_mask = True super(PartialConv2d, self).__init__(*args, **kwargs) if self.multi_channel: self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1]) else: self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]) self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * \ self.weight_maskUpdater.shape[3] self.last_size = (None, None, None, None) self.update_mask, self.mask_ratio = None, None def forward(self, input, mask_in=None): assert len(input.shape) == 4 if mask_in is not None or self.last_size != tuple(input.shape): self.last_size = tuple(input.shape) with torch.no_grad(): if self.weight_maskUpdater.type() != input.type(): self.weight_maskUpdater = self.weight_maskUpdater.to(input) if mask_in is None: # if mask is not provided, create a mask if self.multi_channel: mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], input.data.shape[3]).to(input) else: mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input) else: mask = mask_in self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1) self.mask_ratio = self.slide_winsize / (self.update_mask + 1e-8) self.update_mask = torch.clamp(self.update_mask, 0, 1) self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask) raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input) if self.bias is not None: bias_view = self.bias.view(1, self.out_channels, 1, 1) output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view output = torch.mul(output, self.update_mask) else: output = torch.mul(raw_out, self.mask_ratio) if self.return_mask: return output, self.update_mask else: return output class StyleEncoder(nn.Module): def __init__(self, style_dim, norm_layer=nn.BatchNorm2d): super(StyleEncoder, self).__init__() if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d ndf = 64 kw = 3 padw = 0 self.conv1f = PartialConv2d(3, ndf, kernel_size=kw, stride=2, padding=padw) self.relu1 = nn.ReLU(True) nf_mult = 1 n = 1 nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) self.conv2f = PartialConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias) self.norm2f = norm_layer(ndf * nf_mult) self.relu2 = nn.ReLU(True) n = 2 nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) self.conv3f = PartialConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias) self.norm3f = norm_layer(ndf * nf_mult) self.relu3 = nn.ReLU(True) n = 3 nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) self.conv4f = PartialConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias) self.norm4f = norm_layer(ndf * nf_mult) self.relu4 = nn.ReLU(True) n = 4 nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) self.conv5f = PartialConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias) self.avg_pooling = nn.AdaptiveAvgPool2d(1) self.convs = nn.Conv2d(ndf * nf_mult, style_dim, kernel_size=1, stride=1) def forward(self, input, mask): """Standard forward.""" xb = input mb = mask xb, mb = self.conv1f(xb, mb) xb = self.relu1(xb) xb, mb = self.conv2f(xb, mb) xb = self.norm2f(xb) xb = self.relu2(xb) xb, mb = self.conv3f(xb, mb) xb = self.norm3f(xb) xb = self.relu3(xb) xb, mb = self.conv4f(xb, mb) xb = self.norm4f(xb) xb = self.relu4(xb) xb, mb = self.conv5f(xb, mb) xb = self.avg_pooling(xb) s = self.convs(xb) return s
def init_weights(net, init_type='normal', init_gain=0.02): """Initialize network weights. Parameters: net (network) -- network to be initialized init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal init_gain (float) -- scaling factor for normal, xavier and orthogonal. We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might work better for some applications. Feel free to try yourself. """ def init_func(m): # define the initialization function classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, init_gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=init_gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=init_gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find( 'BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. init.normal_(m.weight.data, 1.0, init_gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) # apply the initialization function <init_func> def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights Parameters: net (network) -- the network to be initialized init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal gain (float) -- scaling factor for normal, xavier and orthogonal. gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 Return an initialized network. """ if len(gpu_ids) > 0: assert (torch.cuda.is_available()) net.to(gpu_ids[0]) net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs init_weights(net, init_type, init_gain=init_gain) return net
class BargainNetModel: def __init__(self, netE, netG, style_dim=16, img_size=512, init_type='normal', init_gain=0.02, gpu_ids=[]): self.gpu_ids = gpu_ids self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') self.lambda_tri = 0.01 self.lambda_f2b = 1.0 self.lambda_ff2 = 1.0 self.loss_names = ['L1', 'tri'] self.optimizers = [] self.lr = 0.0002 self.e_lr_ratio = 1.0 self.g_lr_ratio = 1.0 self.beta1 = 0.5 self.style_dim = style_dim self.image_size = img_size self.netE = init_net(netE, init_type, init_gain, self.gpu_ids) self.netG = init_net(netG, init_type, init_gain, self.gpu_ids) self.relu = nn.ReLU() self.margin = 0.1 self.tripletLoss = nn.TripletMarginLoss(margin=self.margin, p=2) self.criterionL1 = torch.nn.L1Loss() self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=self.lr * self.e_lr_ratio, betas=(self.beta1, 0.999)) self.optimizers.append(self.optimizer_E) self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=self.lr * self.g_lr_ratio, betas=(self.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.schedulers = [ lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=0) for optimizer in self.optimizers ] def set_input(self, input): self.comp = input['comp'].to(self.device) self.real = input['real'].to(self.device) self.mask = input['mask'].to(self.device) self.inputs = torch.cat([self.comp, self.mask], 1).to(self.device) self.bg = 1.0 - self.mask self.real_f = self.real * self.mask def forward(self): self.bg_sty_vector = self.netE(self.real, self.bg) self.real_fg_sty_vector = self.netE(self.real, self.mask) self.bg_sty_map = self.bg_sty_vector.expand([1, self.style_dim, self.image_size, self.image_size]) self.inputs_c2r = torch.cat([self.inputs, self.bg_sty_map], 1) self.harm = self.netG(self.inputs_c2r) self.harm_fg_sty_vector = self.netE(self.harm, self.mask) self.comp_fg_sty_vector = self.netE(self.comp, self.mask) self.fake_f = self.harm * self.mask def backward(self): self.loss_L1 = self.criterionL1(self.harm, self.real) self.loss_tri = (self.tripletLoss(self.real_fg_sty_vector, self.harm_fg_sty_vector, self.comp_fg_sty_vector) * self.lambda_ff2 + self.tripletLoss(self.harm_fg_sty_vector, self.bg_sty_vector, self.comp_fg_sty_vector) * self.lambda_f2b) * self.lambda_tri self.loss = self.loss_L1 + self.loss_tri self.loss.backward(retain_graph=True) def optimize_parameters(self): self.forward() self.optimizer_E.zero_grad() self.optimizer_G.zero_grad() self.backward() self.optimizer_E.step() self.optimizer_G.step()
主要改动是这里,给原来的训练方式加上了tqdm的进度条,现在可以在进度条上看到[“l1_loss”, “tri_loss”, “l1_loss + tri_loss”]的变化,更直观一些。
from HarmonyDataset import Iharmony4Dataset
就是读取数据的代码命名了,改成一样的就没问题
# 参数设计按照官网的默认调用方式修改,官网的训练方式为: """ python train.py --name <experiment_name> --model bargainnet --dataset_mode iharmony4 --is_train 1 --norm batch --preprocess resize_and_crop --gpu_ids 0 --save_epoch_freq 1 --input_nc 20 --lr 1e-4 --beta1 0.9 --lr_policy step --lr_decay_iters 6574200 --netG s2ad """ G_net = UnetGenerator(20, 3, 8, 64, nn.BatchNorm2d, False, use_attention=True) E_net = StyleEncoder(16, norm_layer=nn.BatchNorm2d) if __name__ == "__main__": from HarmonyDataset import Iharmony4Dataset harmony_dataset = Iharmony4Dataset(dataset_root='/app/data/com/') datalen = len(harmony_dataset) model = BargainNetModel(E_net, G_net, gpu_ids=[]) EPOCH = 20 best_loss = 0.3 # best loss, default as 0.3 for epoch in range(EPOCH): tqdm_bar = tqdm.tqdm(enumerate(harmony_dataset), total=datalen, desc='Epoch {}/{}'.format(epoch + 1, EPOCH)) epoch_l1, epoch_tri = 0, 0 for i, data in tqdm_bar: model.set_input(data) # unpack data from a dataset and apply preprocessing model.optimize_parameters() # calculate loss functions, get gradients, update network weights epoch_l1 += model.loss_L1.item() epoch_tri += model.loss_tri.item() tqdm_bar.set_postfix(L1=epoch_l1 / (i + 1), tri=epoch_tri / (i + 1), total=(epoch_l1 + epoch_tri) / (i + 1), best_loss=best_loss) if best_loss > (epoch_l1 + epoch_tri) / datalen: # cache our latest model every <save_latest_freq> iterations print('the best model improve loss from {0} to {1}'.format(best_loss, (epoch_l1 + epoch_tri) / datalen)) best_loss = (epoch_l1 + epoch_tri) / datalen # model save weights torch.save(model.netG.state_dict(), 'best_netG.pth') torch.save(model.netE.state_dict(), 'best_netE.pth') # update learning rates at the end of every epoch. for scheduler in model.schedulers: scheduler.step() # save the netG model complete # x = torch.zeros(1, 20, 512, 512, dtype=torch.float, requires_grad=False) # import hiddenlayer as h # myNetGraph = h.build_graph(netG, x) # 建立网络模型图 # myNetGraph.save(path='./demoModel-G', format='pdf') # 保存网络模型图,可以设置 png 和 PDF 等. else: G_net.load_state_dict(torch.load('/app/checkpoints/best_net_G.pth')) E_net.load_state_dict(torch.load('/app/checkpoints/best_net_E.pth')) print('model load weights success')
预测的效果就这样,随便上百度找了个药品之后去背景,再粘上去。预测时把real换成comp即可。
具体项目还请看论文的GitHub实现:https://github.com/bcmi/BargainNet-Image-Harmonization
应该就这样了
对模型有疑问的建议去看论文问作者,我只是代码的搬运工
finish
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。