当前位置:   article > 正文

#Python&Pytorch 图片和谐化模型——BargainNet_python&pytorch图片和谐化

python&pytorch图片和谐化


前言

BargainNet是bcmi的一个项目。具体项目介绍见GitHub链接。出于各种原因需要使用BargainNet,因为有些不习惯用命令行启动训练模型,所以将里面使用的默认模型、参数直接提取出来,简化成了简单的“读取数据”和“训练模型”两个文件。


一、文件结构

训练数据的文件结构如下(去不掉水印我也很烦):

IHD_train.txt的结构很简单,就是文件列表而已:
在这里插入图片描述

其他的就是读取数据部分的代码和模型代码放在同一文件夹,改一下读取数据代码里数据集的路径就可以

二、读取数据部分

文件名为:HarmonyDataset.py,方便模型导入

1.引入库

import os.path
import random
from abc import ABC

import cv2.cv2 as cv2
import numpy as np
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from albumentations import HorizontalFlip, RandomResizedCrop, Compose, DualTransform, ToGray
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

2.读入数据

class HCompose(Compose):
    def __init__(self, transforms, *args, additional_targets=None, no_nearest_for_masks=True, **kwargs):
        if additional_targets is None:
            additional_targets = {
                'real': 'image',
                'mask': 'mask'
            }
        self.additional_targets = additional_targets
        super().__init__(transforms, *args, additional_targets=additional_targets, **kwargs)
        if no_nearest_for_masks:
            for t in transforms:
                if isinstance(t, DualTransform):
                    t._additional_targets['mask'] = 'image'


def get_transform(params=None, no_flip=True, grayscale=False):
    transform_list = []
    if grayscale:
        transform_list.append(ToGray())
    if params is None:
        transform_list.append(RandomResizedCrop(512, 512, scale=(0.5, 1.0)))

    if not no_flip:
        if params is None:
            transform_list.append(HorizontalFlip())

    return HCompose(transform_list)


class Iharmony4Dataset(data.Dataset, ABC):
    def __init__(self, dataset_root,):
        self.image_paths = []
        print('loading training file: ')
        self.keep_background_prob = 0.05
        self.file = dataset_root.replace("com", "") + 'IHD_train.txt'
        with open(self.file, 'r') as f:
            for line in f.readlines():
                self.image_paths.append(os.path.join(dataset_root, line.rstrip()))

        self.transform = get_transform()
        self.input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    def __getitem__(self, index):
        sample = self.get_sample(index)
        self.check_sample_types(sample)
        sample = self.augment_sample(sample)
        comp = self.input_transform(sample['image'])
        real = self.input_transform(sample['real'])
        mask = sample['mask'].astype(np.float32)
        mask = mask[np.newaxis, ...].astype(np.float32)
        output = {
            'comp': comp.unsqueeze(0),
            'mask': torch.from_numpy(mask).unsqueeze(0),
            'real': real.unsqueeze(0),
            'img_path': sample['img_path']
        }
        return output

    def check_sample_types(self, sample):
        assert sample['comp'].dtype == 'uint8'
        if 'real' in sample:
            assert sample['real'].dtype == 'uint8'

    def augment_sample(self, sample):
        if self.transform is None:
            return sample
        additional_targets = {target_name: sample[target_name]
                              for target_name in self.transform.additional_targets.keys()}

        valid_augmentation = False
        while not valid_augmentation:
            aug_output = self.transform(image=sample['comp'], **additional_targets)
            valid_augmentation = self.check_augmented_sample(aug_output)

        for target_name, transformed_target in aug_output.items():
            sample[target_name] = transformed_target

        return sample

    def check_augmented_sample(self, aug_output):
        if self.keep_background_prob < 0.0 or random.random() < self.keep_background_prob:
            return True

        return aug_output['mask'].sum() > 1.0

    def get_sample(self, index):
        path = self.image_paths[index]
        name_parts = path.split('_')
        mask_path = self.image_paths[index].replace('com', 'mask')
        mask_path = mask_path.replace(('_' + name_parts[-1]), '.png')
        target_path = self.image_paths[index].replace('com', 'gt')
        target_path = target_path.replace(('_' + name_parts[-1]), '.png')

        comp = cv2.imread(path)
        comp = cv2.cvtColor(comp, cv2.COLOR_BGR2RGB)
        real = cv2.imread(target_path)
        real = cv2.cvtColor(real, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path)
        mask = mask[:, :, 0].astype(np.float32) / 255.
        mask = mask.astype(np.uint8)

        return {'comp': comp, 'mask': mask, 'real': real, 'img_path': path}

    def __len__(self):
        return len(self.image_paths)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109

