当前位置:   article > 正文

菜鸟进阶: C++实现KNN文本分类算法_c++ knn 库

c++ knn 库

转载自:http://www.cnblogs.com/finallyliuyu/archive/2010/09/26/1836285.html

原始作者:finallyliuyu(转载请注明原作者和出处)

(代码暂不发布源码下载版,以后会发布)

    KNN文本分类算法又称为(k nearest neighhor)。它是一种基于事例的学习方法,也称懒惰式学习方法。

    它的大概思路是:对于某个待分类的样本点,在训练集中找离它最近的k个样本点,并观察这k个样本点所属类别。看这k个样本点中,那个类别出现的次数多,则将这类别标签赋予该待分类的样本点。

   通过上面的描述,可以看出KNN算法在算法实现上是很简单的,并不十分困难。

  1. 给出代码之前,先给出实验条件。

1。语料库格式:

语料库存放在MSSQLSERVER2000的数据库的表单中,表单格式如下:

TIDJ01S{T~M_[N{]5)`8%PK

(fig 1)

2。如何获得该形式的语料库?

你可以从搜狗lab下载2008年的数据,并且用我的程序对这批数据进行处理,抽取出新闻。处理程序见《菜鸟学习C++练笔之整理搜狗2008版语料库--获取分类语料库》或者去下载我上传到博客园的语料资源见《献给热衷于自然语言处理的业余爱好者的中文新闻分类语料库之二

3。分割出训练语料库与测试语料库(训练语料库和测试语料库也是MSSQL表单,格式同fig1)。关于MSSQLSERVER的一些表复制的技巧见:《MSSQL语句备份

  1. 下面开始给出C++代码:

如果一些函数代码没有给出,请您参阅《菜鸟进阶:C++实现Chi-square 特征词选择算法》以及K-means文本聚类系列(已经完成)

建立VSM模型(考虑到效率问题对训练样本集合与测试样本集采用不同的函数建立VSM模型)

1。对训练集建立VSM模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
*****************以下函数辅助完成聚类功能*********************************************************************8**********************/
/************************************************************************/
/* 建立文档向量模型                                                                     */
/************************************************************************/
map< int ,vector< double > > Preprocess::VSMConstruction(map<string,vector<pair< int , int >>> &mymap)
{  
     clock_t start,finish;
     double totaltime;
     start= clock ();
     int corpus_N=endIndex-beginIndex+1;
     map< int ,vector< double >> vsmMatrix;
     vector<string> myKeys=GetFinalKeyWords();
     vector<pair< int , int > >maxTFandDF=GetfinalKeysMaxTFDF(mymap);
     for ( int i=beginIndex;i<=endIndex;i++)
     {  
         vector<pair< int , double > >tempVSM;
         vector< double >tempVSM2;
         for (vector<string>::size_type j=0;j<myKeys.size();j++)
         {
             //vector<pair<int,int> >::iterator findit=find_if(mymap[myKeys[j]].begin(),mymap[myKeys[j]].end(),PredTFclass(i));
             double TF=( double )count_if(mymap[myKeys[j]].begin(),mymap[myKeys[j]].end(),PredTFclass(i));
             TF=0.5+( double )TF/(maxTFandDF[j].first);
             TF*= log (( double )corpus_N/maxTFandDF[j].second);
             tempVSM.push_back(make_pair(j,TF));
 
         }
         if (!tempVSM.empty())
         {
             tempVSM=NormalizationVSM(tempVSM);
             //
             for (vector<pair< int , double > >::iterator it=tempVSM.begin();it!=tempVSM.end();it++)
             {
                 tempVSM2.push_back(it->second);
             }
             vsmMatrix[i]=tempVSM2;
 
 
 
         }
         tempVSM.clear();
         tempVSM2.clear();
 
 
 
     }
     finish= clock ();
     totaltime=( double )(finish-start)/CLOCKS_PER_SEC;
     cout<< "为训练语料库集合建立VSM模型共用了" <<totaltime<<endl;
 
 
     return vsmMatrix;
 
}

2。对测试集建立VSM模型。

这里值得一提的是在tf-idf计算特征VSM模型特征词权重的时候,tf:计算的是该词在该篇文章中出现的次数。idf:用的是训练集计算出的idf值。原因在于:在一个分类系统,我们假设代分类的文档是一篇一篇进入分类系统中来的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
/************************************************************************/
/* 获得待分类文档集合的VSM模型                                            */
/************************************************************************/
map< int ,vector< double >> Preprocess::GetManyVSM( int begin, int end,map<string,vector<pair< int , int >>> &mymap)
{
     map< int ,vector< double > > testingVSMMatrix;
     
     vector<string>keywords=GetFinalKeyWords();
     char * selectbySpecificId= new char [1000];
     memset (selectbySpecificId,0,1000);
     sprintf_s(selectbySpecificId,1000, "select ArticleId,CAbstract from Article where ArticleId between %d and %d" ,begin,end);
     set<string>stopwords=MakeStopSet();
     if (!ICTCLAS_Init())
     {
         printf ( "ICTCLAS INIT FAILED!\n" );
         string strerr( "there is a error" );
 
     }
     ICTCLAS_SetPOSmap(ICT_POS_MAP_SECOND);
     //导入用户词典后
     printf ( "\n导入用户词典后:\n" );
     int nCount = ICTCLAS_ImportUserDict( "dict.txt" ); //覆盖以前的用户词典
     //保存用户词典
     ICTCLAS_SaveTheUsrDic();
     printf ( "导入%d个用户词。\n" , nCount);
     CoInitialize(NULL);
     _ConnectionPtr pConn(__uuidof(Connection));
     _RecordsetPtr pRst(__uuidof(Recordset));
     pConn->ConnectionString= "Provider=SQLOLEDB.1;Password=xxxx;Persist Security Info=True; User ID=sa;Initial Catalog=ArticleCollection" ;
     pConn->Open( "" , "" , "" ,adConnectUnspecified);
     pRst=pConn->Execute(selectbySpecificId,NULL,adCmdText);
     while (!pRst->rsEOF)
     {
         string rawtext=(_bstr_t)pRst->GetCollect( "CAbstract" );
         if (rawtext!= "" )
         {
             string tempid=(_bstr_t)pRst->GetCollect( "ArticleId" );
             int articleid= atoi (tempid.c_str());
             vector<string>wordcollection=goodWordsinPieceArticle(rawtext,stopwords); //表示这篇文章的词
             vector<pair< int , int > >maxTFandDF=GetfinalKeysMaxTFDF(mymap);
             int corpus_N=endIndex-beginIndex+1;
             vector<pair< int , double > >tempVSM;
             vector< double >vsm;
             for (vector<string>::size_type j=0;j<keywords.size();j++)
             {
                 double TF=( double )count_if(wordcollection.begin(),wordcollection.end(),GT_cls(keywords[j]));
                 TF=0.5+( double )TF/(maxTFandDF[j].first);
                 TF*= log (( double )corpus_N/maxTFandDF[j].second);
                 tempVSM.push_back(make_pair(j,TF));
 
             }
             if (!tempVSM.empty())
             {
                 tempVSM=NormalizationVSM(tempVSM);
                 for (vector<pair< int , double > >::iterator it=tempVSM.begin();it!=tempVSM.end();it++)
                 {
                     vsm.push_back(it->second);
                 }
                 testingVSMMatrix[articleid]=vsm;
 
 
 
             }
 
 
 
         }
         
         pRst->MoveNext();
     }
     
     pRst->Close();
     pConn->Close();
     pRst.Release();
     pConn.Release();
     CoUninitialize();
     delete []selectbySpecificId;
     ICTCLAS_Exit();
     return testingVSMMatrix;
     
 
     
 
 
}

 

对VSM序列化和反序列化的操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
/************************************************************************/
/*  将VSM模型序列化到本地硬盘                                                                    */
/************************************************************************/
void Preprocess::SaveVSM(map< int ,vector< double > >&VSMmatrix, char * dest)
{   clock_t start,finish;
     double totaltime;
     start= clock ();
     ofstream ofile(dest,ios::binary);
     for (map< int ,vector< double > >::iterator it=VSMmatrix.begin();it!=VSMmatrix.end();++it)
     {
         ofile<<it->first<<endl;
         vector< double >::iterator subit;
         ofile<<it->second.size()<<endl;
         for (subit=(it->second).begin();subit!=(it->second).end();++subit)
         {
             ofile<<*subit<< " " ;
         }
         ofile<<endl;
     
 
 
     }
     ofile.close();
     finish= clock ();
     totaltime=( double )(finish-start)/CLOCKS_PER_SEC;
     cout<< "将语料库集合的VSM模型为序列化到硬盘的时间为" <<totaltime<<endl;
 
 
}
 
/************************************************************************/
/* 加载VSM模型到内存                                                                     */
/************************************************************************/
void Preprocess::LoadVSM(map< int ,vector< double > >&VSMmatrix, char * dest)
{  
     clock_t start,finish;
     double totaltime;
     start= clock ();
     ifstream  ifile(dest,ios::binary);
     int articleId; //文章id;
     int lenVec; //id对应的vsm的长度
     double val; //暂存数据
     vector< double >vsm;
     while (!ifile.eof())
     {
         ifile>>articleId;
         ifile>>lenVec;
         for ( int i=0;i<lenVec;i++)
         {
             ifile>>val;
             vsm.push_back(val);
         }
         VSMmatrix[articleId]=vsm;
         vsm.clear();
 
     }
     ifile.close();
     finish= clock ();
     totaltime=( double )(finish-start)/CLOCKS_PER_SEC;
     cout<< "加载VSM模型到内存的时间为" <<totaltime<<endl;
 
 
 
}

 

对一篇文章用KNN方法进行分类的函数(这里距离的定义采用余弦相似度):

 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
/************************************************************************/
/* 对一篇文章分类获取其类别标签   N为KNN中的N的取值                                      */
/************************************************************************/
string Preprocess:: KNNClassificationCell( int N,vector< double >vsm,vector<string>categorization,map<string,vector<pair< int , int >>> &mymap,map< int ,vector< double > >&trainingsetVSM){
 
     clock_t start,finish;
     double totaltime;
     start= clock ();
 
 
     string classLabel;
     //map<int,vector<double> >trainingsetVSM=VSMConstruction(mymap);
     //vector<double>toBeClassifyDoc=GetSingleVSM(articleId,mymap);
     vector<pair< int , double > >SimilaritySore; //保存待分类样本与训练样本集的测试得分
     //计算相似度得分
     for (map< int ,vector< double > >::iterator it=trainingsetVSM.begin();it!=trainingsetVSM.end();it++)
     {
         double score=CalCosineofVectors(vsm,it->second);
         SimilaritySore.push_back(make_pair(it->first,score));
 
     }
     //将相似度运算结果从高到底排序
     stable_sort(SimilaritySore.begin(),SimilaritySore.end(),isLarger2);
     ostringstream out;
     string articleIds;
     out<< "(" ;
     int putComma=0;
     for (vector<pair< int , double > >::size_type j=0;j<N;j++)
     {
         out<<SimilaritySore[j].first;
         if (putComma<N-1)
         {
             out<< "," ;
 
         }
         putComma++;
 
 
     }
     out<< ")" ;
     articleIds=out.str(); //获得和待分类文档距离最近的前N个文档的id字符串
     vector<string> labels=GetClassification(articleIds);
     for (vector<string>::iterator it=labels.begin();it!=labels.end();it++)
     {
         trim(*it, " " );
     }
     vector<pair<string, int > >vectorAssit;
     for ( int i=0;i<categorization.size();i++)
     {
         int num=count_if(labels.begin(),labels.end(),GT_cls(categorization[i]));
         vectorAssit.push_back(make_pair(categorization[i],num));
     }
     stable_sort(vectorAssit.begin(),vectorAssit.end(),isLarger);
     finish= clock ();
     totaltime=( double )(finish-start)/CLOCKS_PER_SEC;
     cout<< "对一篇文章进行KNN分类的时间为" <<totaltime<<endl;
 
 
     return vectorAssit[0].first;
 
     
 
 
 
 
}
1
根据articleid 读取数据库获取类别的函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
************************************************************************/
/*      获得训练语料库中文章的类别标签                                                                */
/************************************************************************/
vector<string> Preprocess::GetClassification(string ArticleIds)
{   vector<string>labels;
     char * selectCategorization= new char [5000];
     memset (selectCategorization,50,5000);
     sprintf_s(selectCategorization,5000, "select Categorization from Article where ArticleId in%s" ,ArticleIds.c_str());
     CoInitialize(NULL);
     _ConnectionPtr pConn(__uuidof(Connection));
     _RecordsetPtr pRst(__uuidof(Recordset));
     pConn->ConnectionString=dbconnection;
     pConn->Open( "" , "" , "" ,adConnectUnspecified);
     pRst=pConn->Execute(selectCategorization,NULL,adCmdText);
     delete []selectCategorization;
     while (!pRst->rsEOF)
     {
         string label=(_bstr_t) pRst->GetCollect( "Categorization" );
         labels.push_back(label);
         pRst->MoveNext();
 
     }
     pRst->Close();
     pConn->Close();
     pRst.Release();
     pConn.Release();
     CoUninitialize();
     return labels;
     
 
}

对训练文档集合用KNN进行分类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/************************************************************************/
/* KNN分类器                                                               */
/************************************************************************/
vector<pair< int ,string> > Preprocess::KNNclassifier(map<string,vector<pair< int , int >>> &mymap,map< int ,vector< double > >&trainingsetVSM,map< int ,vector< double > >&testingsetVSM,vector<string>catigorization, int N)
{
     vector<pair< int ,string>>classifyResults;
     for (map< int ,vector< double > >::iterator it=trainingsetVSM.begin();it!=testingsetVSM.end();it++)
     {
         string label=KNNClassificationCell(N,it->second,catigorization,mymap,trainingsetVSM);
         pair< int ,string> temp=make_pair(it->first,label);
         classifyResults.push_back(temp);
 
     }
     return classifyResults;
 
 
}
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/801274
推荐阅读
相关标签
  

闽ICP备14008679号