当前位置:   article > 正文

Glow模型【图解版加代码】

Glow模型【图解版加代码】

论文:Glow: Generative Flow with Invertible 1x1 Convolutions

代码:pytorch版本:rosinality/glow-pytorch: PyTorch implementation of Glow (github.com)

正版是TensorFlow版本 openai的

参考csdn文章:Glow-pytorch复现github项目_pytorch glow-CSDN博客
(pytorch进阶之路)NormalizingFlow标准流_normalizing flow-CSDN博客
本文的阅读前提:需要先看一下b站的Flow的讲解Flow-based Generative Model_哔哩哔哩_bilibili P59

本csdn文的目标:跑通代码+理解原理(不包含论文结果部分解读)

目录

1 引言

2 背景:

3 Generative Flow

Glow模块的整体代码:

Block模块:

Flow模块: 

 3.1 Actnorm: scale and bias layer with data dependent initialization

3.2 Invertible 1 1 convolution 可逆1*1卷积

3.3 Affine Coupling Layers 仿射耦合层

train部分


1 引言

基于flow模型改进,提出Glow

2 背景:

之前是基于flow的生成模型,我们的目标是从z(一个普通的分布)拟合到x(真实的分布),理解为从图A变为图B,而且要求这个过程是可逆的。

模型为G(x),目标最大化极大似然(最大似然理解为当参数为变量时,X=x的概率最大化):


也就是最后的这个。即最小化:

其中,flow的意思就是多个G连起来:

最终最大化下面这个,即:


其中,z的分布的选取一般为正态分布,均值为0函数G为双摄可逆函数,,可逆回去。

在计算方面,最后可以等于雅可比行列式的对角线。

3 Generative Flow

flow的每一步都由actnorm(3.1)、一个可逆的1x1卷积(3.2)和一个耦合层(3.3)组成。flow的深度为K,层数为L,下图。

Glow模块的整体代码:

  1. class Glow(nn.Module):
  2. def __init__(self, in_channel, n_flow, n_block, affine=True, conv_lu=True): #n_flow为K,n_block为L
  3. super().__init__()
  4. self.blocks = nn.ModuleList() #blocks层为图b的堆叠
  5. n_channel = in_channel
  6. for i in range(n_block - 1):
  7. self.blocks.append(Block(n_channel, n_flow, affine=affine, conv_lu=conv_lu))
  8. n_channel *= 2 #最后一个Block通道*2
  9. self.blocks.append(Block(n_channel, n_flow, split=False, affine=affine))
  10. def forward(self, input):
  11. log_p_sum = 0
  12. logdet = 0
  13. out = input
  14. z_outs = [] #中间z
  15. for block in self.blocks:
  16. out, det, log_p, z_new = block(out) #循环 out
  17. z_outs.append(z_new)
  18. logdet = logdet + det #logdet求和
  19. if log_p is not None:
  20. log_p_sum = log_p_sum + log_p #log_p求和
  21. return log_p_sum, logdet, z_outs # 输出log_p和logdet,以及最后的z序列
  22. def reverse(self, z_list, reconstruct=False):
  23. for i, block in enumerate(self.blocks[::-1]):#最后一个block去掉
  24. if i == 0:
  25. input = block.reverse(z_list[-1], z_list[-1], reconstruct=reconstruct)
  26. else:
  27. input = block.reverse(input, z_list[-(i + 1)], reconstruct=reconstruct)
  28. return input

