赞
踩
这一期将介绍另一种生成模型—玻尔兹曼机,虽然它现在已经较少被提及和使用,但其对概率密度函数的处理方式能加深我们对生成模型的理解。
作者&编辑 | 小米粥
1 玻尔兹曼机
玻尔兹曼机属于另一种显式概率模型,它是一种基于能量的模型。训练玻尔兹曼机同样需要基于极大似然的思想,但在计算极大似然的梯度时,运用了一种不同于变分法的近似算法。玻尔兹曼机已经较少引起关注,故在此我们只简述。
在能量模型中,通常将样本的概率p(x)建模成如下形式:
其中,Z为配分函数。为了增强模型的表达能力,通常会在可见变量h的基础上增加隐变量v,以最简单的受限玻尔兹曼机RBM为例,RBM中的可见变量和隐变量均为二值离散随机变量(当然也可推广至实值)。它定义了一个无向概率图模型,并且为二分图,其中可见变量v组成一部分,隐藏变量h组成另一部分,可见变量之间不存在连接,隐藏变量之间也不存在连接(“受限”即来源于此),可见变量与隐藏变量之间实行全连接,结构如下图所示:
在RBM中,可见变量和隐藏变量的联合概率分布由能量函数给出,即
其中能量函数的表达式为
配分函数Z可写为
考虑到二分图的特殊结构,发现在隐藏变量已知时,可见变量之间彼此独立;当可见变量已知时,隐藏变量之间也彼此独立,即有
以及
进一步地,可得到离散概率的具体表达式:
为了使得RBM与能量模型有一致的表达式,定义可见变量v的自由能f(v)为
其中hi为第i个隐藏变量,此时可见变量的概率为
配分函数Z。使用极大似然法训练RBM模型时,需要计算似然函数的梯度,记模型的的参数为θ ,则
可以看出,RBM明确定义了可见变量的概率密度函数,但它并不易求解,因为计算配分函数 Z 需要对所有的可见变量v和隐藏变量h求积分,所以对数似然log p(v)也无法直接求解,故无法直接使用极大似然的思想训练模型。但是,若跳过对数似然函数的求解而直接求解对数似然函数的梯度,也可完成模型的训练。对于其中的权值、偏置参数有:
分析其梯度表达式,其中不易计算的部分在于对可见变量v的期望的计算。RBM通过采样的方法来对梯度进行近似,然后使用近似得到的梯度进行权值更新。为了采样得到可见变量v,可构建一个马尔科夫链并使其最终收敛到p(v),即马尔科夫链的平稳分布为p(v)。初始随机给定样本,迭代运行足够次数后达到平稳分布,这时可根据转移矩阵从模型分布p(v)连续采样得到样本。我们可使用吉布斯采样方法完成该过程,由于两部分变量的独立性,当固定可见变量(或隐藏变量)时,隐藏变量(可见变量)的分布分别为h(n+1) ~sigmoid(WTv(n)+c)和 v(n+1)~sigmoid(Wv(n+1)+b) ,即先采样得到隐藏变量,再采样得到可见变量,这样,我们便可以使用“随机最大似然”完成生成模型的训练了。
玻尔兹曼机依赖马尔可夫链来训练模型或者使用模型生成样本,但是这种技术现在已经很少被使用了,很可能是因为马尔可夫链近似技术不能被适用于像ImageNet的生成问题。并且,即便是马尔可夫链方法可以很好的用于训练,但是使用一个基于马尔可夫链的模型生成样本是需要花费很大计算代价。
2 玻尔兹曼机代码
import numpy as np
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
batch_size = 64
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor()])),batch_size=batch_size)
test_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=False, transform=transforms.Compose([ transforms.ToTensor()])),batch_size=batch_size)
class RBM(nn.Module):
def __init__(self, n_vis=784, n_hin=500, k=5):
super(RBM, self).__init__()
self.W = nn.Parameter(torch.randn(n_hin, n_vis) * 1e-2)
self.v_bias = nn.Parameter(torch.zeros(n_vis))
self.h_bias = nn.Parameter(torch.zeros(n_hin))
self.k = k
def sample_from_p(self, p):
return F.relu(torch.sign(p - Variable(torch.rand(p.size()))))
def v_to_h(self, v):
p_h = F.sigmoid(F.linear(v, self.W, self.h_bias))
sample_h = self.sample_from_p(p_h)
return p_h, sample_h
def h_to_v(self, h):
p_v = F.sigmoid(F.linear(h, self.W.t(), self.v_bias))
sample_v = self.sample_from_p(p_v)
return p_v, sample_v
def forward(self, v):
pre_h1, h1 = self.v_to_h(v)
h_ = h1
for _ in range(self.k):
pre_v_, v_ = self.h_to_v(h_)
pre_h_, h_ = self.v_to_h(v_)
return v, v_
def free_energy(self, v):
vbias_term = v.mv(self.v_bias)
wx_b = F.linear(v, self.W, self.h_bias)
hidden_term = wx_b.exp().add(1).log().sum(1)
return (-hidden_term - vbias_term).mean()
rbm = RBM(k=1)
train_op = optim.SGD(rbm.parameters(),0.1)
for epoch in range(10):
loss_ = []
for _, (data, target) in enumerate(train_loader):
data = Variable(data.view(-1, 784))
sample_data = data.bernoulli()
print(sample_data[0])
v, v1 = rbm(sample_data)
loss = rbm.free_energy(v) - rbm.free_energy(v1)
#loss_.append(loss.data[0])
train_op.zero_grad()
loss.backward()
train_op.step()
print np.mean(loss_)
show_adn_save("real",make_grid(v.view(32,1,28,28).data))
show_adn_save("generate",make_grid(v1.view(32,1,28,28).data))
def show_adn_save(file_name,img):
npimg = np.transpose(img.numpy(),(1,2,0))
f = "./%s.png" % file_name
plt.imshow(npimg)
plt.imsave(f, npimg)
[1] 伊恩·古德费洛, 约书亚·本吉奥, 亚伦·库维尔. 深度学习
[2]李航. 统计机器学习
总结
本期带大家学习了玻尔兹曼机,至此几种显式生成模型都介绍完了,除了显式模型就是大家非常熟悉的隐式生成模型了,其主要的代表是GAN,我们生态已经介绍过许多内容,大家可以去学习。
个人知乎,欢迎关注
GAN群
有三AI建立了一个GAN群,便于有志者相互交流。感兴趣的同学也可以微信搜索xiaozhouguo94,备注“加入有三-GAN群”。
更多GAN的学习
知识星球是有三AI的付费内容社区,里面包超过100种经典GAN模型的解读,了解详细请阅读以下文章:
【杂谈】有三AI知识星球指导手册出炉!和公众号相比又有哪些内容?
有三AI秋季划GAN学习小组,可长期跟随有三学习GAN相关的内容,并获得及时指导,了解详细请阅读以下文章:
【杂谈】如何让2020年秋招CV项目能力更加硬核,可深入学习有三秋季划4大领域32个方向
转载文章请后台联系
侵权必究
往期精选
【GAN优化】GAN优化专栏上线,首谈生成模型与GAN基础
【GAN的优化】从KL和JS散度到fGAN
【GAN优化】详解对偶与WGAN
【GAN优化】详解SNGAN(频谱归一化GAN)
【GAN优化】一览IPM框架下的各种GAN
【GAN优化】GAN优化专栏栏主小米粥自述,脚踏实地,莫问前程
【GAN优化】GAN训练的几个问题
【GAN优化】GAN训练的小技巧
【GAN优化】从动力学视角看GAN是一种什么感觉?
【GAN优化】小批量判别器如何解决模式崩溃问题
【GAN优化】长文综述解读如何定量评价生成对抗网络(GAN)
【技术综述】有三说GANs(上)
【模型解读】历数GAN的5大基本结构
【百战GAN】如何使用GAN拯救你的低分辨率老照片
【百战GAN】二次元宅们,给自己做一个专属动漫头像可好!
【百战GAN】羡慕别人的美妆?那就用GAN复制粘贴过来
【百战GAN】GAN也可以拿来做图像分割,看起来效果还不错?
【百战GAN】新手如何开始你的第一个生成对抗网络(GAN)任务
【百战GAN】自动增强图像对比度和颜色美感,GAN如何做?
【直播回放】80分钟剖析GAN如何从各个方向提升图像的质量
【直播回放】60分钟剖析GAN如何用于人脸的各种算法
【直播回放】60分钟了解各类图像和视频生成GAN结构
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。