赞
踩
https://arxiv.org/pdf/1611.01144.pdf%20http://arxiv.org/abs/1611.01144.pdf
考虑离散变量 x x x ,如果已知其分布向量 π = { π 1 , … , π N } \pi=\{\pi_1,\dots,\pi_N\} π={π1,…,πN} ,则得到 x x x 的采样的一个简单采样方法是:
x π = o n e _ h o t ( arg max i ( π i ) ) x_{\pi}=one\_hot (\argmax_i(\pi_i)) xπ=one_hot(iargmax(πi))
根据 s o f t m a x softmax softmax 函数,其中 arg max \arg \max argmax 这一操作取得 π n \pi_n πn 的概率为:
P ( π n ) = e π n ∑ j = 1 N e π j P(\pi_n)=\frac{e^{\pi_{n} }}{\sum_{j=1}^{N} e^{\pi_{j} }} P(πn)=∑j=1Neπjeπn
当我们在BP神经网络中,需要让采样结果可导的时候,这样简单的采样就行不通了。
原因在于 arg max \arg \max argmax 这一操作是不可导的,即并没有一个表达式可以映射 π \pi π 到 z z z 上。
该技巧应用甚广,如深度学习中的各种 GAN、强化学习中的 A2C 和 MADDPG 算法等等。
只要涉及在离散分布上运用重参数技巧时(re-parameterization),都可以试试 Gumbel-Softmax Trick。
一般来说,对于 N N N 维概率向量 π \pi π,我们可以通过添加随机 Gumbel 噪声 G i G_i Gi 再取样:
x π = arg max i ( ln ( π i ) + G i ) x_{\pi}=\argmax_i \left(\ln \left(\pi_{i}\right)+G_{i}\right) xπ=iargmax(ln(πi)+Gi)
其中 G i G_i Gi 是独立同分布的标准 Gumbel 分布的随机变量。
我们重新看一下 Gumbel 分布,Gumbel 分布是一种极值型分布,它的概率密度函数(PDF)为:
f ( x ; μ , β ) = e − z − e − z , z = x − μ β f(x ; \mu, \beta)=e^{-z-e^{-z}}, z=\frac{x-\mu}{\beta} f(x;μ,β)=e−z−e−z,z=βx−μ
公式中, μ \mu μ 是位置系数, β \beta β 是尺度系数,标准 Gumbel 分布中有: μ = 0 \mu=0 μ=0, β = 1 \beta=1 β=1。
相应的,Gmubel 分布的累积密度函数(CDF)为:
F ( x ; μ , β ) = e − e − ( x − μ ) / β F(x ; \mu, \beta)=e^{-e^{-(x-\mu) / \beta}} F(x;μ,β)=e−e−(x−μ)/β
并且我们易得它的反函数:
F − 1 ( y ; μ , β ) = μ − β ln ( − ln ( y ) ) F^{-1}(y ; \mu, \beta)=\mu-\beta \ln (-\ln (y)) F−1(y;μ,β)=μ−βln(−ln(y))
这样我们就可以通过从均匀分布中求逆得到 G i G_i Gi:
G i = − ln ( − ln ( U i ) ) , U i ∼ U ( 0 , 1 ) G_{i}=-\ln \left(-\ln \left(U_{i}\right)\right), U_{i} \sim U(0,1) Gi=−ln(−ln(Ui)),Ui∼U(0,1)
这就是 Gmubel-Max trick。
由于上述算法中 arg max \arg \max argmax 这一操作仍是不可导的,因此我们可以用两种方式来让该操作可导。
一种方式是使用Straight-Through Estimator 思想(例如 VQ-VAE 中使用的),重新设计采样为:
t π = s o f t m a x ( ln ( π i ) + G i ) t_\pi=softmax(\ln \left(\pi_{i}\right)+G_{i}) tπ=softmax(ln(πi)+Gi)
z π = o n e _ h o t ( arg max ( t π ) ) z_{\pi}=one\_hot(\arg \max \left(t_\pi\right)) zπ=one_hot(argmax(tπ))
x π = t π + s g [ z π − t π ] x_\pi = t_\pi + sg[z_\pi-t_\pi] xπ=tπ+sg[zπ−tπ]
其中 s g sg sg 是 stop gradient 的意思,这样向前传播的时候使用的是采样的 z π z_{\pi} zπ,而反向传播则是使用 t π t_\pi tπ。
但是 Gumbel Softmax 的作者指出并证明,直接用 s o f t m a x softmax softmax 函数来代替量化过程也是可行的,即:
x π = s o f t m a x ( ln ( π i ) + G i ) x_{\pi}=softmax \left(\ln \left(\pi_{i}\right)+G_{i}\right) xπ=softmax(ln(πi)+Gi)
具体操作为:
x i = e z i / τ ∑ j = 1 N e z j / τ x_i=\frac{e^{z_{i} / \tau}}{\sum_{j=1}^{N} e^{z_{j} / \tau}} xi=∑j=1Nezj/τezi/τ
前三步的目标是让新的随机变量 z z z 与原随机变量 π \pi π 相同,只需要证明取到 z n z_n zn 的概率跟取到 π n \pi_n πn 的概率相同,第四步则是使用温度参数 τ \tau τ 来控制采样结果的分布倾向:
下面我们来证明。
证明取到 z n z_n zn 的概率跟取到 π n \pi_n πn 的概率相同可以写为:
P ( z n ≥ z n ′ ; ∀ n ′ ≠ n ∣ { π n ′ } n ′ = 1 N ) = P ( π n ) P\left(z_{n} \geq z_{n^{\prime}} ; \forall n^{\prime} \neq n \mid\left\{\pi_{n^{\prime}}\right\}_{n^{\prime}=1}^{N}\right)=P(\pi_n) P(zn≥zn′;∀n′=n∣{πn′}n′=1N)=P(πn)
也就是 z n z_{n} zn 比其他所有 z n ′ z_{n^{\prime}} zn′ 都大的概率为 P ( π n ) P(\pi_n) P(πn)。
根据条件累积概率分布函数,我们可以得到:
P ( z n ≥ z n ′ ; ∀ n ′ ≠ n ∣ { π n ′ } n ′ = 1 N ) = ∏ n ′ ≠ n P ( z n ≥ z n ′ ) P\left(z_{n} \geq z_{n^{\prime}} ; \forall n^{\prime} \neq n \mid\left\{\pi_{n^{\prime}}\right\}_{n^{\prime}=1}^{N}\right)=\prod_{n^{\prime} \neq n}P(z_n\geq z_n^{\prime}) P(zn≥zn′;∀n′=n∣{πn′}n′=1N)=n′=n∏P(zn≥zn′)
注意到, z n = π n + G n z_n=\pi_n+G_n zn=πn+Gn,并且 G n G_n Gn 服从 μ = 0 \mu=0 μ=0, β = 1 \beta=1 β=1 的标准 Gumbel 分布,那么将 π n \pi_n πn 看作常数时, z n z_n zn 服从 μ = π n \mu=\pi_n μ=πn, β = 1 \beta=1 β=1 的标准 Gumbel 分布,它的 CDF 为:
F z n ( x ) = e − e − ( x − π n ) F_{z_n}(x)=e^{-e^{-(x-\pi_n)}} Fzn(x)=e−e−(x−πn)
也就是:
F z n ′ ( x ) = e − e − ( x − π n ′ ) F_{z_n^{\prime}}(x)=e^{-e^{-(x-\pi_{n^{\prime}})}} Fzn′(x)=e−e−(x−πn′)
那么根据 CDF 的定义,我们可得:
P ( z n ≥ z n ′ ) = P ( z n ′ ≤ z n ) = F z n ′ ( z n ) = e − e − ( z n − π n ′ ) P(z_n\geq z_n^{\prime})=P(z_n^{\prime}\leq z_n)=F_{z_n^{\prime}}(z_n)=e^{-e^{-(z_n-\pi_{n^{\prime}})}} P(zn≥zn′)=P(zn′≤zn)=Fzn′(zn)=e−e−(zn−πn′)
即:
P ( z n ≥ z n ′ ; ∀ n ′ ≠ n ∣ { π n ′ } n ′ = 1 N ) = ∏ n ′ ≠ n e − e − ( z n − π n ′ ) P\left(z_{n} \geq z_{n^{\prime}} ; \forall n^{\prime} \neq n \mid\left\{\pi_{n^{\prime}}\right\}_{n^{\prime}=1}^{N}\right)=\prod_{n^{\prime} \neq n}e^{-e^{-(z_n-\pi_{n^{\prime}})}} P(zn≥zn′;∀n′=n∣{πn′}n′=1N)=n′=n∏e−e−(zn−πn′)
同时我们可得 z n z_n zn 分布的 CDF 为:
f z n ( x ) = e − ( x − π n ) − e − ( x − π n ) f_{z_n}(x)=e^{-(x-\pi_n)-e^{-(x-\pi_n)}} fzn(x)=e−(x−πn)−e−(x−πn)
对 z n z_n zn 求积分可得边缘累积概率分布函数:
P ( z n ≥ z n ′ ; ∀ n ′ ≠ n ∣ { π n ′ } n ′ = 1 N ) = ∫ P ( z n ≥ z n ′ ; ∀ n ′ ≠ n ∣ { π n ′ } n ′ = 1 N ) ⋅ f z n ( z n ) d z n P(zn≥zn′;∀n′≠n∣{πn′}Nn′=1)=∫P(zn≥zn′;∀n′≠n∣{πn′}Nn′=1)⋅fzn(zn)dzn P(zn≥zn′=;∀n′=n∣{πn′}n′=1N)∫P(zn≥zn′;∀n′=n∣{πn′}n′=1N)⋅fzn(zn)dzn
带入 CDF 可得:
P ( z n ≥ z n ′ ; ∀ n ′ ≠ n ∣ { π n ′ } n ′ = 1 N ) = ∫ ∏ n ′ ≠ n e − e − ( z n − π n ′ ) ⋅ e − ( z n − π n ) − e − ( z n − π n ) d z n P(zn≥zn′;∀n′≠n∣{πn′}Nn′=1)=∫∏n′≠ne−e−(zn−πn′)⋅e−(zn−πn)−e−(zn−πn)dzn P(zn≥zn′=;∀n′=n∣{πn′}n′=1N)∫n′=n∏e−e−(zn−πn′)⋅e−(zn−πn)−e−(zn−πn)dzn
化简可得:
P ( z n ≥ z n ′ ; ∀ n ′ ≠ n ∣ { π n ′ } n ′ = 1 N ) = ∫ ∏ n ′ ≠ n e − e − ( z n − π n ′ ) ⋅ e − ( z n − π n ) − e − ( z n − π n ) d z n = ∫ e − ∑ n ′ ≠ n e − ( z n − π n ) − ( z n − π n ) − e − ( z n − π n ) d z n = ∫ e − ∑ n ′ = 1 N e − ( z n − π n ′ ) − ( z n − π n ) d z n = ∫ e − ( ∑ n ′ = 1 N e π n ′ ) e − z n − z n + π n d z n = ∫ e − e − z n + ln ( ∑ n ′ = 1 N e π n ′ ) − z n + π n d z n = ∫ e − e − ( z n − ln ( ∑ n ′ = 1 N e π n ′ ) ) − ( z n − ln ( ∑ n ′ = 1 N e π n ′ ) ) − ln ( ∑ n ′ = 1 N e π n ′ ) + π n d z n = e − ln ( ∑ n ′ = 1 N e π n ′ ) + π n ∫ e − e − ( z n − ln ( ∑ n ′ = 1 N e π n ′ ) ) − ( z n − ln ( ∑ n ′ = 1 N e π n ′ ) ) d z n = e π n ∑ n ′ = 1 N e π n ′ ∫ e − e − ( z n − ln ( ∑ n ′ = 1 N e π n ′ ) ) − ( z n − ln ( ∑ n ′ = 1 N e π n ′ ) ) d z n = e π n ∑ n ′ = 1 N e π n ′ ∫ e − ( z n − ln ( ∑ n ′ = 1 N e π n ′ ) ) − e − ( z n − ln ( ∑ n ′ = 1 N e π n ′ ) ) d z n P(zn≥zn′;∀n′≠n∣{πn′}Nn′=1)=∫∏n′≠ne−e−(zn−πn′)⋅e−(zn−πn)−e−(zn−πn)dzn=∫e−∑n′≠ne−(zn−πn)−(zn−πn)−e−(zn−πn)dzn=∫e−∑Nn′=1e−(zn−πn′)−(zn−πn)dzn=∫e−(∑Nn′=1eπn′)e−zn−zn+πndzn=∫e−e−zn+ln(∑Nn′=1eπn′)−zn+πndzn=∫e−e−(zn−ln(∑Nn′=1eπn′))−(zn−ln(∑Nn′=1eπn′))−ln(∑Nn′=1eπn′)+πndzn=e−ln(∑Nn′=1eπn′)+πn∫e−e−(zn−ln(∑Nn′=1eπn′))−(zn−ln(∑Nn′=1eπn′))dzn=eπn∑Nn′=1eπn′∫e−e−(zn−ln(∑Nn′=1eπn′))−(zn−ln(∑Nn′=1eπn′))dzn=eπn∑Nn′=1eπn′∫e−(zn−ln(∑Nn′=1eπn′))−e−(zn−ln(∑Nn′=1eπn′))dzn P(zn≥zn′;∀n′=n∣{πn′}n′=1N)=∫∏n′=ne−e−(zn−πn′)⋅e−(zn−πn)−e−(zn−πn)dzn=∫e−∑n′=ne−(zn−πn)−(zn−πn)−e−(zn−πn)dzn=∫e−∑n′=1Ne−(zn−πn′)−(zn−πn)dzn=∫e−(∑n′=1Neπn′)e−zn−zn+πndzn=∫e−e−zn+ln(∑n′=1Neπn′)−zn+πndzn=∫e−e−(zn−ln(∑n′=1Neπn′))−(zn−ln(∑n′=1Neπn′))−ln(∑n′=1Neπn′)+πndzn=e−ln(∑n′=1Neπn′)+πn∫e−e−(zn−ln(∑n′=1Neπn′))−(zn−ln(∑n′=1Neπn′))dzn=∑n′=1Neπn′eπn∫e−e−(zn−ln(∑n′=1Neπn′))−(zn−ln(∑n′=1Neπn′))dzn=∑n′=1Neπn′eπn∫e−(zn−ln(∑n′=1Neπn′))−e−(zn−ln(∑n′=1Neπn′))dzn
注意到积分内为符合 μ = ln ( ∑ k ′ = 1 K e x k ′ ) \mu=\ln \left(\sum_{k^{\prime}=1}^{K} e^{x_{k^{\prime}}}\right) μ=ln(∑k′=1Kexk′) 的 Gumbel 分布,所以积分的结果为 1 1 1,即:
P ( z n ≥ z n ′ ; ∀ n ′ ≠ n ∣ { π n ′ } n ′ = 1 N ) = e π n ∑ n ′ = 1 N e π n ′ = P ( π n ) P(z_{n} \geq z_{n^{\prime}} ; \forall n^{\prime} \neq n \mid\{\pi_{n^{\prime}}\}_{n^{\prime}=1}^{N})=\frac{e^{\pi_{n} }}{\sum_{n^{\prime}=1}^{N} e^{\pi_{n^{\prime}} }}=P(\pi_n) P(zn≥zn′;∀n′=n∣{πn′}n′=1N)=∑n′=1Neπn′eπn=P(πn)
在 pytorch 中已经给出其实现:
def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> Tensor: if has_torch_function_unary(logits): return handle_torch_function(gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim) if eps != 1e-10: warnings.warn("`eps` parameter is deprecated and has no effect.") gumbels = ( -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() ) # ~Gumbel(0,1) gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) y_soft = gumbels.softmax(dim) if hard: # Straight through. index = y_soft.max(dim, keepdim=True)[1] y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft else: # Reparametrization trick. ret = y_soft return ret
注意这里生成的 G i G_i Gi (对应代码中 gumbels)并没有采用论文中的方式,而是用了其等价方式,即直接从指数分布中采样(tensor.exponential_()),然后再取负对数。
这是因为 − ln U i , U ∼ U ( 0 , 1 ) -\ln U_i,U\sim U(0,1) −lnUi,U∼U(0,1) 的分布符合指数分布,证明如下:
设 Y = − l n ( X ) Y=-ln(X) Y=−ln(X),且 X ∼ U ( 0 , 1 ) X\sim U(0,1) X∼U(0,1),有:
F y ( Y ) = P ( Y < y ) = P ( − ln x < y ) = P ( x > e − y ) = 1 − P ( x ≤ e − y ) = 1 − F x ( e − y ) = 1 − e − y Fy(Y)=P(Y<y)=P(−lnx<y)=P(x>e−y)=1−P(x≤e−y)=1−Fx(e−y)=1−e−y Fy(Y)=P(Y<y)=P(−lnx<y)=P(x>e−y)=1−P(x≤e−y)=1−Fx(e−y)=1−e−y
对其求导可得,其概率密度函数为:
f Y ( y ) = e − y f_Y(y)=e^{-y} fY(y)=e−y
刚好为指数分布。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。