Block模块:

  1. class Block(nn.Module):
  2. def __init__(self, in_channel, n_flow, split=True, affine=True, conv_lu=True):
  3. super().__init__()
  4. squeeze_dim = in_channel * 4 #扩大4倍
  5. self.flows = nn.ModuleList()
  6. for i in range(n_flow): #内部Flow块,一共n_flow块
  7. self.flows.append(Flow(squeeze_dim, affine=affine, conv_lu=conv_lu))
  8. self.split = split
  9. if split:
  10. self.prior = ZeroConv2d(in_channel * 2, in_channel * 4)
  11. else:
  12. self.prior = ZeroConv2d(in_channel * 4, in_channel * 8)
  13. def forward(self, input):
  14. b_size, n_channel, height, width = input.shape
  15. squeezed = input.view(b_size, n_channel, height // 2, 2, width // 2, 2) #尺寸变小
  16. squeezed = squeezed.permute(0, 1, 3, 5, 2, 4) #[b,c,h,2,w,2]变成[b,c,2,2,h,w]
  17. out = squeezed.contiguous().view(b_size, n_channel * 4, height // 2, width // 2) #深拷贝重新创建out
  18. logdet = 0
  19. for flow in self.flows:
  20. out, det = flow(out)
  21. logdet = logdet + det
  22. if self.split:
  23. out, z_new = out.chunk(2, 1) #分块,dim=1分2块
  24. mean, log_sd = self.prior(out).chunk(2, 1)
  25. log_p = gaussian_log_p(z_new, mean, log_sd)
  26. log_p = log_p.view(b_size, -1).sum(1)
  27. else:
  28. zero = torch.zeros_like(out)
  29. mean, log_sd = self.prior(zero).chunk(2, 1)
  30. log_p = gaussian_log_p(out, mean, log_sd)
  31. log_p = log_p.view(b_size, -1).sum(1)
  32. z_new = out
  33. return out, logdet, log_p, z_new
  34. def reverse(self, output, eps=None, reconstruct=False):
  35. input = output
  36. if reconstruct:
  37. if self.split:
  38. input = torch.cat([output, eps], 1)
  39. else:
  40. input = eps
  41. else:
  42. if self.split:
  43. mean, log_sd = self.prior(input).chunk(2, 1)
  44. z = gaussian_sample(eps, mean, log_sd)
  45. input = torch.cat([output, z], 1)
  46. else:
  47. zero = torch.zeros_like(input)
  48. # zero = F.pad(zero, [1, 1, 1, 1], value=1)
  49. mean, log_sd = self.prior(zero).chunk(2, 1)
  50. z = gaussian_sample(eps, mean, log_sd)
  51. input = z
  52. for flow in self.flows[::-1]:
  53. input = flow.reverse(input)
  54. b_size, n_channel, height, width = input.shape
  55. unsqueezed = input.view(b_size, n_channel // 4, 2, 2, height, width)
  56. unsqueezed = unsqueezed.permute(0, 1, 4, 2, 5, 3)
  57. unsqueezed = unsqueezed.contiguous().view(b_size, n_channel // 4, height * 2, width * 2)
  58. return unsqueezed

Flow模块: 

  1. class Flow(nn.Module):
  2. def __init__(self, in_channel, affine=True, conv_lu=True):
  3. super().__init__()
  4. self.actnorm = ActNorm(in_channel)
  5. if conv_lu:
  6. self.invconv = InvConv2dLU(in_channel)
  7. else:
  8. self.invconv = InvConv2d(in_channel)
  9. self.coupling = AffineCoupling(in_channel, affine=affine)
  10. def forward(self, input):
  11. out, logdet = self.actnorm(input)
  12. out, det1 = self.invconv(out)
  13. out, det2 = self.coupling(out)
  14. logdet = logdet + det1
  15. if det2 is not None:
  16. logdet = logdet + det2
  17. return out, logdet
  18. def reverse(self, output):
  19. input = self.coupling.reverse(output)
  20. input = self.invconv.reverse(input)
  21. input = self.actnorm.reverse(input)
  22. return input

 3.1 Actnorm: scale and bias layer with data dependent initialization

之前提出批归一化来缓解训练深度模型时遇到的问题。然而,由于批处理归一化(batch normalization)所增加的激活噪声的方差与每个GPU或其他处理单元(PU)的小批(minibatch)大小成反比,因此已知每个PU的小批大小会降低性能。因此,minibatch=1. 我们提出了一个actnorm层(用于激活归一化),它使用每个通道的尺度和偏置参数执行激活的仿射变换,类似于批量归一化。这些参数被初始化,使得每个通道的事后激活具有零均值和给定初始小批量数据的单位方差。这是数据依赖初始化的一种形式(Salimans and Kingma 2016)。初始化后,尺度和偏差被视为独立于数据的常规可训练参数。(没怎么懂,看代码吧)

在Flow模块中的第一层就是ActNorm。这一步其实就是一个标准化,对于input(经过squeezed)【batch,12,32,32 】进行每个通道的标准化,用每个通道,例如3通道计算batch*h*w的均值,【1,12,1,1】,标准差也同样,然后进行(x-均值)/(标准差+1e-6) 标准化。因为要可逆,需要计算det,为系数的log求和,其实就是1/(标准差+1e-6)的log求和。

  1. class ActNorm(nn.Module):
  2. def __init__(self, in_channel, logdet=True):
  3. super().__init__()
  4. self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1)) #每个通道有一个值 初始为全0
  5. self.scale = nn.Parameter(torch.ones(1, in_channel, 1, 1)) #初始scale为全1
  6. self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) #不被更新的参数
  7. self.logdet = logdet #是否计算logdet
  8. def initialize(self, input): #改变scale
  9. with torch.no_grad():
  10. flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)#深度拷贝,[12, 64*32*32]
  11. mean = (
  12. flatten.mean(1)
  13. .unsqueeze(1)
  14. .unsqueeze(2)
  15. .unsqueeze(3)
  16. .permute(1, 0, 2, 3)
  17. )#上面把input分为12通道,每个通道包含64张的图像的一个通道数据,求均值,并转化为[1,12,1,1]
  18. std = (
  19. flatten.std(1)
  20. .unsqueeze(1)
  21. .unsqueeze(2)
  22. .unsqueeze(3)
  23. .permute(1, 0, 2, 3)
  24. )#类似的,求std标准差,并转化为[1,12,1,1]
  25. self.loc.data.copy_(-mean)# loc为负的平均值
  26. self.scale.data.copy_(1 / (std + 1e-6)) #scale为1 / (std + 1e-6)
  27. def forward(self, input):#64, 12, 32, 32
  28. _, _, height, width = input.shape
  29. if self.initialized.item() == 0: #没操作,为0
  30. self.initialize(input) #initialized为一个操作,根据input,对loc和scale的赋值
  31. self.initialized.fill_(1) #操作完了,为1;哈哈哈哈要是我写的话,就是直接创建一个哨兵
  32. log_abs = logabs(self.scale)#均值的绝对值的log
  33. logdet = height * width * torch.sum(log_abs)#均值的logabs求和后乘以h*w det为系数log求和,一共h*w个点
  34. if self.logdet:
  35. return self.scale * (input + self.loc), logdet #对input每个点使用通道标准化 det为系数log求和
  36. else:
  37. return self.scale * (input + self.loc)
  38. def reverse(self, output):
  39. return output / self.scale - self.loc