comp为合成后的图片—————————— mask为合成区域的mask——————real为groundtrue
在这里插入图片描述
目标自然就是让comp → real了


三、模型构成

叫啥都行

1.引入库

import functools

import torch
import torch.nn.functional as F
import tqdm
from torch import nn
from torch.nn import init
from torch.optim import lr_scheduler
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

2.模型结构——G

class UnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False,
                 use_attention=False):
        super(UnetGenerator, self).__init__()
        # construct unet structure
        weight = torch.FloatTensor([0.1])
        self.weight = torch.nn.Parameter(weight, requires_grad=True)
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer,
                                             innermost=True)  # add the innermost layer
        for i in range(num_downs - 5):  # add intermediate layers with ngf * 8 filters
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
                                                 norm_layer=norm_layer, use_dropout=use_dropout)
        # gradually reduce the number of filters from ngf * 8 to ngf
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
                                             norm_layer=norm_layer, use_attention=use_attention)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block,
                                             norm_layer=norm_layer, use_attention=use_attention)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer,
                                             use_attention=use_attention)
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True,
                                             norm_layer=norm_layer)  # add the outermost layer

    def forward(self, inputs):
        ori_code_map = inputs[:, 4:, :, :]
        code_map_input = ori_code_map * torch.clamp(self.weight, min=0.001)
        mew_inputs = torch.cat([inputs[:, :4, :, :], code_map_input], 1)
        return self.model(mew_inputs)


class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False,
                 use_attention=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.use_attention = use_attention
        if use_attention:
            attention_conv = nn.Conv2d(outer_nc + input_nc, outer_nc + input_nc, kernel_size=1)
            attention_sigmoid = nn.Sigmoid()
            self.attention = nn.Sequential(*[attention_conv, attention_sigmoid])

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            ret = torch.cat([x, self.model(x)], 1)
            return self.attention(ret) * ret if self.use_attention else ret
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82

3.模型结构——E

class PartialConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        # whether the mask is multi-channel or not
        if 'multi_channel' in kwargs:
            self.multi_channel = kwargs['multi_channel']
            kwargs.pop('multi_channel')
        else:
            self.multi_channel = False

        self.return_mask = True

        super(PartialConv2d, self).__init__(*args, **kwargs)

        if self.multi_channel:
            self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0],
                                                 self.kernel_size[1])
        else:
            self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])

        self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * \
                             self.weight_maskUpdater.shape[3]

        self.last_size = (None, None, None, None)
        self.update_mask, self.mask_ratio = None, None

    def forward(self, input, mask_in=None):
        assert len(input.shape) == 4
        if mask_in is not None or self.last_size != tuple(input.shape):
            self.last_size = tuple(input.shape)

            with torch.no_grad():
                if self.weight_maskUpdater.type() != input.type():
                    self.weight_maskUpdater = self.weight_maskUpdater.to(input)

                if mask_in is None:
                    # if mask is not provided, create a mask
                    if self.multi_channel:
                        mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2],
                                          input.data.shape[3]).to(input)
                    else:
                        mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input)
                else:
                    mask = mask_in

                self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride,
                                            padding=self.padding, dilation=self.dilation, groups=1)

                self.mask_ratio = self.slide_winsize / (self.update_mask + 1e-8)
                self.update_mask = torch.clamp(self.update_mask, 0, 1)
                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)

        raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)

        if self.bias is not None:
            bias_view = self.bias.view(1, self.out_channels, 1, 1)
            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
            output = torch.mul(output, self.update_mask)
        else:
            output = torch.mul(raw_out, self.mask_ratio)

        if self.return_mask:
            return output, self.update_mask
        else:
            return output


