赞
踩
import numpy as np import pandas as pd from scipy.io import loadmat train_data_mat = loadmat("../data/train_data2.mat") train_data = train_data_mat["Data"] train_label = train_data_mat["Label"] print(train_data.shape, train_label.shape) # 初始化字典 u, s, v = np.linalg.svd(train_data) n_comp = 50 dict_data = u[:, :n_comp] #字典更新 def dict_update(y, d, x, n_components): """ 使用KSVD更新字典的过程 """ for i in range(n_components): index = np.nonzero(x[i, :])[0] if len(index) == 0: continue # 更新第i列 d[:, i] = 0 # 计算误差矩阵 r = (y - np.dot(d, x))[:, index] # 利用svd的方法,来求解更新字典和稀疏系数矩阵 u, s, v = np.linalg.svd(r, full_matrices=False) # 使用左奇异矩阵的第0列更新字典 d[:, i] = u[:, 0] # 使用第0个奇异值和右奇异矩阵的第0行的乘积更新稀疏系数矩阵 for j,k in enumerate(index): x[i, k] = s[0] * v[0, j] return d, x '''注: 上面代码的16~17需要注意python的numpy中的普通索引和花式索引的区别,花式索引会产生一个原数组的副本,所以对花式索引的操作并不会改变原数据,因此不能像第10行一样,需利用直接索引更新x。''' # 迭代更新求解 from sklearn import linear_model max_iter = 10 dictionary = dict_data y = train_data tolerance = 1e-6 for i in range(max_iter): # 稀疏编码 x = linear_model.orthogonal_mp(dictionary, y) e = np.linalg.norm(y - np.dot(dictionary, x)) if e < tolerance: break dict_update(y, dictionary, x, n_comp) sparsecode = linear_model.orthogonal_mp(dictionary, y) train_restruct = dictionary.dot(sparsecode)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。