当前位置:   article > 正文

MNIST & CatBoost保存模型并预测_catboost 保存模型

catboost 保存模型

安装

pip install catboost
  • 1

数据集

分类MNIST(60000条数据784个特征),已上传CSDN

代码

import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from catboost import CatBoostClassifier
from sklearn.model_selection import train_test_split
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
train = pd.read_csv('./input/mnist/train.csv')
train.head()
  • 1
  • 2

在这里插入图片描述

X = train.iloc[:, 1:]  # 训练数据
y = train['label']  #标签
  • 1
  • 2
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # 划分训练、测试集
  • 1
def plot_digits(instances, images_per_row=10):
    '''绘制数据集
    
    :param instances: 部分数据集
    :type instances: numpy.ndarray
    :param images_per_row: 每一行显示图片数
    '''
    size = 28
    images_per_row = min(len(instances), images_per_row)
    images = [instance.reshape(size, size) for instance in instances]
    n_rows = (len(instances) - 1) // images_per_row + 1
    row_images = []
    n_empty = n_rows * images_per_row - len(instances)
    images.append(np.zeros((size, size * n_empty)))
    for row in range(n_rows):
        rimages = images[row * images_per_row: (row + 1) * images_per_row]
        row_images.append(np.concatenate(rimages, axis=1))
    image = np.concatenate(row_images, axis=0)
    plt.imshow(image, cmap='gray_r')
    plt.axis("off")
    
plt.figure()
plot_digits(X_train[:100].values, images_per_row=10)
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

在这里插入图片描述

# 定义模型
clf = CatBoostClassifier()
  • 1
  • 2
# 训练
model = clf.fit(X_train, y_train)
  • 1
  • 2
0:	learn: 2.2139620	total: 975ms	remaining: 16m 13s
1:	learn: 2.1344069	total: 1.95s	remaining: 16m 15s
2:	learn: 2.0559619	total: 2.92s	remaining: 16m 10s
3:	learn: 1.9850790	total: 3.89s	remaining: 16m 7s
......
996:	learn: 0.1231917	total: 16m 35s	remaining: 3s
997:	learn: 0.1231500	total: 16m 36s	remaining: 2s
998:	learn: 0.1231068	total: 16m 37s	remaining: 999ms
999:	learn: 0.1230654	total: 16m 38s	remaining: 0us
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
# 评估
print('accuracy:', model.score(X_test, y_test))
  • 1
  • 2
# 保存
model.save_model('mnist.model')
  • 1
  • 2
# 加载
ccc = CatBoostClassifier()
ccc.load_model('mnist.model')
  • 1
  • 2
  • 3
# 预测
index = random.randint(0, len(X_test))  # 随机挑一个
_X = X_test.values[index]
_y = y_test.values[index]  # 真值
predict = ccc.predict(_X)[0]  # 预测值

_X = _X.reshape(28, 28)
plt.imshow(_X, cmap='gray_r')
plt.title('original {}'.format(_y))
plt.show()

print('index:', index)
print('original:', _y)
print('predicted:', predict)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

在这里插入图片描述

index: 7534
original: 6
predicted: 6
  • 1
  • 2
  • 3

在这里插入图片描述

index: 6510
original: 4
predicted: 4
  • 1
  • 2
  • 3

在这里插入图片描述

index: 7311
original: 6
predicted: 6
  • 1
  • 2
  • 3

ipynb

下载地址

参考文献

  1. Battle of the Boosting Algos: LGB, XGB, Catboost
  2. CatBoost - open-source gradient boosting library
  3. Quick start - CatBoost. Documentation
  4. CatBoost tutorials
  5. 机器学习算法之Catboost
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/477751
推荐阅读
相关标签
  

闽ICP备14008679号