class StyleEncoder(nn.Module):
    def __init__(self, style_dim, norm_layer=nn.BatchNorm2d):
        super(StyleEncoder, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        ndf = 64
        kw = 3
        padw = 0
        self.conv1f = PartialConv2d(3, ndf, kernel_size=kw, stride=2, padding=padw)
        self.relu1 = nn.ReLU(True)
        nf_mult = 1

        n = 1
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n, 8)
        self.conv2f = PartialConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw,
                                    bias=use_bias)
        self.norm2f = norm_layer(ndf * nf_mult)
        self.relu2 = nn.ReLU(True)

        n = 2
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n, 8)
        self.conv3f = PartialConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw,
                                    bias=use_bias)
        self.norm3f = norm_layer(ndf * nf_mult)
        self.relu3 = nn.ReLU(True)

        n = 3
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n, 8)
        self.conv4f = PartialConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw,
                                    bias=use_bias)
        self.norm4f = norm_layer(ndf * nf_mult)
        self.relu4 = nn.ReLU(True)

        n = 4
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n, 8)
        self.conv5f = PartialConv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw,
                                    bias=use_bias)
        self.avg_pooling = nn.AdaptiveAvgPool2d(1)
        self.convs = nn.Conv2d(ndf * nf_mult, style_dim, kernel_size=1, stride=1)

    def forward(self, input, mask):
        """Standard forward."""
        xb = input
        mb = mask

        xb, mb = self.conv1f(xb, mb)
        xb = self.relu1(xb)
        xb, mb = self.conv2f(xb, mb)
        xb = self.norm2f(xb)
        xb = self.relu2(xb)
        xb, mb = self.conv3f(xb, mb)
        xb = self.norm3f(xb)
        xb = self.relu3(xb)
        xb, mb = self.conv4f(xb, mb)
        xb = self.norm4f(xb)
        xb = self.relu4(xb)
        xb, mb = self.conv5f(xb, mb)
        xb = self.avg_pooling(xb)
        s = self.convs(xb)
        return s
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132

4.初始化模型与权重

def init_weights(net, init_type='normal', init_gain=0.02):
    """Initialize network weights.

    Parameters:
        net (network)   -- network to be initialized
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.

    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
    work better for some applications. Feel free to try yourself.
    """

    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find(
                'BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)  # apply the initialization function <init_func>


def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
    """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
    Parameters:
        net (network)      -- the network to be initialized
        init_type (str)    -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        gain (float)       -- scaling factor for normal, xavier and orthogonal.
        gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2

    Return an initialized network.
    """

    if len(gpu_ids) > 0:
        assert (torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)  # multi-GPUs
    init_weights(net, init_type, init_gain=init_gain)
    return net
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53

5.构建BargainNet

class BargainNetModel:
    def __init__(self, netE, netG, style_dim=16, img_size=512, init_type='normal', init_gain=0.02, gpu_ids=[]):
        self.gpu_ids = gpu_ids
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
        self.lambda_tri = 0.01
        self.lambda_f2b = 1.0
        self.lambda_ff2 = 1.0
        self.loss_names = ['L1', 'tri']
        self.optimizers = []
        self.lr = 0.0002
        self.e_lr_ratio = 1.0
        self.g_lr_ratio = 1.0
        self.beta1 = 0.5

        self.style_dim = style_dim
        self.image_size = img_size
        self.netE = init_net(netE, init_type, init_gain, self.gpu_ids)
        self.netG = init_net(netG, init_type, init_gain, self.gpu_ids)
        self.relu = nn.ReLU()
        self.margin = 0.1
        self.tripletLoss = nn.TripletMarginLoss(margin=self.margin, p=2)
        self.criterionL1 = torch.nn.L1Loss()
        self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=self.lr * self.e_lr_ratio,
                                            betas=(self.beta1, 0.999))
        self.optimizers.append(self.optimizer_E)
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=self.lr * self.g_lr_ratio,
                                            betas=(self.beta1, 0.999))
        self.optimizers.append(self.optimizer_G)
        self.schedulers = [
            lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=0) for optimizer in self.optimizers
        ]

    def set_input(self, input):
        self.comp = input['comp'].to(self.device)
        self.real = input['real'].to(self.device)
        self.mask = input['mask'].to(self.device)
        self.inputs = torch.cat([self.comp, self.mask], 1).to(self.device)
        self.bg = 1.0 - self.mask
        self.real_f = self.real * self.mask

    def forward(self):
        self.bg_sty_vector = self.netE(self.real, self.bg)
        self.real_fg_sty_vector = self.netE(self.real, self.mask)
        self.bg_sty_map = self.bg_sty_vector.expand([1, self.style_dim, self.image_size, self.image_size])
        self.inputs_c2r = torch.cat([self.inputs, self.bg_sty_map], 1)
        self.harm = self.netG(self.inputs_c2r)

        self.harm_fg_sty_vector = self.netE(self.harm, self.mask)
        self.comp_fg_sty_vector = self.netE(self.comp, self.mask)
        self.fake_f = self.harm * self.mask

    def backward(self):
        self.loss_L1 = self.criterionL1(self.harm, self.real)
        self.loss_tri = (self.tripletLoss(self.real_fg_sty_vector, self.harm_fg_sty_vector,
                                          self.comp_fg_sty_vector) * self.lambda_ff2
                         + self.tripletLoss(self.harm_fg_sty_vector, self.bg_sty_vector,
                                            self.comp_fg_sty_vector) * self.lambda_f2b) * self.lambda_tri
        self.loss = self.loss_L1 + self.loss_tri
        self.loss.backward(retain_graph=True)

    def optimize_parameters(self):
        self.forward()
        self.optimizer_E.zero_grad()
        self.optimizer_G.zero_grad()
        self.backward()
        self.optimizer_E.step()
        self.optimizer_G.step()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67

