当前位置:   article > 正文

《统计学习方法》第7章 课后习题_统计学习导论第七章答案

统计学习导论第七章答案

这一章尤为复杂,我看了好多资料还有博客,数学功底差总是吃亏的

1.1 比较感知机的对偶形式与线性可分支持向量机的对偶形式。

可以根据线性向量机:求解,这边我直接只用自己写的程序求解。w,b为

https://cuijiahua.com/blog/2017/11/ml_8_svm_1.html 这网站很不错,讲解的很清楚

转载注明:机器学习实战教程(八):支持向量机原理篇之手撕线性SVM | Jack Cui

自己写的程序很有问题,只是能来对算法流程加深印象,程序性能并不好

  1. #-*- coding:UTF-8 -*-
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. class SVM:
  5. def __init__(self,datamat,labelmat,model='RBF'):
  6. self.X = datamat
  7. self.Y = labelmat
  8. #实验时得出
  9. #在非线性时30-33会有最好效果,我参考的代码能达到0.04的错误率,我只能0.06的,应该是选择alpha1,2时有些问题
  10. #但在线性时将会因松弛因子大将所有xi都考虑,过拟合
  11. self.C = 33.0 if model == 'RBF' else 0.02
  12. self.N,self.M = np.array(self.X).shape
  13. self.alpha = np.array([0.0]*self.N)
  14. self.b = 0.0
  15. self.model = model
  16. self.updateTrainParam()
  17. #推导L(w,b,a)时xi的转置是由w的转置带过来的,书中默认有
  18. #注意所有xi输入时已经是xi的转置了
  19. def K(self,xi,xj,e=0.1):
  20. #np中矩阵转置,[1,2,3]没有转置 一维[[1,2,3]]^T = 三维[[1],[2],[3]]
  21. #由于[]取值方便,在这里进行转换
  22. xi_arr = np.array([xi])
  23. xj_arr = np.array([xj])
  24. if self.model == 'liner':
  25. #这时点乘为一个数,但依旧是矩阵表示 [[]]
  26. return (xi_arr.dot(xj_arr.T))[0][0]
  27. #高斯核函数
  28. elif self.model == 'RBF':
  29. x = xi_arr-xj_arr
  30. return np.exp(x.dot(x.T)/(-1*e**2))[0][0]
  31. #计算g(xi)
  32. def computGx(self,i):
  33. gxi = 0.0
  34. for j in range(self.N):
  35. gxi += self.alpha[j]*self.Y[j]*self.K(self.X[j], self.X[i])
  36. return gxi+self.b
  37. #计算yi*g(xi)
  38. def computyMg(self,i):
  39. return self.Y[i]*self.Gx[i]
  40. #是否满足kkt条件
  41. def KKT(self,i):
  42. if self.alpha[i] == 0:
  43. return self.computyMg(i) >= 1
  44. elif 0 < self.alpha[i] < self.C:
  45. return self.computyMg(i) == 0
  46. else:
  47. return self.computyMg(i) <= 1
  48. #计算误差Ei
  49. def computEi(self,i):
  50. return self.Gx[i] - self.Y[i]
  51. #名字取错,由于不用全部重新更新,只调用一次初始化
  52. def updateTrainParam(self):
  53. self.Gx = [self.computGx(i) for i in range(self.N)]
  54. self.Ei = [self.computEi(i) for i in range(self.N)]
  55. #不使用精度,限制迭代次数
  56. def train(self,max_itar,rate):
  57. #计算上下限L,H
  58. def computLH(i,j):
  59. L = 0.0
  60. H = 0.0
  61. s = self.alpha[j]-self.alpha[i]
  62. a = self.alpha[j]+self.alpha[i]
  63. if self.Y[i] == self.Y[j]:
  64. L = max([0,a-self.C])
  65. H = min([self.C,a])
  66. else:
  67. L = max([0,s])
  68. H = min([self.C,self.C+s])
  69. return L,H
  70. #计算a2new,unc
  71. def computAlp2newunc(i,j):
  72. n = self.K(self.X[i], self.X[i])+self.K(self.X[j], self.X[j])-2*self.K(self.X[i], self.X[j])
  73. return self.alpha[j] + self.Y[j]*(self.Ei[i] - self.Ei[j])/n
  74. #计算a2new
  75. def computAlp2new(i,j):
  76. a2newunc = computAlp2newunc(i, j)
  77. L,H = computLH(i, j)
  78. if a2newunc > H:
  79. return H
  80. elif a2newunc < L:
  81. return L
  82. else:
  83. return a2newunc
  84. #计算a1new,a2new
  85. def computA1A2new(i,j):
  86. a2new = computAlp2new(i, j)
  87. a1new = self.alpha[i] + self.Y[i]*self.Y[j]*(self.alpha[j]-a2new)
  88. return a1new,a2new
  89. #计算a1new,a2new,bnew
  90. def computAllnew(i,j):
  91. bnew = 0.0
  92. a1new,a2new = computA1A2new(i,j)
  93. b1new = -self.Ei[i]-self.Y[i]*self.K(self.X[i], self.X[i])*(a1new - self.alpha[i])\
  94. -self.Y[j]*self.K(self.X[i], self.X[j])*(a2new - self.alpha[j])+self.b
  95. b2new = -self.Ei[j]-self.Y[i]*self.K(self.X[i], self.X[j])*(a1new - self.alpha[i])\
  96. -self.Y[j]*self.K(self.X[j], self.X[j])*(a2new - self.alpha[j])+self.b
  97. if 0< a1new < self.C:
  98. bnew = b1new
  99. elif 0< a2new < self.C:
  100. bnew = b2new
  101. else:
  102. bnew = (b1new+b2new)/2
  103. return a1new,a2new,bnew
  104. #训练开始处
  105. for k in range(max_itar):
  106. #外循环列表
  107. supvet_indx = [i for i in range(self.N) if 0< self.alpha[i] <self.C]
  108. other_indx = [i for i in range(self.N) if i not in supvet_indx]
  109. supvet_indx.extend(other_indx)
  110. #1.选择ai
  111. for i in supvet_indx:
  112. #2.a1符合kkt条件
  113. if self.KKT(i):
  114. continue
  115. #3.通过ai的E1选择E2,得到a2
  116. if self.Ei[i] >0:
  117. j = self.Ei.index(min(self.Ei))
  118. else:
  119. j = self.Ei.index(max(self.Ei))
  120. #4.a1!=a2
  121. if i == j:
  122. continue
  123. a2old = self.alpha[j]
  124. #计算a1new,a2new,bnew
  125. a1new,a2new,bnew = computAllnew(i, j)
  126. #a2的变化率足够小
  127. if(abs(a2old - a2new) < rate):
  128. break
  129. #更新a1,a2,b
  130. self.alpha[i] = a1new
  131. self.alpha[j] = a2new
  132. self.b = bnew
  133. #更新G(x)和Ei,减少计算
  134. self.Gx[i] = self.computGx(i)
  135. self.Gx[j] = self.computGx(j)
  136. self.Ei[i] = self.computEi(i)
  137. self.Ei[j] = self.computEi(j)
  138. print(k)
  139. #计算w
  140. def computW(self):
  141. w = np.array([[0.0]*self.M])
  142. X = np.array(self.X)
  143. for i in range(self.N):
  144. w += self.alpha[i]*self.Y[i]*X[i]
  145. return w[0]
  146. #预测分数
  147. def predict(self,X_test,Y_test):
  148. def K_arr(xi):
  149. return np.array([[self.K(xi, xj) for xj in self.X]])
  150. err_count = 0
  151. for i,xi in enumerate(X_test):
  152. Karr = K_arr(xi)
  153. fsign = np.sign(Karr.dot(np.array([self.Y*self.alpha]).T)[0][0]+self.b)
  154. if fsign != Y_test[i]:
  155. err_count += 1
  156. return err_count/self.N
  157. def loadDataSet(filename):
  158. datamat = []
  159. labelmat = []
  160. with open(filename) as fr:
  161. for line in fr.readlines():
  162. lineArr = line.strip().split('\t')
  163. datamat.append([float(lineArr[0]),float(lineArr[1])])
  164. labelmat.append(float(lineArr[2]))
  165. return datamat,labelmat
  166. def showDataSet(dataMat,labelMat):
  167. data_plus = []
  168. data_minus = []
  169. for i in range(len(dataMat)):
  170. if labelMat[i] > 0:
  171. data_plus.append(dataMat[i])
  172. else:
  173. data_minus.append(dataMat[i])
  174. data_plus_np = np.array(data_plus)
  175. data_minus_np = np.array(data_minus)
  176. plt.scatter(np.transpose(data_plus_np)[0],np.transpose(data_plus_np)[1])
  177. plt.scatter(np.transpose(data_minus_np)[0],np.transpose(data_minus_np)[1])
  178. plt.show()
  179. def showClassifer(dataMat, labelMat,w, b,alphas):
  180. #绘制样本点
  181. data_plus = [] #正样本
  182. data_minus = [] #负样本
  183. for i in range(len(dataMat)):
  184. if labelMat[i] > 0:
  185. data_plus.append(dataMat[i])
  186. else:
  187. data_minus.append(dataMat[i])
  188. data_plus_np = np.array(data_plus) #转换为numpy矩阵
  189. data_minus_np = np.array(data_minus) #转换为numpy矩阵
  190. plt.scatter(np.transpose(data_plus_np)[0], np.transpose(data_plus_np)[1], s=30, alpha=0.7) #正样本散点图
  191. plt.scatter(np.transpose(data_minus_np)[0], np.transpose(data_minus_np)[1], s=30, alpha=0.7) #负样本散点图
  192. #绘制直线
  193. x1 = max(dataMat)[0]
  194. x2 = min(dataMat)[0]
  195. a1, a2 = w
  196. b = float(b)
  197. a1 = float(a1)
  198. a2 = float(a2)
  199. y1, y2 = (-b- a1*x1)/a2, (-b - a1*x2)/a2
  200. plt.plot([x1, x2], [y1, y2])
  201. #找出支持向量
  202. for i, alpha in enumerate(alphas):
  203. if alpha > 0:
  204. x, y = dataMat[i]
  205. plt.scatter([x], [y], s=150, c='none', alpha=0.7, linewidth=1.5, edgecolor='red')
  206. plt.show()
  207. def main():
  208. #习题2
  209. x=[[1, 2], [2, 3], [3, 3], [2, 1], [3, 2]]
  210. y=[1, 1, 1, -1, -1]
  211. svm_lin = SVM(x, y,'liner')
  212. svm_lin.train(100,0.0001)
  213. w = svm_lin.computW()
  214. print(w,svm_lin.b)
  215. showClassifer(x, y,w,svm_lin.b,svm_lin.alpha)
  216. """
  217. X_lin,Y_lin = loadDataSet('testSet.txt')
  218. svm_lin = SVM(X_lin, Y_lin,'liner')
  219. #svm_lin = SVM(X_lin, Y_lin,'RBF')#非线性比较通用,虽然所构造的线我不知如何体现
  220. svm_lin.train(100,0.0001)
  221. print(svm_lin.predict(X_lin, Y_lin))
  222. showClassifer(X_lin, Y_lin,svm_lin.computW(),svm_lin.b,svm_lin.alpha)
  223. X_test,Y_test = loadDataSet('testSetRBF2.txt')
  224. X_RBF,Y_RBF = loadDataSet('testSetRBF.txt')
  225. svm_RFB = SVM(X_RBF, Y_RBF,'RBF')
  226. svm_RFB.train(100,0.0001)
  227. print(svm_RFB.predict(X_test, Y_test))
  228. showDataSet(X_test,Y_test)
  229. """
  230. pass
  231. if __name__=='__main__':
  232. main()

看了一些博主的没理解,先放放

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/331317
推荐阅读
相关标签
  

闽ICP备14008679号