3.2 Invertible 1 1 convolution 可逆1*1卷积

在Flow模块中的第二层,根据是否LU,选择是否带LU操作的1*1可逆卷积:

  1. if conv_lu:
  2. self.invconv = InvConv2dLU(in_channel)
  3. else:
  4. self.invconv = InvConv2d(in_channel)
  1. class InvConv2dLU(nn.Module):
  2. def __init__(self, in_channel):
  3. super().__init__()
  4. weight = np.random.randn(in_channel, in_channel)#[12,12]
  5. q, _ = la.qr(weight) #qr分解,q为正交矩阵,r为上三角矩阵
  6. w_p, w_l, w_u = la.lu(q.astype(np.float32))#对于正交矩阵q进行LU分解,p为置换矩阵,l为下三角,u为上三角,PA=LU,P就是把最大元素放在第一行
  7. w_s = np.diag(w_u)#对角线
  8. w_u = np.triu(w_u, 1) #去掉对角线,只保留上三角
  9. u_mask = np.triu(np.ones_like(w_u), 1) #上三角单位阵,不包含对角线
  10. l_mask = u_mask.T #下三角 不包含对角线
  11. w_p = torch.from_numpy(w_p) #q置换矩阵p
  12. w_l = torch.from_numpy(w_l) #q的下三角l
  13. w_s = torch.from_numpy(w_s.copy()) #q的上三角u的对角线
  14. w_u = torch.from_numpy(w_u) #q的上三角u的上三角
  15. self.register_buffer("w_p", w_p)#p不更新
  16. self.register_buffer("u_mask", torch.from_numpy(u_mask))
  17. self.register_buffer("l_mask", torch.from_numpy(l_mask))
  18. self.register_buffer("s_sign", torch.sign(w_s))
  19. self.register_buffer("l_eye", torch.eye(l_mask.shape[0])) #对角线全1,其余全0
  20. self.w_l = nn.Parameter(w_l) #更新的
  21. self.w_s = nn.Parameter(logabs(w_s))
  22. self.w_u = nn.Parameter(w_u)
  23. def forward(self, input):
  24. _, _, height, width = input.shape
  25. weight = self.calc_weight()#[12,12,1,1] 这里就是1*1卷积了,12种12通道 对应下面的卷积操作
  26. out = F.conv2d(input, weight) #输出通道数为卷积种类为12
  27. logdet = height * width * torch.sum(self.w_s)
  28. return out, logdet
  29. def calc_weight(self):
  30. weight = (
  31. self.w_p
  32. @ (self.w_l * self.l_mask + self.l_eye) #@为矩阵乘法
  33. @ ((self.w_u * self.u_mask) + torch.diag(self.s_sign * torch.exp(self.w_s)))
  34. )
  35. return weight.unsqueeze(2).unsqueeze(3)
  36. def reverse(self, output):
  37. weight = self.calc_weight()#weight跟上面的weight是同一个 需要先训练上面的那个weight
  38. return F.conv2d(output, weight.squeeze().inverse().unsqueeze(2).unsqueeze(3))