6.训练模型:

主要改动是这里,给原来的训练方式加上了tqdm的进度条,现在可以在进度条上看到[“l1_loss”, “tri_loss”, “l1_loss + tri_loss”]的变化,更直观一些。

from HarmonyDataset import Iharmony4Dataset就是读取数据的代码命名了,改成一样的就没问题

# 参数设计按照官网的默认调用方式修改,官网的训练方式为:
"""
python train.py --name <experiment_name> --model bargainnet --dataset_mode iharmony4 --is_train 1 --norm batch --preprocess resize_and_crop --gpu_ids 0 --save_epoch_freq 1 --input_nc 20 --lr 1e-4 --beta1 0.9 --lr_policy step --lr_decay_iters 6574200 --netG s2ad
"""
G_net = UnetGenerator(20, 3, 8, 64, nn.BatchNorm2d, False, use_attention=True)
E_net = StyleEncoder(16, norm_layer=nn.BatchNorm2d)

if __name__ == "__main__":
    from HarmonyDataset import Iharmony4Dataset

    harmony_dataset = Iharmony4Dataset(dataset_root='/app/data/com/')
    datalen = len(harmony_dataset)
    model = BargainNetModel(E_net, G_net, gpu_ids=[])
    EPOCH = 20
    best_loss = 0.3  # best loss, default as 0.3
    for epoch in range(EPOCH):
        tqdm_bar = tqdm.tqdm(enumerate(harmony_dataset), total=datalen, desc='Epoch {}/{}'.format(epoch + 1, EPOCH))
        epoch_l1, epoch_tri = 0, 0
        for i, data in tqdm_bar:
            model.set_input(data)  # unpack data from a dataset and apply preprocessing
            model.optimize_parameters()  # calculate loss functions, get gradients, update network weights
            epoch_l1 += model.loss_L1.item()
            epoch_tri += model.loss_tri.item()
            tqdm_bar.set_postfix(L1=epoch_l1 / (i + 1), tri=epoch_tri / (i + 1),
                                 total=(epoch_l1 + epoch_tri) / (i + 1), best_loss=best_loss)

        if best_loss > (epoch_l1 + epoch_tri) / datalen:  # cache our latest model every <save_latest_freq> iterations
            print('the best model improve loss from {0} to {1}'.format(best_loss, (epoch_l1 + epoch_tri) / datalen))
            best_loss = (epoch_l1 + epoch_tri) / datalen
            # model save weights
            torch.save(model.netG.state_dict(), 'best_netG.pth')
            torch.save(model.netE.state_dict(), 'best_netE.pth')

        # update learning rates at the end of every epoch.
        for scheduler in model.schedulers:
            scheduler.step()

    # save the netG model complete
    # x = torch.zeros(1, 20, 512, 512, dtype=torch.float, requires_grad=False)
    # import hiddenlayer as h
    # myNetGraph = h.build_graph(netG, x)  # 建立网络模型图
    # myNetGraph.save(path='./demoModel-G', format='pdf')  # 保存网络模型图,可以设置 png 和 PDF 等.

else:
    G_net.load_state_dict(torch.load('/app/checkpoints/best_net_G.pth'))
    E_net.load_state_dict(torch.load('/app/checkpoints/best_net_E.pth'))
    print('model load weights success')

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

总结

预测的效果就这样,随便上百度找了个药品之后去背景,再粘上去。预测时把real换成comp即可。
训练效果

具体项目还请看论文的GitHub实现:https://github.com/bcmi/BargainNet-Image-Harmonization

应该就这样了

对模型有疑问的建议去看论文问作者,我只是代码的搬运工

finish

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/362091
推荐阅读
相关标签
  

闽ICP备14008679号