赞
踩
主要在一维情况下实现的EM算法,具体就是实现如下内容。目的是为了帮助学习掌握EM算法。代码并不一定能让你掌握算法的原理,但是通过代码,你可以大体知道算法的流程,之和再去学习原理也就更容易理解和掌握了。
- /**
- * 一維情況下的EM算法實現
- * @author aturbo
- *1、求期望(e-step)
- *2、期望最大化(估值)(M-step)
- *3、循環以上兩部直到收斂
- */
- public class MyEM {
- private static final double[] points={1.0,1.3,2.2,2.6,2.8,5.0,7.3,7.4,7.5,7.7,7.9};
- private static double[][] w;
- private static double[] means = {7.7,2.3};//均值
- private static double[] variances= {1,1};//方差
- private static double[] probs = {0.5,0.5};//每个类的概率;这里默认选择k=2了;
-
- /**
- * 高斯分布计算公式,也就是先验概率
- * @param point
- * @param mean
- * @param variance
- * @return
- */
- private static double gaussianPro(double point,double mean,double variance){
- double prob = 0.0;
- prob = (1/(Math.sqrt(2*Math.PI)*Math.sqrt(variance)))*Math.exp(-(point-mean)*(point-mean)/(2*variance));
- return prob;
- }
- /**
- * E-step的主要逻辑
- * @param means
- * @param variances
- * @param points
- * @param probs
- * @return
- */
- private static double[][] countPostprob(double[] means,double[] variances,double[] points,double[] probs){
- int clusterNum = means.length;
- int pointNum = points.length;
- double[][] postProbs = new double[clusterNum][pointNum];
- double[] denominator = new double[pointNum];
- for(int m = 0;m <pointNum;m++){
- denominator[m] = 0.0;
- for(int n = 0;n<clusterNum;n++){
- denominator[m]+=(gaussianPro(points[m], means[n], variances[n])*probs[n]);
- }
- }
- for(int i = 0;i<clusterNum;i++){
- for(int j = 0;j<pointNum;j++){
- postProbs[i][j]=(gaussianPro(points[j], means[i], variances[i])*probs[i])/(denominator[j]);
- }
- }
- return postProbs;
- }
- private static void eStep(){
- w = countPostprob(means, variances, points, probs);
- }
- /**
- * M-step的主要逻辑之一:由E-step得到的期望,重新计算均值
- * @param w
- * @param points
- * @return
- */
- private static double[] guessMean(double[][] w,double[] points){
-
- int wLength = w.length;
- double[] means = new double[w.length];
- double[] wi = new double[wLength];
- for (int m = 0; m < wLength; m++) {
- wi[m] = 0.0;
- for(int n = 0; n<points.length;n++){
- wi[m] += w[m][n];
- }
- }
- for(int i = 0;i<w.length;i++){
- means[i] = 0.0;
- for(int j = 0;j<points.length;j++){
- means[i]+=(w[i][j]*points[j]);
- }
- means[i] /= wi[i];
- }
- return means;
- }
- /**
- * M-step的主要逻辑之一:由E-step得到的期望,重新计算方差
- * @param w
- * @param points
- * @return
- */
- private static double[] guessVariance(double[][] w,double[] points){
- int wLength = w.length;
- double[] means = new double[w.length];
- double[] variances = new double[wLength];
- double[] wi = new double[wLength];
- for (int m = 0; m < wLength; m++) {
- wi[m] = 0.0;
- for(int n = 0; n<points.length;n++){
- wi[m] += w[m][n];
- }
- }
- means = guessMean(w, points);
- for(int i = 0;i<wLength;i++){
- variances[i] = 0.0;
- for(int j = 0;j<points.length;j++){
- variances[i] +=(w[i][j]*(points[j]-means[i])*(points[j]-means[i]));
- }
- variances[i] /= wi[i];
- }
-
- return variances;
- }
- /**
- * M-step的主要逻辑之一:由E-step得到的期望,重新计算概率
- * @param w
- * @return
- */
- private static double[] guessProb(double[][] w){
- int wLength = w.length;
- double[] probs = new double[wLength];
- for(int i = 0;i<wLength;i++){
- probs[i] = 0.0;
- for(int j = 0;j<w[i].length;j++){
- probs[i]+=w[i][j];
- }
- probs[i] /=w[i].length;
- }
- return probs;
- }
- private static void mStep(){
- means = guessMean(w, points);
- variances = guessVariance(w, points);
- probs = guessProb(w);
- }
- /**
- * 计算前后两次迭代的参数的差值
- * @param bef_values
- * @param values
- * @return
- */
- private static double threshold(double[] bef_values,double[] values){
- double diff = 0.0;
- for(int i = 0 ; i < values.length;i++){
- diff += (values[i]-bef_values[i]);
- }
- return Math.abs(diff);
- }
- public static void main(String[] args)throws Exception{
-
- int k = 2;
- w = new double[k][points.length];
- double[] bef_means;
- double[] bef_var;
- do{
- bef_means = means;
- bef_var = variances;
- eStep();
- mStep();
- }while(threshold(bef_means, means)<0.01&&threshold(bef_var, variances)<0.01);
- for(double prob:probs)
- System.out.println(prob);
- }
- }

《Data Mining and Analysis: Fundamental Concepts and Algorithms》
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。