赞
踩
摘要:
介绍了K近邻算法,记录了MindSporeAI框架使用部分wine数据集进行KNN实验的步聚和方法。包括环境准备、下载红酒数据集、加载数据和预处理、搭建模型、进行预测等。
用于分类和回归的非参数统计方法
Cover、Hart于1968年提出
机器学习最基础的算法之一。
确定样本类别
计算样本与所有训练样本的距离
找出最接近的k个样本
统计样本类别
投票
结果就是票数最多的类。
三个基本要素:
K值,样本分类由K个邻居的“多数表决”确定
K值太小容易产生噪声
K值太大类别界限模糊
距离度量,特征空间中两个样本间的相似度
距离越小越相似
Lp距离(p=2时,即为欧式距离)
曼哈顿距离
海明距离
分类决策规则
多数表决
基于距离加权的多数表决(权值与距离成反比)
(1)找出距离目标样本x_test最近的k个训练样本,保存至集合N中;
(2)统计集合N中各类样本个数 Ci,i=1,2,3,...,c;
(3)最终分类结果为Ci最大的那个类(argmaxCi)。
k取值重要。
根据问题和数据特点来确定。
带权重的k近邻算法
每个样本有不同的投票权重
回归预测输出为所有邻居的标签均值:
yi为k个目标邻居样本的标签值
带样本权重的回归预测函数:
ωi为第个i样本的权重
常用欧氏距离(欧几里得距离)
空间中两点x和y之间的欧氏距离公式:
注意将特征向量的每个分量归一化
减少不同尺度的干扰
大数值特征分量会淹没小数值特征分量
其它距离
Mahalanobis距离
Bhattacharyya距离
- %%capture captured_output
- # 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
- !pip uninstall mindspore -y
- !pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
- # 查看当前 mindspore 版本
- !pip show mindspore
输出:
- Name: mindspore
- Version: 2.2.14
- Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
- Home-page: https://www.mindspore.cn
- Author: The MindSpore Authors
- Author-email: contact@mindspore.cn
- License: Apache 2.0
- Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
- Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
- Required-by: mindnlp
官网链接:UCI Machine Learning Repository
http://archive.ics.uci.edu/dataset/109/wine
数据内容:
意大利同一地区、三个不同品种葡萄酒化学分析结果。
包括每种葡萄酒中所含13种成分的量:
Alcohol | 酒精 |
Malic acid | 苹果酸 |
Ash | 灰 |
Alcalinity of ash | 灰的碱度 |
Magnesium | 镁 |
Total phenols | 总酚 |
Flavanoids | 类黄酮 |
Nonflavanoid phenols | 非黄酮酚 |
Proanthocyanins | 原花青素 |
Color intensity | 色彩强度 |
Hue | 色调 |
OD280/OD315 of diluted wines | 稀释酒的OD280/OD315 |
Proline | 脯氨酸 |
方式一,从Wine数据集官网下载wine.data文件。
方式二,从华为云OBS中下载wine.data文件。
Key | Value | Key | Value |
Data Set Characteristics | Multivariate | Number of Instances | 178 |
Attribute Characteristics | Integer, Real | Number of Attributes | 13 |
Associated Tasks | Classification | Missing Values? | No |
- from download import download
-
- # 下载红酒数据集
- url = "https://ascend-professional-construction-dataset.obs.cn-north-4.myhuaweicloud.com:443/MachineLearning/wine.zip"
- path = download(url, "./", kind="zip", replace=True)
输出:
- Downloading data from https://ascend-professional-construction-dataset.obs.cn-north-4.myhuaweicloud.com:443/MachineLearning/wine.zip (4 kB)
-
- file_sizes: 100%|██████████████████████████| 4.09k/4.09k [00:00<00:00, 2.35MB/s]
- Extracting zip file...
- Successfully downloaded / unzipped to ./
导入os、numpy、MindSpore、matplotlib等模块
用context.set_context()配置运行模式、后端信息、硬件等
读取Wine数据集wine.data
查看部分数据。
- %matplotlib inline
- import os
- import csv
- import numpy as np
- import matplotlib.pyplot as plt
-
- import mindspore as ms
- from mindspore import nn, ops
-
- ms.set_context(device_target="CPU")
-
- with open('wine.data') as csv_file:
- data = list(csv.reader(csv_file, delimiter=','))
- print(data[56:62]+data[130:133])
输出:
- [['1', '14.22', '1.7', '2.3', '16.3', '118', '3.2', '3', '.26', '2.03', '6.38', '.94', '3.31', '970'],
- ['1', '13.29', '1.97', '2.68', '16.8', '102', '3', '3.23', '.31', '1.66', '6', '1.07', '2.84', '1270'],
- ['1', '13.72', '1.43', '2.5', '16.7', '108', '3.4', '3.67', '.19', '2.04', '6.8', '.89', '2.87', '1285'],
- ['2', '12.37', '.94', '1.36', '10.6', '88', '1.98', '.57', '.28', '.42', '1.95', '1.05', '1.82', '520'],
- ['2', '12.33', '1.1', '2.28', '16', '101', '2.05', '1.09', '.63', '.41', '3.27', '1.25', '1.67', '680'],
- ['2', '12.64', '1.36', '2.02', '16.8', '100', '2.02', '1.41', '.53', '.62', '5.75', '.98', '1.59', '450'],
- ['3', '12.86', '1.35', '2.32', '18', '122', '1.51', '1.25', '.21', '.94', '4.1', '.76', '1.29', '630'],
- ['3', '12.88', '2.99', '2.4', '20', '104', '1.3', '1.22', '.24', '.83', '5.4', '.74', '1.42', '530'],
- ['3', '12.81', '2.31', '2.4', '24', '98', '1.15', '1.09', '.27', '.83', '5.7', '.66', '1.36', '560']]
三类样本(共178条)
自变量X为数据集的13个属性
因变量Y为数据集的3个类别
取样本的某两个属性进行2维可视化
可以看到在某两个属性上样本的分布情况以及可分性。
- X = np.array([[float(x) for x in s[1:]] for s in data[:178]], np.float32)
- Y = np.array([s[0] for s in data[:178]], np.int32)
- attrs = ['Alcohol', 'Malic acid', 'Ash', 'Alcalinity of ash', 'Magnesium', 'Total phenols',
- 'Flavanoids', 'Nonflavanoid phenols', 'Proanthocyanins', 'Color intensity', 'Hue',
- 'OD280/OD315 of diluted wines', 'Proline']
- plt.figure(figsize=(10, 8))
- for i in range(0, 4):
- plt.subplot(2, 2, i+1)
- a1, a2 = 2 * i, 2 * i + 1
- plt.scatter(X[:59, a1], X[:59, a2], label='1')
- plt.scatter(X[59:130, a1], X[59:130, a2], label='2')
- plt.scatter(X[130:, a1], X[130:, a2], label='3')
- plt.xlabel(attrs[a1])
- plt.ylabel(attrs[a2])
- plt.legend()
- plt.show()
将数据集按128:50划分为训练集(已知类别样本)和验证集(待验证样本):
- train_idx = np.random.choice(178, 128, replace=False)
- test_idx = np.array(list(set(range(178)) - set(train_idx)))
- X_train, Y_train = X[train_idx], Y[train_idx]
- X_test, Y_test = X[test_idx], Y[test_idx]
MindSpore算子
tile
square
ReduceSum
sqrt
TopK
矩阵运算并行计算
目标样本x和已分类训练样本X_train的距离
top k近邻
- class KnnNet(nn.Cell):
- def __init__(self, k):
- super(KnnNet, self).__init__()
- self.k = k
-
- def construct(self, x, X_train):
- #平铺输入x以匹配X_train中的样本数
- x_tile = ops.tile(x, (128, 1))
- square_diff = ops.square(x_tile - X_train)
- square_dist = ops.sum(square_diff, 1)
- dist = ops.sqrt(square_dist)
- #-dist表示值越大,样本就越接近
- values, indices = ops.topk(-dist, self.k)
- return indices
-
- def knn(knn_net, x, X_train, Y_train):
- x, X_train = ms.Tensor(x), ms.Tensor(X_train)
- indices = knn_net(x, X_train)
- topk_cls = [0]*len(indices.asnumpy())
- for idx in indices.asnumpy():
- topk_cls[Y_train[idx]] += 1
- cls = np.argmax(topk_cls)
- return cls
验证KNN算法
k=5
验证精度接近80%
- acc = 0
- knn_net = KnnNet(5)
- for x, y in zip(X_test, Y_test):
- pred = knn(knn_net, x, X_train, Y_train)
- acc += (pred == y)
- print('label: %d, prediction: %s' % (y, pred))
- print('Validation accuracy is %f' % (acc/len(Y_test)))
输出:
- label: 1, prediction: 1
- label: 3, prediction: 3
- label: 3, prediction: 3
- label: 3, prediction: 3
- label: 3, prediction: 3
- label: 3, prediction: 3
- label: 1, prediction: 1
- label: 3, prediction: 1
- label: 1, prediction: 1
- label: 1, prediction: 2
- label: 3, prediction: 3
- label: 1, prediction: 1
- label: 3, prediction: 3
- label: 1, prediction: 1
- label: 1, prediction: 1
- label: 3, prediction: 2
- label: 1, prediction: 1
- label: 3, prediction: 3
- label: 1, prediction: 1
- label: 1, prediction: 3
- label: 1, prediction: 1
- label: 1, prediction: 1
- label: 1, prediction: 3
- label: 1, prediction: 1
- label: 3, prediction: 2
- label: 1, prediction: 1
- label: 3, prediction: 2
- label: 3, prediction: 2
- label: 1, prediction: 1
- label: 3, prediction: 1
- label: 3, prediction: 1
- label: 1, prediction: 1
- label: 2, prediction: 3
- label: 2, prediction: 2
- label: 2, prediction: 2
- label: 2, prediction: 2
- label: 2, prediction: 2
- label: 2, prediction: 2
- label: 2, prediction: 3
- label: 2, prediction: 2
- label: 2, prediction: 3
- label: 2, prediction: 2
- label: 2, prediction: 2
- label: 2, prediction: 2
- label: 2, prediction: 3
- label: 2, prediction: 2
- label: 2, prediction: 2
- label: 2, prediction: 2
- label: 2, prediction: 2
- label: 2, prediction: 2
- Validation accuracy is 0.720000
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。