当前位置:   article > 正文

EM算法的一个实例(java)_javaem算法模型

javaem算法模型

      主要在一维情况下实现的EM算法,具体就是实现如下内容。目的是为了帮助学习掌握EM算法。代码并不一定能让你掌握算法的原理,但是通过代码,你可以大体知道算法的流程,之和再去学习原理也就更容易理解和掌握了。





  1. /**
  2. * 一維情況下的EM算法實現
  3. * @author aturbo
  4. *1、求期望(e-step)
  5. *2、期望最大化(估值)(M-step)
  6. *3、循環以上兩部直到收斂
  7. */
  8. public class MyEM {
  9. private static final double[] points={1.0,1.3,2.2,2.6,2.8,5.0,7.3,7.4,7.5,7.7,7.9};
  10. private static double[][] w;
  11. private static double[] means = {7.7,2.3};//均值
  12. private static double[] variances= {1,1};//方差
  13. private static double[] probs = {0.5,0.5};//每个类的概率;这里默认选择k=2了;
  14. /**
  15. * 高斯分布计算公式,也就是先验概率
  16. * @param point
  17. * @param mean
  18. * @param variance
  19. * @return
  20. */
  21. private static double gaussianPro(double point,double mean,double variance){
  22. double prob = 0.0;
  23. prob = (1/(Math.sqrt(2*Math.PI)*Math.sqrt(variance)))*Math.exp(-(point-mean)*(point-mean)/(2*variance));
  24. return prob;
  25. }
  26. /**
  27. * E-step的主要逻辑
  28. * @param means
  29. * @param variances
  30. * @param points
  31. * @param probs
  32. * @return
  33. */
  34. private static double[][] countPostprob(double[] means,double[] variances,double[] points,double[] probs){
  35. int clusterNum = means.length;
  36. int pointNum = points.length;
  37. double[][] postProbs = new double[clusterNum][pointNum];
  38. double[] denominator = new double[pointNum];
  39. for(int m = 0;m <pointNum;m++){
  40. denominator[m] = 0.0;
  41. for(int n = 0;n<clusterNum;n++){
  42. denominator[m]+=(gaussianPro(points[m], means[n], variances[n])*probs[n]);
  43. }
  44. }
  45. for(int i = 0;i<clusterNum;i++){
  46. for(int j = 0;j<pointNum;j++){
  47. postProbs[i][j]=(gaussianPro(points[j], means[i], variances[i])*probs[i])/(denominator[j]);
  48. }
  49. }
  50. return postProbs;
  51. }
  52. private static void eStep(){
  53. w = countPostprob(means, variances, points, probs);
  54. }
  55. /**
  56. * M-step的主要逻辑之一:由E-step得到的期望,重新计算均值
  57. * @param w
  58. * @param points
  59. * @return
  60. */
  61. private static double[] guessMean(double[][] w,double[] points){
  62. int wLength = w.length;
  63. double[] means = new double[w.length];
  64. double[] wi = new double[wLength];
  65. for (int m = 0; m < wLength; m++) {
  66. wi[m] = 0.0;
  67. for(int n = 0; n<points.length;n++){
  68. wi[m] += w[m][n];
  69. }
  70. }
  71. for(int i = 0;i<w.length;i++){
  72. means[i] = 0.0;
  73. for(int j = 0;j<points.length;j++){
  74. means[i]+=(w[i][j]*points[j]);
  75. }
  76. means[i] /= wi[i];
  77. }
  78. return means;
  79. }
  80. /**
  81. * M-step的主要逻辑之一:由E-step得到的期望,重新计算方差
  82. * @param w
  83. * @param points
  84. * @return
  85. */
  86. private static double[] guessVariance(double[][] w,double[] points){
  87. int wLength = w.length;
  88. double[] means = new double[w.length];
  89. double[] variances = new double[wLength];
  90. double[] wi = new double[wLength];
  91. for (int m = 0; m < wLength; m++) {
  92. wi[m] = 0.0;
  93. for(int n = 0; n<points.length;n++){
  94. wi[m] += w[m][n];
  95. }
  96. }
  97. means = guessMean(w, points);
  98. for(int i = 0;i<wLength;i++){
  99. variances[i] = 0.0;
  100. for(int j = 0;j<points.length;j++){
  101. variances[i] +=(w[i][j]*(points[j]-means[i])*(points[j]-means[i]));
  102. }
  103. variances[i] /= wi[i];
  104. }
  105. return variances;
  106. }
  107. /**
  108. * M-step的主要逻辑之一:由E-step得到的期望,重新计算概率
  109. * @param w
  110. * @return
  111. */
  112. private static double[] guessProb(double[][] w){
  113. int wLength = w.length;
  114. double[] probs = new double[wLength];
  115. for(int i = 0;i<wLength;i++){
  116. probs[i] = 0.0;
  117. for(int j = 0;j<w[i].length;j++){
  118. probs[i]+=w[i][j];
  119. }
  120. probs[i] /=w[i].length;
  121. }
  122. return probs;
  123. }
  124. private static void mStep(){
  125. means = guessMean(w, points);
  126. variances = guessVariance(w, points);
  127. probs = guessProb(w);
  128. }
  129. /**
  130. * 计算前后两次迭代的参数的差值
  131. * @param bef_values
  132. * @param values
  133. * @return
  134. */
  135. private static double threshold(double[] bef_values,double[] values){
  136. double diff = 0.0;
  137. for(int i = 0 ; i < values.length;i++){
  138. diff += (values[i]-bef_values[i]);
  139. }
  140. return Math.abs(diff);
  141. }
  142. public static void main(String[] args)throws Exception{
  143. int k = 2;
  144. w = new double[k][points.length];
  145. double[] bef_means;
  146. double[] bef_var;
  147. do{
  148. bef_means = means;
  149. bef_var = variances;
  150. eStep();
  151. mStep();
  152. }while(threshold(bef_means, means)<0.01&&threshold(bef_var, variances)<0.01);
  153. for(double prob:probs)
  154. System.out.println(prob);
  155. }
  156. }

参考文献:

 《Data Mining and Analysis: Fundamental Concepts and Algorithms》

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/煮酒与君饮/article/detail/785517
推荐阅读
相关标签
  

闽ICP备14008679号