当前位置:   article > 正文

svm的代码实现_svm代码

svm代码

目录

代码

代码解读

导入pandas库,并且读取文件

可视化原始数据

使用SVM进行训练

可视化SVM结果

可视化原始数据,选取1维核3维的数据进行可视化

可视化超平面

进行坐标轴限制

找到支持向量[二维数组]可视化支持向量,并绘制图片


代码

  1. import pandas as pd
  2. data = pd.read_csv('iris.csv', header=None)
  3. import matplotlib.pyplot as plt
  4. data1 = data.iloc[:50, :]
  5. data2 = data.iloc[50:, :]
  6. plt.scatter(data1[1], data1[3], marker='+')
  7. plt.scatter(data2[1], data2[3], marker='o')
  8. plt.show()
  9. from sklearn.svm import SVC
  10. x = data.iloc[:, [1, 3]]
  11. y = data.iloc[:, -1]
  12. svc = SVC(kernel='linear', C=float('inf'), random_state=0)
  13. svc.fit(x, y)
  14. scores = svc.score
  15. print('score: ', scores)
  16. w = svc.coef_[0]
  17. b = svc.intercept_[0]
  18. import numpy as np
  19. x1 = np.linspace(0, 7, 300)
  20. x2 = -(w[0] * x1 + b) / w[1]
  21. x3 = (1 - (w[0] * x1 + b)) / w[1]
  22. x4 = (-1 - (w[0] * x1 + b)) / w[1]
  23. plt.scatter(data1[1], data1[3], marker='+', color='r')
  24. plt.scatter(data2[1], data2[3], marker='o', color='b')
  25. plt.plot(x1, x2, linewidth=2, color='r')
  26. plt.plot(x1, x3, linewidth=1, color='r', linestyle='--')
  27. plt.plot(x1, x4, linewidth=1, color='r', linestyle='--')
  28. plt.xlim(4, 7)
  29. plt.ylim(0, 5)
  30. vets = svc.support_vectors_
  31. plt.scatter(vets[:, 0], vets[:, 1], c='r', marker='x')
  32. plt.show()

代码解读

导入pandas库,并且读取文件
  1. import pandas as pd
  2. data = pd.read_csv('iris.csv', header=None)
可视化原始数据
  1. import matplotlib.pyplot as plt: # 这行代码导入了 matplotlib 库的 pyplot 模块,并将其别名为 plt。
  2. data1 = data.iloc[:50, :]: # 这行代码从 data DataFrame 中选取前 50 行所有列的数据,并将其存储为新的 DataFrame 对象,名为 data1。
  3. data2 = data.iloc[50:, :]: #这行代码从 data DataFrame 中选取第 50 行之后的所有行和所有列的数据,并将其存储为新的 DataFrame 对象,名为 data2。
  4. plt.scatter(data1[1], data1[3], marker='+'): #这行代码使用 matplotlib 的 scatter 函数绘制 data1 DataFrame 中的第 2 列(索引为 1)和第 4 列(索引为 3)的数据点,并使用加号作为标记。
  5. plt.scatter(data2[1], data2[3], marker='o'): #这行代码使用 matplotlib 的 scatter 函数绘制 data2 DataFrame 中的第 2 列(索引为 1)和第 4 列(索引为 3)的数据点,并使用圆圈作为标记。
  6. plt.show(): #这行代码使用 matplotlib 的 show 函数显示之前绘制的图形。

图片结果为:

使用SVM进行训练
  1. from sklearn.svm import SVC: #这行代码从 sklearn 库中导入了 SVC 类,它是用于支持向量机(Support Vector Machine)的类。
  2. x = data.iloc[:, [1, 3]]: #这行代码从 data DataFrame 中选取所有行的第 2 列和第 4 列的数据,并将其存储为新的 DataFrame 对象,名为 x。
  3. y = data.iloc[:, -1]: #这行代码从 data DataFrame 中选取所有行的最后一列的数据,并将其存储为新的 Series 对象,名为 y。
  4. svc = SVC(kernel='linear', C=float('inf'), random_state=0): #这行代码创建了一个 SVC 对象,名为 svc。其中,kernel 参数设置为 'linear',表示使用线性核函数;C 参数设置为无穷大,表示对分类错误的惩罚非常大;random_state 参数设置为 0,表示随机种子为 0。
  5. svc.fit(x, y): #这行代码使用 SVC 的 fit 函数对 x 和 y 进行训练。
可视化SVM结果
  1. w = svm.coef_[0] # 参数w[原始数据为二维数组]
  2. b = svm.intercept_[0] # 偏置项b[原始数据为一维数组]
  3. # 超平面方程:w1x1+w2x2+b=0
  4. # ->>x2 = -(w1x1+b)/w2
  5. import numpy as np
  6. x1 = np.linspace(0, 7, 300) # 在0~7之间产生300个数据
  7. x2 = -(w[0] * x1 + b) / w[1] # 超平面方程
  8. x3 = (1 - (w[0] * x1 + b)) / w[1] # 上超平面方程
  9. x4 = (-1 - (w[0] * x1 + b)) / w[1]# 下超平面方程
可视化原始数据,选取1维核3维的数据进行可视化
  1. plt.scatter(data1[1], data1[3], marker='+', color='r'): #这行代码使用 matplotlib 的 scatter 函数绘制 data1 DataFrame 中的第 2 列(索引为 1)和第 4 列(索引为 3)的数据点,并使用红色作为标记。
  2. plt.scatter(data2[1], data2[3], marker='o', color='b'): #这行代码使用 matplotlib 的 scatter 函数绘制 data2 DataFrame 中的第 2 列(索引为 1)和第 4 列(索引为 3)的数据点,并使用蓝色作为标记
可视化超平面
  1. plt.plot(x1, x2, linewidth=2, color='r'): #这行代码使用 matplotlib 的 plot 函数绘制从 x1 到 x2 的直线,并设置线条宽度为2,并使用红色作为线条颜色。
  2. plt.plot(x1, x3, linewidth=1, color='r', linestyle='--'): #这行代码使用 matplotlib 的 plot 函数绘制从 x1 到 x3 的直线,并设置线条宽度为1,使用红色作为线条颜色,并使用虚线作为线条样式。
  3. plt.plot(x1, x4, linewidth=1, color='r', linestyle='--'): #这行代码使用 matplotlib 的 plot 函数绘制从 x1 到 x4 的直线,并设置线条宽度为1,使用红色作为线条颜色,并使用虚线作为线条样式。
进行坐标轴限制
  1. plt.xlim(4, 7): #这行代码设置 x 轴的显示范围为 4 到 7。
  2. plt.ylim(0, 5): #这行代码设置 y 轴的显示范围为 0 到 5。
找到支持向量[二维数组]可视化支持向量,并绘制图片
  1. vets = svc.support_vectors_: #这行代码获取 svc 对象中支持向量的数据,并将其存储在变量 vets 中。
  2. plt.scatter(vets[:, 0], vets[:, 1], c='r', marker='x'): #这行代码使用 matplotlib 的 scatter 函数绘制 vets DataFrame 中的第一列和第二列的数据点,并使用红色作为标记和线条颜色。
  3. plt.show(): #这行代码使用 matplotlib 的 show 函数显示之前绘制的图形。

图片结果为:

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号