赞
踩
如果有一个未知类别的样本s,其属性向量为x,
一般的KNN使用的是与该向量最近的k个样本作为判断其类别的依据。也就是说,在最近的k个样本中,哪种类别数量最多,该未知样本就被判定为那个类别。这个过程涉及到了求距离和排序两个操作。如果有1亿个样本,就要求1亿个距离,然后对这1亿个跨度进行排序。显然数量大了,算法运行效率就低了。
基于此,这篇文章提出了一种全新的思路(可能别人已经想出来了),利用样本质心来判定类别。对于未知样本s, 属性向量为x。如果已知样本有1亿个,共有3个类别,那么在进行分类时,只要把x与3个类别的质心求3次距离并求一个最小值。这极大地提高了运行效率。下面介绍具体推导过程。
定义1 如果在n维空间中有m个点,分别为
P
1
,
P
2
,
.
.
.
,
P
m
P_1,P_2,...,P_m
P1,P2,...,Pm 那么这m个点的质心P定义为:
P
=
1
m
∑
i
=
1
m
P
i
P=\frac{1}{m}\sum^{m}_{i=1}P_i
P=m1∑i=1mPi
定理1 如果已知m个点的质心为
P
P
P,再加入一个点
P
m
+
1
P_{m+1}
Pm+1那么它的质心
P
P
P将产生如下变化:
P
′
=
m
m
+
1
P
+
1
m
+
1
P
m
+
1
P'=\frac{m}{m+1}P+\frac{1}{m+1}P_{m+1}
P′=m+1mP+m+11Pm+1
证
:
定义2 如果样本分为c类,每一类可以计算一个质心,那么样本将有
c
c
c个质心,这些质心称为样本类别的中心
。
据此,可以提出使用新的思路进行样本分类的算法:
算法1 样本预处理算法
for each class cls:
gravity_sum = zero vector
for each sample of cls:
gravity_sum += sample's attribute vector
gravity_center = gravity_sum / count(cls)
算法2 增强学习算法
smp = the sample want to be add
attr_vec = the sample's attribute vector
cls = the sample's class
p = gravity_center of cls
n = amounts of cls
p = n / (n + 1) * p + 1 / (n + 1) * attr_vec
算法3 分类算法
attr_vec = the unknown sample's attribute vector
dis = {}
for each gravity_center:
d = (gracity_center - attr_vec).norm()
put d into dis
cls = argmin dis // this is the sample's class that was predicted
以iris数据集为例,验证一下这个算法。
数据如下,一共有3个类别:
经过多次运行,在预测集上的准确率保持在90%左右。
以下是代码
import numpy as np import pandas as pd data = pd.read_csv("iris.csv", header=None) data = data.sample(frac=1) # train classes = {} for i in range(100): cls = data.iloc[i,4]; x = data.iloc[i, 0:4].to_numpy(); if cls in classes: n = classes[cls]["n"]; center = classes[cls]["center"]; center = (n / (n+1)) * center + (1/ (n+1)) * x; classes[cls] = {"n": n + 1, "center": center}; else: classes[cls] = {"n": 1, "center": x}; # predict count = 0; for i in range(101, 150): cls = data.iloc[i,4]; x = data.iloc[i, 0:4].to_numpy(); max_distance = -1; max_likely_cls = ""; for key, value in classes.items(): distance = np.linalg.norm(value["center"] - x); if max_distance < 0 or distance < max_distance: max_distance = distance; max_likely_cls = key; if(max_likely_cls == cls): count = count + 1; print("正确率:", "%.2f" % (count * 100 / 50, ));
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。