当前位置:   article > 正文

python机器学习 基于决策树的MNIST数字分类 详细教程 数据集+源码+远程部署

基于决策树的mnist数字分类

数据集+源码:python机器学习 基于决策树的MNIST数字分类 详细教程 数据集+源码+远程部署

目录

​编辑

介绍

加载数据

数据处理

配置流水线

数据集训练集划分(8:2)

模型

模型训练

训练集测试集验证

网格搜索


介绍

决策树是一种非常受欢迎的机器学习算法,它可以用于分类和回归任务。在基于决策树的MNIST数字分类中,算法会学习如何从手写数字的图像像素值中提取特征,并根据这些特征来决定图像表示的数字(0到9)。

MNIST数据集是一个包含了手写数字的大型数据库,常用于训练各种图像处理系统。数据集包含60000个训练样本和10000个测试样本。每个样本是一个28x28像素的灰度图像。

要使用决策树进行MNIST数字分类,可以按照以下步骤:

  1. 数据预处理:加载MNIST数据集,并将图像的28x28像素矩阵平展成一个长度为784的一维数组。这样每个图像就变成了一个特征向量。

  2. 创建决策树模型:选择一种决策树算法,例如CART(分类和回归树)或者ID3等,然后使用训练数据集来训练模型。

  3. 训练模型:使用训练集中的图像和对应的标签来训练决策树模型。决策树将会学习如何根据像素值的特征进行分类。

  4. 评估模型:在测试数据集上评估决策树模型的性能。可以计算准确率(accuracy)、精确度(precision)、召回率(recall)和F1分数等指标来衡量模型的分类能力。

  5. 优化模型:可能需要调整决策树的一些参数,比如树的深度、最小分裂所需的样本数等,来优化模型的性能。

  6. 使用模型:一旦模型经过足够的训练并且评估指标令人满意,就可以用它来对新的手写数字图像进行分类。

虽然决策树模型相对简单并且易于理解,它们在处理像MNIST这样的高维数据时通常不如深度学习模型效果好。如果需要高精度的图像分类,可能会考虑使用卷积神经网络(CNN)等更复杂的算法。

[[  0.   0.   5.  13.   9.   1.   0.   0.]
 [  0.   0.  13.  15.  10.  15.   5.   0.]
 [  0.   3.  15.   2.   0.  11.   8.   0.]
 [  0.   4.  12.   0.   0.   8.   8.   0.]
 [  0.   5.   8.   0.   0.   9.   8.   0.]
 [  0.   4.  11.   0.   1.  12.   7.   0.]
 [  0.   2.  14.   5.  10.  12.   0.   0.]
 [  0.   0.   6.  13.  10.   0.   0.   0.]]
 

加载数据

  1. X, y = sklearn.datasets.load_digits(return_X_y=True)
  2. X = pd.DataFrame(X)
  3. y = pd.DataFrame(y)
  4. print(X.info)
  5. print(y.info)

数据处理

  1. cate_cols = [] # 离散特征
  2. num_cols = [] # 数值型特征
  3. # 获取各个特征的数据类型
  4. dtypes = X.dtypes
  5. for col, dtype in dtypes.items():
  6. if dtype == 'object':
  7. cate_cols.append(col)
  8. else:
  9. num_cols.append(col)

数值型特征

  1. class Num_Encoder(BaseEstimator, TransformerMixin):
  2. def __init__(self, cols=[], fillna=False, addna=False):
  3. self.fillna = fillna
  4. self.cols = cols
  5. self.addna = addna
  6. self.na_cols = []
  7. self.imputers = {}
  8. def fit(self, X, y=None):
  9. for col in self.cols:
  10. if self.fillna:
  11. self.imputers[col] = X[col].median()
  12. if self.addna and X[col].isnull().sum():
  13. self.na_cols.append(col)
  14. print(self.na_cols, self.imputers)
  15. return self
  16. def transform(self, X, y=None):
  17. df = X.loc[:, self.cols]
  18. for col in self.imputers:
  19. df[col].fillna(self.imputers[col], inplace=True)
  20. for col in self.na_cols:
  21. df[col + '_na'] = pd.isnull(df[col])
  22. return df

