赞
踩
.# 概率基础——EM算法
在概率统计领域,期望最大化(Expectation Maximization,EM)算法是一种常用的迭代优化方法,特别适用于含有隐变量的概率模型参数估计。EM算法的原理相对复杂,但其应用广泛,尤其在机器学习和数据挖掘领域中发挥了重要作用。本篇博客将介绍EM算法的理论基础、公式推导、应用案例以及Python实现,并通过例子解释其有效性。
在极大似然估计中,我们通常通过求解最大化似然函数的方法来估计模型参数。然而,当模型含有隐变量时,直接求解似然函数可能变得困难甚至不可行。EM算法通过引入隐变量的期望(Expectation)步骤和参数最大化(Maximization)步骤,以迭代的方式优化模型参数。
EM算法的核心思想在于通过迭代优化,逐步提高似然函数的值,从而得到模型参数的估计值。
在E步骤中,我们需要计算给定观测数据下隐变量的后验概率。设观测数据为 X X X,隐变量为 Z Z Z,模型参数为 θ θ θ,我们要求解的是:
Q ( θ ∣ θ ( t ) ) = E Z [ l o g P ( X , Z ∣ θ ) ∣ X , θ ( t ) ] Q(\theta|\theta^{(t)}) = E_Z[logP(X,Z|\theta)|X, \theta^{(t)}] Q(θ∣θ(t))=EZ[logP(X,Z∣θ)∣X,θ(t)]
其中, θ ( t ) \theta^{(t)} θ(t)表示第t次迭代后的参数估计值。利用贝叶斯公式,我们有:
Q ( θ ∣ θ ( t ) ) = ∑ Z l o g P ( X , Z ∣ θ ) ⋅ P ( Z ∣ X , θ ( t ) ) Q(\theta|\theta^{(t)}) = \sum_Z logP(X,Z|\theta) \cdot P(Z|X, \theta^{(t)}) Q(θ∣θ(t))=Z∑logP(X,Z∣θ)⋅P(Z∣X,θ(t))
在M步骤中,我们需要更新模型参数,使得似然函数达到最大化。我们的目标是求解:
θ ( t + 1 ) = a r g m a x θ Q ( θ ∣ θ ( t ) ) \theta^{(t+1)} = argmax_\theta Q(\theta|\theta^{(t)}) θ(t+1)=argmaxθQ(θ∣θ(t))
通过求解上述优化问题,得到新的参数估计值。
EM算法的有效性可以从多个角度解释:
假设我们观测到一组数据,但无法确定其真实的概率分布。我们假设这些数据由多个高斯分布混合而成,每个高斯分布对应一个隐变量,表示数据点来自于哪个分布。我们可以使用EM算法估计GMM的参数,包括每个分布的均值、方差和混合系数。
下面通过一个简单的例子,使用Python实现EM算法对GMM模型进行参数估计,并绘制出分布图像。
import numpy as np import matplotlib.pyplot as plt from scipy.stats import norm # 生成数据 np.random.seed(0) data = np.concatenate([np.random.normal(0, 1, 1000), np.random.normal(5, 1, 1000)]) # 初始化参数 mu1, sigma1, weight1 = -1, 1, 0.5 mu2, sigma2, weight2 = 1, 1, 0.5 # EM算法 def expectation(data, mu1, sigma1, weight1, mu2, sigma2, weight2): p1 = norm.pdf(data, mu1, sigma1) * weight1 p2 = norm.pdf(data, mu2, sigma2) * weight2 total = p1 + p2 gamma1 = p1 / total gamma2 = p2 / total return gamma1, gamma2 def maximization(data, gamma1, gamma2): mu1 = np.sum(gamma1 * data) / np.sum(gamma1) mu2 = np.sum(gamma2 * data) / np.sum(gamma2) sigma1 = np.sqrt(np.sum(gamma1 * (data - mu1) ** 2) / np.sum(gamma1)) sigma2 = np.sqrt(np.sum(gamma2 * (data - mu2) ** 2) / np.sum(gamma2)) weight1 = np.mean(gamma1) weight2 = np.mean(gamma2) return mu1, sigma1, weight1, mu2, sigma2, weight2 def EM(data, mu1, sigma1, weight1, mu2, sigma2, weight2, epochs=100, tol=1e-6): for _ in range(epochs): gamma1, gamma2 = expectation(data, mu1, sigma1, weight1, mu2, sigma2, weight2) mu1_new, sigma1_new, weight1_new, mu2_new, sigma2_new, weight2_new = maximization(data, gamma1, gamma2) if np.abs(mu1_new - mu1) < tol and np.abs(sigma1_new - sigma1) < tol and \ np.abs(weight1_new - weight1) < tol and np.abs(mu2_new - mu2) < tol and \ np.abs(sigma2_new - sigma2) < tol and np.abs(weight2_new - weight2) < tol: break mu1, sigma1, weight1, mu2, sigma2, weight2 = mu1_new, sigma1_new, weight1_new, mu2_new, sigma2_new, weight2_new return mu1, sigma1, weight1, mu2, sigma2, weight2 mu1, sigma1, weight1, mu2, sigma2, weight2 = EM(data, mu1, sigma1, weight1, mu2, sigma2, weight2) # 绘制数据分布及估计的高斯分布 x = np.linspace(-5, 10, 1000) plt.hist(data, bins=50, density=True, alpha=0.5, color='gray') plt.plot(x, norm.pdf(x, mu1, sigma1) * weight1, label='Gaussian 1', color='blue') plt.plot(x, norm.pdf(x, mu2, sigma2) * weight2, label='Gaussian 2', color='red') plt.legend() plt.title('Gaussian Mixture Model') plt.xlabel('x') plt.ylabel('Density') plt.show()
上述代码实现了一个简单的GMM模型的EM算法。首先,我们生成了一组模拟数据,然后通过EM算法估计数据分布的参数,并绘制出数据分布及估计的高斯分布图像。
EM算法作为一种经典的迭代优化算法,在概率统计领域有着重要的应用。通过本文的介绍,我们了解了EM算法的理论基础、公式推导、有效性以及应用案例,并通过Python代码实现了对GMM模型的参数估计。希望本文对您理解EM算法有所帮助。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。