自己定义的权重W,(cxc),与输入的tensor h (h x w x c)之间进行卷积计算,因此,log_det的计算为:

但是,detW的计算复杂,为了简化计算复杂度,提出使用LU分解,把W参数化:

P为置换矩阵(不参与更新),L为下三角矩阵(对角线为0),U为上三角矩阵(对角线为0),diag(s)为分解时候的上三角矩阵plu的u的对角线,U仅仅只是u的对角线变为0,这样才符合plu分解,即,W=p*l*u。这样,log_det可以简化为:

对于较大的通道数c,可以大大节省。并且,除了P不参与更新外,L、U、s都参与更新。

也提供了不进行PLU分解的版本:

  1. class InvConv2d(nn.Module):
  2. def __init__(self, in_channel):
  3. super().__init__()
  4. weight = torch.randn(in_channel, in_channel)
  5. q, _ = torch.qr(weight)
  6. weight = q.unsqueeze(2).unsqueeze(3)
  7. self.weight = nn.Parameter(weight)
  8. def forward(self, input):
  9. _, _, height, width = input.shape
  10. out = F.conv2d(input, self.weight)
  11. logdet = (
  12. height * width * torch.slogdet(self.weight.squeeze().double())[1].float()
  13. )
  14. return out, logdet
  15. def reverse(self, output):
  16. return F.conv2d(
  17. output, self.weight.squeeze().inverse().unsqueeze(2).unsqueeze(3)
  18. )

3.3 Affine Coupling Layers 仿射耦合层

这一层在flow模块中的第三层

仿射耦合层是一种强大的可逆变换,其中正向函数、逆函数和对数行列式的计算效率很高。加性耦合层是s=1和log_det=0的特殊情况。

还是看代码吧。