离散型特征

  1. class Cat_Encoder(BaseEstimator, TransformerMixin):
  2. def __init__(self, cols, max_n_cat=7, onehot_cols=[], orders={}):
  3. self.cols = cols
  4. self.onehot_cols = onehot_cols
  5. self.cats = {}
  6. self.max_n_cat = max_n_cat
  7. self.orders = orders
  8. def fit(self, X, y=None):
  9. df_cat = X.loc[:, self.cols]
  10. for n, c in df_cat.items():
  11. df_cat[n].fillna('NAN', inplace=True)
  12. df_cat[n] = c.astype('category').cat.as_ordered()
  13. if n in self.orders:
  14. df_cat[n].cat.set_categories(self.orders[n], ordered=True, inplace=True)
  15. cats_count = len(df_cat[n].cat.categories)
  16. if cats_count <= 2 or cats_count > self.max_n_cat:
  17. self.cats[n] = df_cat[n].cat.categories
  18. if n in self.onehot_cols:
  19. self.onehot_cols.remove(n)
  20. elif n not in self.onehot_cols:
  21. self.onehot_cols.append(n)
  22. print(self.onehot_cols)
  23. return self
  24. def transform(self, df, y=None):
  25. X = df.loc[:, self.cols]
  26. for col in self.cats:
  27. X[col].fillna('NAN', inplace=True)
  28. X.loc[:, col] = pd.Categorical(X[col], categories=self.cats[col], ordered=True)
  29. X.loc[:, col] = X[col].cat.codes
  30. if len(self.onehot_cols):
  31. df_1h = pd.get_dummies(X[self.onehot_cols], dummy_na=True)
  32. df_drop = X.drop(self.onehot_cols, axis=1)
  33. return pd.concat([df_drop, df_1h], axis=1)
  34. return X

配置流水线

  1. num_pipeline = Pipeline([
  2. ('num_encoder', Num_Encoder(cols=num_cols, fillna='median', addna=True)),
  3. ])
  4. X_num = num_pipeline.fit_transform(X)
  5. cat_pipeline = Pipeline([
  6. ('cat_encoder', Cat_Encoder(cols=cate_cols))
  7. ])
  8. X_cate = cat_pipeline.fit_transform(X)

数据集训练集划分(8:2)

  1. X = pd.concat([X_num, X_cate], axis=1)
  2. print(X.shape, y.shape)
  3. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=2022)
  4. print('【训练集】', X_train.shape, y_train.shape)
  5. print('【测试集】', X_test.shape, y_test.shape)

模型

  1. dtc = DecisionTreeClassifier(max_depth=5,
  2. criterion='gini',
  3. random_state=2022)

模型训练

dtc.fit(X_train, y_train)

训练集测试集验证

 

  1. y_train_pred = dtc.predict(X_train)
  2. y_test_pred = dtc.predict(X_test)
  3. accuracy_train = metrics.accuracy_score(y_train, y_train_pred)
  4. accuracy_test = metrics.accuracy_score(y_test, y_test_pred)
  5. print('训练集的accuracy: ', accuracy_train)
  6. print('测试集的accuracy: ', accuracy_test)

网格搜索

  1. dtcmodel = DecisionTreeClassifier()
  2. param_grid_dtc = {
  3. 'criterion': ['gini', 'entropy'],
  4. 'splitter': ['best', 'random'],
  5. 'max_depth': range(2, 10, 2),
  6. 'min_samples_split': range(1, 5, 1),
  7. 'max_features': ['auto', 'sqrt', 'log2']
  8. }
  9. dtcmodel_grid = GridSearchCV(estimator=dtcmodel,
  10. param_grid=param_grid_dtc,
  11. verbose=1,
  12. n_jobs=-1,
  13. cv=2)

  1. dtcmodel_grid.fit(X_train, y_train)
  2. print('【DTC】', dtcmodel_grid.best_score_)

  1. best_modeldtc = DecisionTreeClassifier(criterion=dtcmodel_grid.best_estimator_.get_params()['criterion'],
  2. splitter=dtcmodel_grid.best_estimator_.get_params()['splitter'],
  3. max_depth=dtcmodel_grid.best_estimator_.get_params()['max_depth'],
  4. min_samples_split=dtcmodel_grid.best_estimator_.get_params()['min_samples_split'],
  5. max_features=dtcmodel_grid.best_estimator_.get_params()['max_features'],)
best_modeldtc.fit(X_train, y_train)

  1. y_train_pred = best_modeldtc.predict(X_train)
  2. y_test_pred = best_modeldtc.predict(X_test)
  3. accuracy_train = metrics.accuracy_score(y_train, y_train_pred)
  4. accuracy_test = metrics.accuracy_score(y_test, y_test_pred)
  5. print('训练集的accuracy: ', accuracy_train)
  6. print('测试集的accuracy: ', accuracy_test)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小惠珠哦/article/detail/936817
推荐阅读
相关标签
  

闽ICP备14008679号