当前位置:   article > 正文

高斯混合模型(GMM)及EM算法---MATLAB程序_混合高斯gmm em程序matlab

混合高斯gmm em程序matlab

        大家喜欢的话记得关注、点赞、收藏哦~

        高斯混合模型(Gaussian Mixed Model,GMM)是由多个高斯分布函数组成的线性组合。理论上,GMM可以拟合出任意类型的分布,通常用于解决同一集合下的数据包含多个不同分布的情况。---转自:http://t.csdn.cn/SPEcN

        设随机变量 X服从混合高斯分布(Mixture Gaussian Distribution),其概率密度函数如下所示:

         其中,K为分量数,若用两个二维高斯分布来表示,则有分量数K = 2 ;\alpha _{k}为混合系数(mixture coefficient),可以看作每个分量N\left ( x,\mu _{k},\Sigma _{k}\right )的权重,也可以看作每个分量被选中的概率,混合系数\alpha _{k}满足如下关系:

         读到这里,想必各位心中都会自然而然地产生一个问题:各分量的分布参数和混合系数要怎么确定?这里就需要引入最大期望算法(Expectation-maximization algorithm,EM)对各参数进行估计!!!(核心)

        为了确保文章的完整性,GMM的参数估计过程引用了博客http://t.csdn.cn/SPEcN的部分内容:

        引入一个新的K维随机变量z(0-1变量),P\left ( z_{k}=1 \right )=\alpha _{k}表示第k个分量被选中的概率为\alpha _{k}z_{k}满足如下关系:

         假设z_{k}之间独立同分布,我们可以写出z的联合概率分布:

         被选中分量的概率密度函数可以写作如下形式:

         进而,

         上面分别给出了P\left ( z \right )P\left ( x\mid z \right ),根据全概率公式,可以求出P\left ( x \right )

         其中,\sum_{z}表示对z的K种情况进行累加。上式与GMM的定义式具有相同的形式,但上式引入了新的变量z,通常称为隐含变量(latent variable)『隐含』的意义是:我们知道数据可以分成两类,但是随机抽取一个数据点,我们不知道这个数据点属于第一类还是第二类,有多少概率属于第一类,有多少概率属于第二类,它的归属我们观察不到,因此引入一个隐含变量z来描述这个现象。(转自:http://t.csdn.cn/SPEcN,我感觉这段话写得特别好)

        EM算法即可用于求解具有隐含变量的参数估计问题。

         在贝叶斯的思想下,P\left ( z \right )是先验概率, P\left ( x\mid z \right )是似然概率,P\left ( z\mid x \right )是后验概率。为方便使用EM算法估计GMM模型的参数,需要推导后验概率P\left ( z\mid x \right )的表达式,如下所示:

      至此,我们完成了所有的前期准备工作,接下来引入EM算法以估计GMM模型的各个参数。

     EM算法是在概率模型中寻找参数最大似然估计或者最大后验估计的算法,其中概率模型依赖于无法观测的隐性变量。

        EM算法经过两个步骤交替进行计算:

        第一步:计算期望(E步),利用对隐藏变量的现有估计值,计算其最大似然估计值;
        第二步:最大化(M步),最大化在E步上求得的最大似然值来计算参数的值。M步上找到的参数估计值被用于下一个E步计算中,这个过程不断交替进行。(转自:http://t.csdn.cn/gTdx1

        GMM模型中待估计的参数有\alpha ,\mu ,\Sigma,将定义式改写为以下形式

         写出上式的极大似然函数:

         对各参数\alpha ,\mu ,\Sigma求偏导并令导数为0可得(推导过程详见 http://t.csdn.cn/iRa9n )

         其中,

         算法流程如下:

        1) 确定分量数K,利用k-means算法对样本进行划分,计算各分量对应样本的均值和协方差作为\mu _{k}\Sigma_{k}的初值,\alpha _{k}取1/K,即各分量的权重相等

        2) E步:根据当前的\alpha _{k},\mu _{k}\Sigma_{k}计算后验概率\gamma\left ( z_{nk}\right )

        3) M步:根据步骤2)中的 \gamma\left ( z_{nk}\right )重新计算 \alpha _{k},\mu _{k}\Sigma_{k} 

         其中,

        4) 计算极大似然函数

         5) 检查极大似然函数是否收敛,若收敛,输出当前的\alpha _{k},\mu _{k}\Sigma_{k} ,否则返回步骤2)。

        本文的MATLAB代码如下:

        一维数据:

  1. clear all
  2. clc
  3. close all
  4. samples = [normrnd(5,2,[1000,1]);normrnd(0,1,[1000,1])];%这里的sigma是标准差standard deviation
  5. D=size(samples,2);
  6. N=size(samples,1);
  7. K=2;%用K个正态分布拟合
  8. Pi=ones(1,K)/K;
  9. Miu=cell(K,1);
  10. Sigma2=cell(K,1);
  11. %% %K均值聚类确定初值
  12. [idx,center]=kmeans(samples,K);
  13. for i=1:K
  14. miu0=center(i,:);
  15. sigma0=var(samples(find(idx==i),:));
  16. Miu{i,1}=miu0;
  17. Sigma2{i,1}=sigma0;
  18. end
  19. beta=inf;
  20. likelihood_function_value=0;
  21. record=[];
  22. %% %EM算法
  23. while(1)
  24. %% %E步
  25. gama=zeros(N,K);
  26. samples_pd=zeros(N,K);
  27. for j=1:K
  28. samples_pd(:,j)=normpdf(samples,Miu{j,1},sqrt(Sigma2{j,1}));
  29. end
  30. for i=1:N
  31. for j=1:K
  32. gama(i,j)=Pi(j)*samples_pd(i,j)/(Pi*samples_pd(i,:)');
  33. end
  34. end
  35. likelihood_function_value_old=likelihood_function_value;
  36. likelihood_function_value=sum(log(sum(samples_pd.*repmat(Pi,N,1),1)));
  37. record=[record,likelihood_function_value];
  38. beta=abs(likelihood_function_value-likelihood_function_value_old);
  39. if beta<0.0001
  40. plot(1:length(record),record)
  41. break
  42. end
  43. %% %M步
  44. Nk=sum(gama,1);
  45. for j=1:K
  46. Miu{j,1}=zeros(1,D);
  47. Sigma2{j,1}=zeros(D,D);
  48. for i=1:N
  49. Miu{j,1}=Miu{j,1}+gama(i,j)*samples(i,:)/Nk(j);
  50. end
  51. for i=1:N
  52. Sigma2{j,1}=Sigma2{j,1}+(gama(i,j)*(samples(i,:)-Miu{j,1})'*(samples(i,:)-Miu{j,1}))/Nk(j);
  53. end
  54. end
  55. Pi=Nk/N;
  56. end
  57. X = [min(samples(:,1)):0.01:max(samples(:,1))]';
  58. YX=zeros(size(X,1),K);
  59. for j=1:K
  60. YX(:,j)=normpdf(X,Miu{j,1},sqrt(Sigma2{j,1}));
  61. end
  62. figure(2)
  63. [count,centers]=hist(samples,100);
  64. bar(centers,count,'FaceColor',[0.72,0.89,0.98])
  65. hold on
  66. plot(X, YX(:,1)/max(YX(:,1))*max(count),'color',[0.85,0.33,0.10],'linewidth',2)
  67. plot(X, YX(:,2)/max(YX(:,2))*max(count),'color',[0.93,0.69,0.13],'linewidth',2)
  68. YX = sum(YX.*repmat(Pi,size(X,1),1),2);
  69. YX=YX/max(YX)*max(count);
  70. plot(X,YX,'color',[0.00,0.45,0.74],'linewidth',2)

 多维数据:

  1. clear all
  2. clc
  3. close all
  4. mu1 = [5,2]; % 均值
  5. sigma1 = [1 , 0.9;
  6. 0.9 , 1]; % 协方差
  7. mu2 = [2,5]; % 均值
  8. sigma2 = [1 , 0.45;
  9. 0.45 , 1]; % 协方差
  10. samples = [mvnrnd(mu1,sigma1,1000);mvnrnd(mu2,sigma2,1000)];%这里的sigma是协方差covariances
  11. D=size(samples,2);
  12. N=size(samples,1);
  13. K=2;%用K个正态分布拟合
  14. Pi=ones(1,K)/K;
  15. Miu=cell(K,1);
  16. Sigma=cell(K,1);
  17. [idx,center]=kmeans(samples,K);
  18. for i=1:K
  19. miu0=center(i,:);
  20. sigma0=cov(samples(find(idx==i),:));
  21. Miu{i,1}=miu0;
  22. Sigma{i,1}=sigma0;
  23. end
  24. beta=inf;
  25. likelihood_function_value=0;
  26. record=[];
  27. %% %EM算法
  28. while(1)
  29. %% %E步
  30. gama=zeros(N,K);
  31. samples_pd=zeros(N,K);
  32. for j=1:K
  33. samples_pd(:,j)=mvnpdf(samples,Miu{j,1},Sigma{j,1});
  34. end
  35. for i=1:N
  36. for j=1:K
  37. gama(i,j)=Pi(j)*samples_pd(i,j)/(Pi*samples_pd(i,:)');
  38. end
  39. end
  40. likelihood_function_value_old=likelihood_function_value;
  41. likelihood_function_value=sum(log(sum(samples_pd.*repmat(Pi,N,1),1)));
  42. record=[record,likelihood_function_value];
  43. beta=abs(likelihood_function_value-likelihood_function_value_old);
  44. if beta<0.001
  45. plot(1:length(record),record)
  46. break
  47. end
  48. %% %M步
  49. Nk=sum(gama,1);
  50. for j=1:K
  51. Miu{j,1}=zeros(1,D);
  52. Sigma{j,1}=zeros(D,D);
  53. for i=1:N
  54. Miu{j,1}=Miu{j,1}+gama(i,j)*samples(i,:)/Nk(j);
  55. end
  56. for i=1:N
  57. Sigma{j,1}=Sigma{j,1}+(gama(i,j)*(samples(i,:)-Miu{j,1})'*(samples(i,:)-Miu{j,1}))/Nk(j);
  58. end
  59. end
  60. Pi=Nk/N;
  61. end
  62. X = gridsamp([min(samples(:,1)) min(samples(:,2));max(samples(:,1)) max(samples(:,2))], 40);
  63. YX0=zeros(size(X,1),K);
  64. for j=1:K
  65. YX0(:,j)=mvnpdf(X,Miu{j,1},Sigma{j,1});
  66. end
  67. X1 = reshape(X(:,1),40,40); X2 = reshape(X(:,2),40,40);
  68. YX = reshape(YX0(:,1), size(X1));
  69. figure(2), mesh(X1, X2, YX,'edgecolor',[0.85,0.33,0.10])
  70. YX = reshape(YX0(:,2), size(X1));
  71. hold on,mesh(X1, X2, YX,'edgecolor',[0.93,0.69,0.13])
  72. YX = reshape(sum(YX0.*repmat(Pi,size(X,1),1),2), size(X1));
  73. figure(3),mesh(X1, X2, YX,'edgecolor',[0.00,0.45,0.74])

  1. function S = gridsamp(range, q)
  2. %GRIDSAMP n-dimensional grid over given range
  3. %
  4. % Call: S = gridsamp(range, q)
  5. %
  6. % range : 2*n matrix with lower and upper limits
  7. % q : n-vector, q(j) is the number of points
  8. % in the j'th direction.
  9. % If q is a scalar, then all q(j) = q
  10. % S : m*n array with points, m = prod(q)
  11. % hbn@imm.dtu.dk
  12. % Last update June 25, 2002
  13. [mr n] = size(range); dr = diff(range);
  14. if mr ~= 2 | any(dr < 0)
  15. error('range must be an array with two rows and range(1,:) <= range(2,:)')
  16. end
  17. sq = size(q);
  18. if min(sq) > 1 | any(q <= 0)
  19. error('q must be a vector with non-negative elements')
  20. end
  21. p = length(q);
  22. if p == 1, q = repmat(q,1,n);
  23. elseif p ~= n
  24. error(sprintf('length of q must be either 1 or %d',n))
  25. end
  26. % Check for degenerate intervals
  27. i = find(dr == 0);
  28. if ~isempty(i), q(i) = 0*q(i); end
  29. % Recursive computation
  30. if n > 1
  31. A = gridsamp(range(:,2:end), q(2:end)); % Recursive call
  32. [m p] = size(A); q = q(1);
  33. S = [zeros(m*q,1) repmat(A,q,1)];
  34. y = linspace(range(1,1),range(2,1), q);
  35. k = 1:m;
  36. for i = 1 : q
  37. S(k,1) = repmat(y(i),m,1); k = k + m;
  38. end
  39. else
  40. S = linspace(range(1,1),range(2,1), q).';
  41. end

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/920124
推荐阅读
相关标签
  

闽ICP备14008679号