赞
踩
前言:在正式讲述朴素贝叶斯分类器之前,先介绍清楚两个基本概念:判别学习方法(Discriminative Learning Algorithm)和生成学习方法(Generative Learning Algorithm)。
上篇博文我们使用Logistic回归解决二类分类问题,解决过程是在解空间中寻找一条直线(其实准确来说是曲线、曲面、超平面)直接把两个样本的类别区分开,作为决策边界。当我们有一个新的样本加入时,直接看它在边界的哪一侧便可。这种解决方法便是判别学习方法。而生成学习方法,则是分别对两个类别进行建模,当有新的样本加入时,我们可以看看这个样本对哪个模型匹配程度较高,便可判断它属于哪个类型。从数学上讲,就是计算这个新样本分别匹配两个模型的后验概率。更形式化来讲,判别学习方法是直接对P(y|x)进行建模或者直接学习输入空间到输出空间的映射关系,其中x是样本的特征向量,y是输出的分类类别,其本身不能反映训练数据本身的特性,反映的是异类数据之间的差异,直接面对预测,往往学习的准确率更高,简化学习问题。就好比我们区分绵羊和山羊的羊群时,给它们划一条边界;而生成学习方法则是先对P(x|y)(条件概率)和P(x)(先验概率)进行建模,然后按照贝叶斯公式计算P(y|x)。所以就可以从统计的角度表示数据的分布情况,能够反映同类数据本身的相似度,并且训练数据量较少时仍然适用。但它不关心到底划分各类的那个分类边界在哪。就好比我们区分山羊和绵羊时,先掌握它们各自的特征(有毛没毛,有角无角等),然后再进行区分(稍微想一下,如果出现一个羊羔,你很难简单地根据边界来判断它是属于绵羊或山羊,此时使用生成学习方法计算它分别属于绵羊和山羊的概率作出的决策可能会更加保险)。
常见的判别模型有:线性回归(包括局部线性)、Logistic回归、支持向量机、神经网络等。
常见的生成模型有:朴素贝叶斯模型、隐马尔科夫模型、高斯混合模型等。
一、问题引入
前言所述,本文所要介绍的算法是朴素贝叶斯分类器,分为输入的特征向量为连续和离散两种情况,属于生成模型,区别于Logistic回归的判别模型。实际上,Logistic回归有一个泛化(generalization)的模型:高斯判别模型(GDA),属于生成模型,它与Logistic回归的关系如图:
其中使用的类条件分布是多元高斯分布。
此外本文再介绍贝叶斯分类器的另一个最典型也是最常用的应用:垃圾邮件分类,此模型使用的类条件分布是多项式分布。这两种分类模型都基于贝叶斯公式,只是使用了不同的类条件分布,即条件概率的分布函数。
根据贝叶斯公式:
一般情况下,我们认为P(x)(先验概率)对于每个类别来说是一个常数,根据类别数而定,即
其中N是训练集的对象数目,Nc是属于C类的对象数量。P(y)是当前训练数据上y属于某一类的概率。这些变量都可以直接根据训练样本计算,不难求得。计算的重点和难点在于P(x|y),取决于它服从什么分布。
二、问题分析
我们首先分析基于高斯判别模型的贝叶斯分类器。
1.高斯判别模型
(1)在GDA中,假设P(x|y)服从多元高斯分布。多元高斯分布是高斯分布在多维变量下的扩展,它的参数是均值向量μ和协方差矩阵Σ,Σ是n阶对称正定矩阵。多元高斯分布的概率密度公式为:
其中|Σ|是矩阵的行列式的值。协方差矩阵可以由协方差函数Cov求得,协方差函数的计算公式是:
假定y有c个类别,即y∈{1, 2, … , c},那么条件概率服从以下公式:
之后就是对均值向量μ和协方差矩阵∑进行参数估计了,仍然是使用最大似然估计,似然函数以及计算过程已经介绍过多次,下面直接给出结果:
下面给出三类数据分类的Matlab代码实现。
(2)代码实现
1)导入数据并绘制散点图
% 导入数据
load('bc_data')
% 绘制散点图
cl = unique(t); % 类别标签去重
col = {
'ko','kd','ks'} % 指定散点形状
fcol = {[1 0 0],[0 1 0],[0 0 1]}; % 指定散点颜色
figure(1);
hold off
for c = 1:length(cl)
pos = find(t==cl(c)); % 分别查找类别1,2,3的元素位置下标
plot(X(pos,1),X(pos,2),col{c},... % 分别画出每种类别的散点图
'markersize',10,'linewidth',2,...
'markerfacecolor',fcol{c});
hold on
end
xlim([-3 7])
ylim([-6 6])
绘制效果如下:
2) 假设各类服从多元高斯分布
%% 计算各类的均值向量和协方差矩阵
class_var = [];
for c = 1:length(cl)
pos = find(t==cl(c));
class_mean(c,:) = mean(X(pos,:));
class_var
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。