当前位置:   article > 正文

机器学习算法决策树C4.5之c++实现(不调用外源库)_决策树算法 c++语言实现

决策树算法 c++语言实现

 目前玩机器学习的小伙伴,上来就是使用现有的sklearn机器学习包,写两行代码,调调参数就能跑起来,看似方便,实则有时不利于个人能力发展,要知道现在公司需要的算法工程师,不仅仅只是会调参(这种工作,入门几个月的人就可以干了),而是要深入底层,能优化代码,能自己搭。

本文章适合以下几类人:

1)初学者,了解机器学习的实现过程

2)想提升自己的代码能力

第一步:原理

     决策树可以被简单的看成是一些if 和else,其优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。其缺点:可能会产生过度匹配问题。决策树相关详细理论的博客,网上有很多,我就不重复啰嗦了,感兴趣的可以看下这个:十大经典算法之C4.5算法综述 - 知乎

第二步:代码实现

  1. #include <vector>
  2. #include <set>
  3. #include <map>
  4. #include <string>
  5. #include <fstream>
  6. #include <sstream>
  7. #include <iostream>
  8. #include <math.h>
  9. using namespace std;
  10. /*******树的构造*******/
  11. struct TreeNode {
  12. string m_sAttribute;//节点名字
  13. int m_iDeciNum; //yes 数
  14. int m_iUnDecinum; //no 数
  15. vector<TreeNode*> m_vChildren;
  16. };
  17. TreeNode* CreateTreeNode(string value)
  18. {
  19. TreeNode* pNode = new TreeNode();
  20. pNode->m_sAttribute = value;
  21. return pNode;
  22. }
  23. bool FindNode(TreeNode* pRoot, std::string& item)
  24. {
  25. if (pRoot->m_sAttribute == item)
  26. return true;
  27. bool found = false;
  28. vector<TreeNode*>::iterator i = pRoot->m_vChildren.begin();
  29. while (!found && i < pRoot->m_vChildren.end())
  30. {
  31. found = FindNode(*i, item);
  32. ++i;
  33. }
  34. return found;
  35. }
  36. void ConnectTreeNodes(TreeNode* pParent, TreeNode* pChild)
  37. {
  38. if (pParent != NULL)
  39. {
  40. pParent->m_vChildren.push_back(pChild);
  41. }
  42. }
  43. void PrintTreeNode(TreeNode* pNode)
  44. {
  45. if (pNode != NULL)
  46. {
  47. printf("value of this node is: %d.\n", pNode->m_sAttribute);
  48. printf("its children is as the following:\n");
  49. std::vector<TreeNode*>::iterator i = pNode->m_vChildren.begin();
  50. while (i < pNode->m_vChildren.end())
  51. {
  52. if (*i != NULL)
  53. printf("%s\t", (*i)->m_sAttribute);
  54. ++i;
  55. }
  56. printf("\n");
  57. }
  58. else
  59. {
  60. printf("this node is null.\n");
  61. }
  62. printf("\n");
  63. }
  64. void PrintTree(TreeNode* pRoot)
  65. {
  66. PrintTreeNode(pRoot);
  67. if (pRoot != NULL)
  68. {
  69. std::vector<TreeNode*>::iterator i = pRoot->m_vChildren.begin();
  70. while (i < pRoot->m_vChildren.end())
  71. {
  72. PrintTree(*i);
  73. ++i;
  74. }
  75. }
  76. }
  77. void DestroyTree(TreeNode* pRoot)
  78. {
  79. if (pRoot != NULL)
  80. {
  81. std::vector<TreeNode*>::iterator i = pRoot->m_vChildren.begin();
  82. while (i < pRoot->m_vChildren.end())
  83. {
  84. DestroyTree(*i);
  85. ++i;
  86. }
  87. delete pRoot;
  88. }
  89. }
  90. /*******树的构造*******/
  91. class DecisionTree {
  92. private:
  93. struct attrItem
  94. {
  95. std::vector<int> itemNum; //itemNum[0] = itemLine.size()
  96. //itemNum[1] = decision num
  97. set<int> itemLine; //可用行
  98. };
  99. //重点
  100. struct attributes
  101. {
  102. string attriName; //属性名字
  103. vector<double> statResult;
  104. map<string, attrItem*> attriItem;//存放子目录的信息
  105. };
  106. vector<attributes*> statTree;
  107. int attriNum;
  108. vector<vector<string>> infos;
  109. map<string, int> attr_clum;//作用不大
  110. public:
  111. DecisionTree() {
  112. attriNum = 0;
  113. }
  114. vector<vector<string>>& getInfos()
  115. {
  116. return infos;
  117. }
  118. vector<attributes*>& getStatTree()
  119. {
  120. return statTree;
  121. }
  122. int pretreatment(string filename, set<int>& readLineNum, vector<int>& readClumNum);
  123. int statister(vector<vector<string>>& infos, vector<attributes*>& statTree,
  124. set<int>& readLine, vector<int>& readClumNum);
  125. int compuDecisiNote(vector<attributes*>& statTree, int deciNum, int lineNum, vector<int>& readClumNum);
  126. double info_D(int deciNum, int sum);
  127. void resetStatTree(vector<attributes*>& statTree, vector<int>& readClumNum);
  128. double Info_attr(map<string, attrItem*>& attriItem, double& splitInfo, int lineNum);
  129. void CreatTree(TreeNode* &treeHead, vector<attributes*>& statTree, vector<vector<string>>& infos,
  130. set<int>& readLine, vector<int>& readClumNum, int deep);
  131. };
  132. /*
  133. * @function CreatTree 预处理函数,负责读入数据,并生成信息矩阵和属性标记
  134. * @param: filename 文件名
  135. * @param: readLineNum 可使用行set
  136. * @param: readClumNum 可用属性vector 0可用 1不可用
  137. * @return int 返回函数执行状态
  138. */
  139. int DecisionTree::pretreatment(string filename, set<int>& readLineNum, vector<int>& readClumNum)
  140. {
  141. }
  142. /*
  143. * @function Info_attr info_D 总信息量
  144. * @param: deciNum 有效信息数
  145. * @param: sum 总信息量
  146. * @return double 总信息量比例
  147. */
  148. double DecisionTree::info_D(int deciNum, int sum)
  149. {
  150. double pi = (double)deciNum / (double)sum;
  151. double result = 0.0;
  152. if ((1.0 - pi) < 0.000001 || (pi - 0.0) < 0.000001)
  153. {
  154. return result;
  155. }
  156. result = pi * (log(pi) / log((double)2)) + (1 - pi)*(log(1 - pi) / log((double)2));
  157. return -result;
  158. }
  159. /*
  160. * @function Info_attr 总信息量
  161. * @param: deciNum 有效信息数
  162. * @param: sum 总信息量
  163. * @return double
  164. */
  165. double DecisionTree::Info_attr(map<string, attrItem*>& attriItem, double& splitInfo, int lineNum)
  166. {
  167. double result = 0.0;
  168. for (map<string, attrItem*>::iterator item = attriItem.begin();
  169. item != attriItem.end();
  170. ++item
  171. )
  172. {
  173. double pi = (double)(item->second->itemNum[0]) / (double)lineNum;
  174. splitInfo += pi * (log(pi) / log((double)2));
  175. double sub_attr = info_D(item->second->itemNum[1], item->second->itemNum[0]);
  176. result += pi * sub_attr;
  177. }
  178. splitInfo = -splitInfo;
  179. return result;
  180. }
  181. /*
  182. * @function compuDecisiNote 计算C4.5
  183. * @param: statTree 为状态树,此树动态更新,但是由于是DFS对数据更新,所以不必每次新建状态树
  184. * @param: deciNum Yes的数据量
  185. * @param: lineNum 计算set的行数
  186. * @param: readClumNum 用于计算的set
  187. * @return int 信息量最大的属性号
  188. */
  189. int DecisionTree::compuDecisiNote(vector<attributes*>& statTree, int deciNum, int lineNum, vector<int>& readClumNum)
  190. {
  191. double max_temp = 0;
  192. int max_attribute = 0;
  193. //总的yes行的信息量
  194. double infoD = info_D(deciNum, lineNum);
  195. for (int i = 0; i < attriNum - 1; i++)
  196. {
  197. if (readClumNum[i] == 0)
  198. {
  199. double splitInfo = 0.0;
  200. //info
  201. double info_temp = Info_attr(statTree[i]->attriItem, splitInfo, lineNum);
  202. statTree[i]->statResult.push_back(info_temp);
  203. //gain
  204. double gain_temp = infoD - info_temp;
  205. statTree[i]->statResult.push_back(gain_temp);
  206. //split_info
  207. statTree[i]->statResult.push_back(splitInfo);
  208. //gain_info
  209. double temp = gain_temp / splitInfo;
  210. statTree[i]->statResult.push_back(temp);
  211. //得到最大值*/
  212. if (temp > max_temp)
  213. {
  214. max_temp = temp;
  215. max_attribute = i;
  216. }
  217. }
  218. }
  219. return max_attribute;
  220. }
  221. /*
  222. * @function resetStatTree 清理状态树
  223. * @param: statTree 状态树
  224. * @param: readClumNum 需要清理的属性set
  225. * @return void
  226. */
  227. void DecisionTree::resetStatTree(vector<attributes*>& statTree, vector<int>& readClumNum)
  228. {
  229. for (int i = 0; i < readClumNum.size() - 1; i++)
  230. {
  231. if (readClumNum[i] == 0)
  232. {
  233. map<string, attrItem*>::iterator it_end = statTree[i]->attriItem.end();
  234. for (map<string, attrItem*>::iterator it = statTree[i]->attriItem.begin();
  235. it != it_end; it++)
  236. {
  237. delete it->second;
  238. }
  239. statTree[i]->attriItem.clear();
  240. statTree[i]->statResult.clear();
  241. }
  242. }
  243. }
  244. int main(int argc, char* argv[]) {
  245. string filename = "tree.txt";
  246. DecisionTree dt;
  247. int attr_node = 0;
  248. TreeNode* treeHead = nullptr;
  249. set<int> readLineNum;
  250. vector<int> readClumNum;
  251. int deep = 0;
  252. if (dt.pretreatment(filename, readLineNum, readClumNum) == 0)
  253. {
  254. dt.CreatTree(treeHead, dt.getStatTree(), dt.getInfos(), readLineNum, readClumNum, deep);
  255. }
  256. return 0;
  257. }

第三步:运行过程

运行结果

用到的软件是vs2010以上的版本都可以,不用额外配置什么,没调包,会用这个软件进行c++开发,就会使用这个软件

此程序由于不调用任何外源库,所以读者可以看清楚每一个算法的原理,要想学好机器学习算法,必须打好基础,不要好高骛远,另外,程序都是有备注,应该很好理解的,实在不懂,可以来问店主

代码的下载路径(新窗口打开链接)机器学习算法决策树C4.5之c++实现(不调用外源库)

有问题可以私信或者留言,有问必答

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/611958
推荐阅读
相关标签
  

闽ICP备14008679号