Zero initialization:零初始化最后一个卷积。这样每个仿射耦合层最初执行一个恒等函数,这有助于训练非常深的网络。也就是说,网络一开始输入等于输出,因为F作为乘积为1和H为0。

  1. class ZeroConv2d(nn.Module):
  2. def __init__(self, in_channel, out_channel, padding=1):
  3. super().__init__()
  4. self.conv = nn.Conv2d(in_channel, out_channel, 3, padding=0)
  5. self.conv.weight.data.zero_()
  6. self.conv.bias.data.zero_()
  7. self.scale = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) # scale变成可以训练的 [1,12,1,1]
  8. def forward(self, input):
  9. out = F.pad(input, [1, 1, 1, 1], value=1) # 填充数值为1 从[64,512,32,32]变为[64,512,34,34]
  10. out = self.conv(out) #通道数变回 从512变回12 初始输出全为0,因为权重为0
  11. out = out * torch.exp(self.scale * 3) #0乘以1还是0
  12. return out

  1. class AffineCoupling(nn.Module):
  2. def __init__(self, in_channel, filter_size=512, affine=True):
  3. super().__init__()
  4. self.affine = affine
  5. self.net = nn.Sequential(
  6. nn.Conv2d(in_channel // 2, filter_size, 3, padding=1),
  7. nn.ReLU(inplace=True),
  8. nn.Conv2d(filter_size, filter_size, 1),
  9. nn.ReLU(inplace=True),
  10. ZeroConv2d(filter_size, in_channel if self.affine else in_channel // 2),#如果仿射,输出通道数为12,否则为6
  11. )
  12. self.net[0].weight.data.normal_(0, 0.05)#初始化权重,对于第一个Conv2d
  13. self.net[0].bias.data.zero_()
  14. self.net[2].weight.data.normal_(0, 0.05)#初始化权重,对于第二个Conv2d
  15. self.net[2].bias.data.zero_()
  16. def forward(self, input):
  17. in_a, in_b = input.chunk(2, 1)#分块,对于dim=1,通道分为2块,这应该就是上下两块 [6,6]
  18. if self.affine:
  19. log_s, t = self.net(in_a).chunk(2, 1)#6 6 输出初始都为0
  20. # s = torch.exp(log_s)
  21. s = torch.sigmoid(log_s + 2) #图中的F
  22. # out_a = s * in_a + t
  23. out_b = (in_b + t) * s #生成下面那个 t为图中的H,有所不同的是计算顺序
  24. logdet = torch.sum(torch.log(s).view(input.shape[0], -1), 1)
  25. else: #不生成F
  26. net_out = self.net(in_a) #图中的H 通道数为6
  27. out_b = in_b + net_out #直接生成下面 通道数相同
  28. logdet = None
  29. return torch.cat([in_a, out_b], 1), logdet #上面的那块不变,
  30. def reverse(self, output):
  31. out_a, out_b = output.chunk(2, 1) #上面的out拆分,第一个其实没有变
  32. if self.affine:
  33. log_s, t = self.net(out_a).chunk(2, 1) #由于第一个没有变,生成的这两个块与上面是一样的
  34. # s = torch.exp(log_s)
  35. s = torch.sigmoid(log_s + 2)#生成的F与上面也是一样的
  36. # in_a = (out_a - t) / s
  37. in_b = out_b / s - t #先除以F后减t
  38. else:
  39. net_out = self.net(out_a) #由于第一个没有变 生成的F没有变
  40. in_b = out_b - net_out #直接减掉就好
  41. return torch.cat([out_a, in_b], 1)

代码实现部分,关于s的生成注释掉的部分与视频中讲解的一致,属于标准形式,后面用sigmoid生成openai代码中也是如此。

至此,Flow模块已经完成。论文方法部分也结束了。

在Flow模块外部还有squeezed操作,把图像切分为4块后,拼起来,通道变为12后再送入Flow块。后面还有一个split操作。这形成一个Block块。如果需要split操作,则输出的一半作为z,另一半作为out送到下游。看代码

首先,高斯分布的概率密度函数:

对于此概率密度函数取对数log,以e为底:注意下面的输入,log_sd是对标准差取对数,其中mean和log_sd都是可以训练的。 

  1. def gaussian_log_p(x, mean, log_sd):
  2. return -0.5 * log(2 * pi) - log_sd - 0.5 * (x - mean) ** 2 / torch.exp(2 * log_sd)
  1. class Block(nn.Module):
  2. def __init__(self, in_channel, n_flow, split=True, affine=True, conv_lu=True):
  3. super().__init__()
  4. squeeze_dim = in_channel * 4 #扩大4倍
  5. self.flows = nn.ModuleList()
  6. for i in range(n_flow): #内部Flow块,一共n_flow块
  7. self.flows.append(Flow(squeeze_dim, affine=affine, conv_lu=conv_lu))
  8. self.split = split
  9. if split:#对于split,输入,输出的通道数不同
  10. self.prior = ZeroConv2d(in_channel * 2, in_channel * 4)
  11. else:
  12. self.prior = ZeroConv2d(in_channel * 4, in_channel * 8)
  13. def forward(self, input):
  14. b_size, n_channel, height, width = input.shape
  15. squeezed = input.view(b_size, n_channel, height // 2, 2, width // 2, 2) #尺寸变小
  16. squeezed = squeezed.permute(0, 1, 3, 5, 2, 4) #[b,c,h,2,w,2]变成[b,c,2,2,h,w]
  17. out = squeezed.contiguous().view(b_size, n_channel * 4, height // 2, width // 2) #深拷贝重新创建out [b, c*4, h//2, w//2]
  18. logdet = 0
  19. for flow in self.flows:
  20. out, det = flow(out)
  21. logdet = logdet + det
  22. if self.split:#如果split的话,flow的out一半是z,另一半用来生成log_p指标,
  23. out, z_new = out.chunk(2, 1) #通道分块,dim=1分2块 6,6
  24. mean, log_sd = self.prior(out).chunk(2, 1) #6,6 mean, log_sd都是可学习的
  25. log_p = gaussian_log_p(z_new, mean, log_sd) #z_new的分布为高斯分布的概率的log 这就是z是高斯分布的关键
  26. log_p = log_p.view(b_size, -1).sum(1)#求和
  27. else:
  28. zero = torch.zeros_like(out)
  29. mean, log_sd = self.prior(zero).chunk(2, 1)
  30. log_p = gaussian_log_p(out, mean, log_sd)#out的分布为高斯分布的概率的log
  31. log_p = log_p.view(b_size, -1).sum(1)
  32. z_new = out
  33. return out, logdet, log_p, z_new
  34. def reverse(self, output, eps=None, reconstruct=False): #reverse的输入,如果是最后一层,output和eps都是z_list,其他层的话output为out,eps为z
  35. input = output
  36. if reconstruct: #是否重建
  37. if self.split:
  38. input = torch.cat([output, eps], 1) #如果split了,【out,z】
  39. else:
  40. input = eps #z
  41. else: #如果不需要重建
  42. if self.split:
  43. mean, log_sd = self.prior(input).chunk(2, 1)
  44. z = gaussian_sample(eps, mean, log_sd)
  45. input = torch.cat([output, z], 1)
  46. else:
  47. zero = torch.zeros_like(input)
  48. # zero = F.pad(zero, [1, 1, 1, 1], value=1)
  49. mean, log_sd = self.prior(zero).chunk(2, 1)
  50. z = gaussian_sample(eps, mean, log_sd)
  51. input = z
  52. for flow in self.flows[::-1]:
  53. input = flow.reverse(input)
  54. b_size, n_channel, height, width = input.shape
  55. unsqueezed = input.view(b_size, n_channel // 4, 2, 2, height, width)
  56. unsqueezed = unsqueezed.permute(0, 1, 4, 2, 5, 3)
  57. unsqueezed = unsqueezed.contiguous().view(b_size, n_channel // 4, height * 2, width * 2)
  58. return unsqueezed

最后Glow模型:

  1. class Glow(nn.Module):
  2. def __init__(self, in_channel, n_flow, n_block, affine=True, conv_lu=True): #n_flow为K,n_block为L
  3. super().__init__()
  4. self.blocks = nn.ModuleList() #blocks层为图b的堆叠
  5. n_channel = in_channel
  6. for i in range(n_block - 1):
  7. self.blocks.append(Block(n_channel, n_flow, affine=affine, conv_lu=conv_lu))
  8. n_channel *= 2 #最后一个Block通道*2
  9. self.blocks.append(Block(n_channel, n_flow, split=False, affine=affine))
  10. def forward(self, input):
  11. log_p_sum = 0
  12. logdet = 0
  13. out = input
  14. z_outs = [] #中间z
  15. for block in self.blocks:
  16. out, det, log_p, z_new = block(out) #循环 out
  17. z_outs.append(z_new)
  18. logdet = logdet + det #logdet求和
  19. if log_p is not None:
  20. log_p_sum = log_p_sum + log_p #log_p求和
  21. return log_p_sum, logdet, z_outs # 输出log_p和logdet,以及最后的z序列
  22. def reverse(self, z_list, reconstruct=False):
  23. for i, block in enumerate(self.blocks[::-1]):#最后一个block去掉
  24. if i == 0:
  25. input = block.reverse(z_list[-1], z_list[-1], reconstruct=reconstruct)
  26. else:
  27. input = block.reverse(input, z_list[-(i + 1)], reconstruct=reconstruct)
  28. return input

train部分

  1. from tqdm import tqdm
  2. import numpy as np
  3. from PIL import Image
  4. from math import log, sqrt, pi
  5. import argparse
  6. import torch
  7. from torch import nn, optim
  8. from torch.autograd import Variable, grad
  9. from torch.utils.data import DataLoader
  10. import torch.utils.data
  11. from torchvision.datasets import CIFAR10
  12. from torchvision import datasets, transforms, utils
  13. from model import Glow
  14. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  15. parser = argparse.ArgumentParser(description="Glow trainer")
  16. parser.add_argument("--iter1", default=200000, type=int, help="maximum iterations") # 迭代周期
  17. parser.add_argument("--n_flow", default=32, type=int, help="number of flows in each block")
  18. parser.add_argument("--n_block", default=4, type=int, help="number of blocks")
  19. parser.add_argument("--no_lu", action="store_true", help="use plain convolution instead of LU decomposed version")
  20. parser.add_argument("--affine", action="store_true", help="use affine coupling instead of additive")
  21. parser.add_argument("--n_bits", default=5, type=int, help="number of bits")
  22. parser.add_argument("--lr", default=1e-4, type=float, help="learning rate")
  23. parser.add_argument("--temp", default=0.7, type=float, help="temperature of sampling")
  24. parser.add_argument("--n_sample", default=20, type=int, help="number of samples")
  25. def data_tr_1(x):
  26. x = x.resize((64, 64))
  27. x = np.array(x, dtype='float32') / 255
  28. x = (x - 0.5) / 0.5
  29. x = x.transpose((2, 0, 1))
  30. x = torch.from_numpy(x)
  31. return x
  32. def sample_data():
  33. transform = transforms.Compose(
  34. [
  35. transforms.Resize(64),
  36. transforms.CenterCrop(64),
  37. transforms.RandomHorizontalFlip(),
  38. transforms.ToTensor(),
  39. ]
  40. )
  41. dataset = CIFAR10('./data', train=True, transform=transform, download=True)
  42. loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
  43. #test_set = CIFAR10('./data', train=False, transform=transform, download=True)
  44. #test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
  45. #dataset = datasets.ImageFolder(path, transform=transform)
  46. #loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=4)
  47. loader = iter(loader)
  48. while True:
  49. try:
  50. yield next(loader)
  51. except StopIteration:
  52. loader = DataLoader(
  53. dataset, shuffle=True, batch_size=64, num_workers=4
  54. )
  55. loader = iter(loader)
  56. yield next(loader)
  57. def calc_z_shapes(n_channel, input_size, n_block):
  58. '''
  59. 每一个block之后输出的z_shape
  60. input:(3,64,64)
  61. [(6, 32, 32), (12, 16, 16), (48, 8, 8)]
  62. '''
  63. z_shapes = []
  64. for i in range(n_block - 1):
  65. input_size //= 2 #size 两倍变小
  66. n_channel *= 2 # 通道两倍变大
  67. z_shapes.append((n_channel, input_size, input_size))
  68. input_size //= 2
  69. z_shapes.append((n_channel * 4, input_size, input_size))
  70. return z_shapes
  71. def calc_loss(log_p, logdet, image_size, n_bins):
  72. # log_p = calc_log_p([z_list])
  73. n_pixel = image_size * image_size * 3
  74. loss = -log(n_bins) * n_pixel
  75. loss = loss + logdet + log_p
  76. return (
  77. (-loss / (log(2) * n_pixel)).mean(),
  78. (log_p / (log(2) * n_pixel)).mean(),
  79. (logdet / (log(2) * n_pixel)).mean(),
  80. )
  81. def train(args, model, optimizer):
  82. dataset = iter(sample_data())
  83. n_bins = 2.0 ** args.n_bits # 2^5bit
  84. z_sample = [] #中间初始值z?
  85. z_shapes = calc_z_shapes(3, image_size, n_block)
  86. for z in z_shapes:
  87. z_new = torch.randn(n_sample, *z) * temp # n_sample为batch
  88. z_sample.append(z_new.to(device)) #[-2, 3]左右
  89. with tqdm(range(iter1)) as pbar:
  90. for i in pbar:
  91. image, _ = next(dataset)
  92. image = image.to(device)
  93. image = image * 255 # [0, 255]
  94. if args.n_bits < 8: #5
  95. image = torch.floor(image / 2 ** (8 - args.n_bits)) #[0,31]
  96. image = image / n_bins - 0.5 #[-0.5, 2.6]
  97. if i == 0:
  98. with torch.no_grad():
  99. log_p, logdet, _ = model.module(image + torch.rand_like(image) / n_bins)
  100. continue
  101. else:
  102. log_p, logdet, _ = model(image + torch.rand_like(image) / n_bins) #加噪声
  103. logdet = logdet.mean()
  104. loss, log_p, log_det = calc_loss(log_p, logdet, image_size, n_bins)
  105. model.zero_grad()
  106. loss.backward()
  107. # warmup_lr = args.lr * min(1, i * batch_size / (50000 * 10))
  108. warmup_lr = args.lr
  109. optimizer.param_groups[0]["lr"] = warmup_lr
  110. optimizer.step()
  111. pbar.set_description(
  112. f"Loss: {loss.item():.5f}; logP: {log_p.item():.5f}; logdet: {log_det.item():.5f}; lr: {warmup_lr:.7f}"
  113. )
  114. if i % 100 == 0:
  115. with torch.no_grad():
  116. utils.save_image(
  117. model_single.reverse(z_sample).cpu().data,
  118. f"sample/{str(i + 1).zfill(6)}.png",
  119. normalize=True,
  120. nrow=10,
  121. range=(-0.5, 0.5),
  122. )
  123. if i % 10000 == 0:
  124. torch.save(
  125. model.state_dict(), f"checkpoint/model_{str(i + 1).zfill(6)}.pt"
  126. )
  127. torch.save(
  128. optimizer.state_dict(), f"checkpoint/optim_{str(i + 1).zfill(6)}.pt"
  129. )
  130. if __name__ == "__main__":
  131. args = parser.parse_args()
  132. print(args)
  133. image_size = 64
  134. n_flow = args.n_flow
  135. n_block = args.n_block
  136. n_sample = args.n_sample
  137. temp = args.temp
  138. iter1 = args.iter1
  139. model_single = Glow(3, n_flow, n_block, affine=args.affine, conv_lu=not args.no_lu)
  140. model = nn.DataParallel(model_single)
  141. # model = model_single
  142. model = model.to(device)
  143. optimizer = optim.Adam(model.parameters(), lr=1e-4)
  144. train(args, model, optimizer)

数据集我选用的是cifar10.batch size设置为64,其余都是原本的默认值。loss为log_p与logdet相加后取负,也就是目标为最大化log_p, 使输出逐渐为高斯分布,logdet使得可逆后回去Image。

最后生成的结果

由于我的数据并没有分类存放,导致学习到的特征比较混乱,而且我也只是跑通代码理解原理而已。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/680662
推荐阅读
相关标签
  

闽ICP备14008679号