当前位置:   article > 正文

python实现ID3算法对鸢尾花数据集分类_机器学习id3算法鸢尾花

机器学习id3算法鸢尾花

        理论部分大家可以自己学习,这里代码是利用的逻辑编写,没有用sklearn

  1. import numpy as np
  2. import pandas as pd
  3. from collections import Counter
  4. class DecisionTree:
  5. class Node:
  6. def __init__(self, feature_index=None, threshold=None, value=None, left=None, right=None):
  7. self.feature_index = feature_index # 特征索引
  8. self.threshold = threshold # 分割阈值
  9. self.value = value # 叶子节点预测值
  10. self.left = left # 左子树
  11. self.right = right # 右子树
  12. def __init__(self, max_depth=None, min_samples_split=2):
  13. self.max_depth = max_depth # 最大深度
  14. self.min_samples_split = min_samples_split # 分割的最小样本数
  15. self.root = None # 根节点
  16. def fit(self, X, y):
  17. self.root = self._build_tree(X, y, depth=0)
  18. def _build_tree(self, X, y, depth):
  19. n_samples, n_features = X.shape
  20. n_classes = len(np.unique(y))
  21. if depth == self.max_depth or n_samples < self.min_samples_split or n_classes == 1: # 满足停止条件
  22. value = self._most_common_label(y)
  23. return self.Node(value=value)
  24. best_feature_index, best_threshold = self._find_best_split(X, y)
  25. if best_feature_index is None or best_threshold is None:# 满足停止条件
  26. value = self._most_common_label(y)
  27. return self.Node(value=value)
  28. left_indices = X[:, best_feature_index] < best_threshold
  29. right_indices = ~left_indices
  30. left_branch = self._build_tree(X[left_indices], y[left_indices], depth+1)# 递归构建子树
  31. right_branch = self._build_tree(X[right_indices], y[right_indices], depth+1)
  32. return self.Node(feature_index=best_feature_index, threshold=best_threshold,
  33. left=left_branch, right=right_branch)
  34. def _find_best_split(self, X, y):
  35. n_samples, n_features = X.shape
  36. best_info_gain = -1
  37. best_feature_index = None
  38. best_threshold = None
  39. entropy_parent = self._entropy(y)# 计算父节点的熵
  40. for feature_index in range(n_features):
  41. unique_values = np.unique(X[:, feature_index])
  42. for threshold in unique_values:
  43. left_indices = X[:, feature_index] < threshold
  44. right_indices = ~left_indices
  45. entropy_left = self._entropy(y[left_indices]) # 计算子节点的熵和信息增益
  46. entropy_right = self._entropy(y[right_indices])
  47. info_gain = self._information_gain(entropy_parent, y[left_indices], y[right_indices])
  48. if info_gain > best_info_gain:# 选择信息增益最大的分割点
  49. best_info_gain = info_gain
  50. best_feature_index = feature_index
  51. best_threshold = threshold
  52. return best_feature_index, best_threshold
  53. def _entropy(self, y): #计算熵
  54. counter = Counter(y)
  55. probabilities = [count / len(y) for count in counter.values()]
  56. entropy = -sum(p * np.log2(p) for p in probabilities)
  57. return entropy
  58. def _information_gain(self, entropy_parent, y_left, y_right):
  59. n_total = len(y_left) + len(y_right)
  60. p_left, p_right = len(y_left) / n_total, len(y_right) / n_total
  61. info_gain = entropy_parent - (p_left * self._entropy(y_left) + p_right * self._entropy(y_right))
  62. return info_gain #计算信息增益
  63. def _most_common_label(self, y):
  64. counter = Counter(y)
  65. most_common = counter.most_common(1)
  66. return most_common[0][0]
  67. def predict(self, X):
  68. return [self._traverse_tree(x, self.root) for x in X]
  69. def _traverse_tree(self, x, node): #进行分类
  70. if node.value is not None:
  71. return node.value
  72. if x[node.feature_index] < node.threshold:
  73. return self._traverse_tree(x, node.left)
  74. else:
  75. return self._traverse_tree(x, node.right)
  76. data = pd.read_excel("C:/Users/wxc/Desktop/xuexi/python/pythonProject/机器学习/决策树/train.xlsx")
  77. x_train = np.array(data.iloc[:, 1:5])
  78. y_train = np.array(data.iloc[:, 6])
  79. tree = DecisionTree(max_depth=4, min_samples_split=1)
  80. tree.fit(x_train, y_train)
  81. data1 = pd.read_excel("C:/Users/wxc/Desktop/xuexi/python/pythonProject/机器学习/决策树/test.xlsx",header=None) # 新样本特征
  82. x_test = np.array(data1.iloc[:, 1:5])
  83. y_test = np.array(data1.iloc[:, 6])
  84. predictions = tree.predict(x_test)
  85. print("预测值为:", predictions)
  86. c = 0
  87. for i in range(len(y_test)):
  88. if y_test[i] == predictions[i]:
  89. c= c+1
  90. print('准确率')
  91. print(c/(len(y_test)))

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/701805
推荐阅读
相关标签
  

闽ICP备14008679号