当前位置:   article > 正文

LDA代码实现

lda代码

机器学习线性分类器的LDA分类器的代码实现,由于要进行可视化的展现,所以这里我使用二变量进行二分类的任务,由于书上几乎都会给出它的推导过程所以在这里就省略掉它的过程推导,从使用层面上进行描述和编程语言进行实现。

# my LDA funciton
import numpy as np
import matplotlib.pyplot as plt
# 这个仅限于二维的方式
def my_LDA(X,Y):
    x_0 = X[np.where(Y == 0)[0],:]
    #y_0 = Y[np.where(Y == 0),:]
    x_1 = X[np.where(Y == 1)[0],:]
    #y_1 = Y[np.where(Y == 1),:]
    x0 = x_0.mean(axis = 0)
    x1 = x_1.mean(axis = 0)
    SW = np.zeros([x_0.shape[1],x_0.shape[1]])
    for i in range(x_0.shape[0]):
        for j in range(x_0.shape[1]):
            for k in range(x_0.shape[1]):
                SW[j,k] = SW[j,k] + (x_0[i,j]-x0[j])*(x_0[i,k]-x0[k]) 
    
    for i in range(x_1.shape[0]):
        for j in range(x_0.shape[1]):
            for k in range(x_0.shape[1]):
                SW[j,k] = SW[j,k] + (x_1[i,j]-x1[j])*(x_1[i,k]-x1[k])
    #返回的值是超平面的垂直方向和 过垂直面的一个点
    return np.dot(np.linalg.inv(SW), (x0-x1)), 1/2*(x0+x1)

x = np.random.random([30,2])
x1 = x + 0.5*np.random.random([30,2])+0.1
x0 = x - 0.5*np.random.random([30,2])-0.1

X = np.concatenate((x1, x0),axis = 0)
Y = np.zeros([60,1],dtype= np.int8)
Y[0:30,0] = 1



w , zuobiao = my_LDA(X,Y)
k = -w[0]/w[1]
b = zuobiao[1] - zuobiao[0]*k
plt.Figure(figsize = (10,8))
plt.plot(x1[:,0],x1[:,1],'r*',label = 'positive')
plt.plot(x0[:,0],x0[:,1],'b+',label  = 'negatiive')
plt.plot([0,b],[1.5,1.5*k+b],color = 'black',linestyle = '--',label = r'超平面')
plt.legend(['positive','negatiive', r'超平面'])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42

结果展示:
在这里插入图片描述
通过画图可以看出,我们的分界面可以很好的区分这两类问题,简单来说就是二分类的问题。易知有2个红色的点被分错,1个蓝色的点被分错,其他的点都是分类正确的,而我们总体使用的样本为60个,所以该分类器的分类准确率为57/60 = 95%,可以认为这是一个比较好的结果。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/719867
推荐阅读
相关标签
  

闽ICP备14008679号