当前位置:   article > 正文

字典学习 (Dictionary Learning) —— K-SVD 算法_字典学习算法

字典学习算法

论文

M. Aharon, M. Elad and A. Bruckstein, “K-SVD: An algorithm for designing overcomplete dictionaries for sparse representation,” in IEEE Transactions on Signal Processing, vol. 54, no. 11, pp. 4311-4322, Nov. 2006.

问题描述

min ⁡ D , X ∣ ∣ Y − D X ∣ ∣ F s . t . ∣ ∣ x i ∣ ∣ 0 < T 0 , ∀ i

minD,X||YDX||Fs.t.||xi||0<T0,i
minD,Xs.t.YDXFxi0<T0,i
其中 Y ∈ R M × L Y\in R^{M\times L} YRM×L为原始数据, D ∈ R M × N D\in R^{M\times N} DRM×N为字典, X ∈ R N × L X\in R^{N\times L} XRN×L为编码。

M M M 表示数据特征维度, L L L表示样本数, N N N 表示字典大小。

优化的目标是找到原始数据的稀疏表示,要求 X X X的每一列 x i x_i xi的非零元数目小于 T 0 T_0 T0
在这里插入图片描述

求解原理

交替优化:

  • 固定 D D D,优化 X X X,主要用到正交匹配跟踪 (OMP)
  • 固定 X X X,优化 D D D,主要用到奇异值分解 (SVD)

在这里插入图片描述

python 实现

KSVD 算法

from sklearn import linear_model

def KSVD(Y, dict_size,
         max_iter = 10,
         sparse_rate = 0.2,
         tolerance = 1e-6):
    
    assert(dict_size <= Y.shape[1])

    def dict_update(y, d, x):
        assert(d.shape[1] == x.shape[0])

        for i in range(x.shape[0]):
            index = np.where(np.abs(x[i, :]) > 1e-7)[0]

            if len(index) == 0:
                continue

            d[:, i] = 0
            r = (y - np.dot(d, x))[:, index]
            u, s, v = np.linalg.svd(r, full_matrices=False)
            d[:, i] = u[:, 0]
            for j,k in enumerate(index):
                x[i, k] = s[0] * v[0, j]
        return d, x


    # initialize dictionary
    if dict_size > Y.shape[0]:
        dic = Y[:, np.random.choice(Y.shape[1], dict_size, replace=False)]
    else:
        u, s, v = np.linalg.svd(Y)
        dic = u[:, :dict_size]
        
    print('dict shape:', dic.shape)
    
    n_nonzero_coefs_each_code = int(sparse_rate * dict_size) if int(sparse_rate * dict_size) > 0 else 1
    for i in range(max_iter):
        x = linear_model.orthogonal_mp(dic, Y, n_nonzero_coefs = n_nonzero_coefs_each_code)
        e = np.linalg.norm(Y - dic @ x)
        if e < tolerance:
            break
        dict_update(Y, dic, x)

    sparse_code = linear_model.orthogonal_mp(dic, Y, n_nonzero_coefs = n_nonzero_coefs_each_code)
    
    return dic, sparse_code
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

测试

Y = D X Y = D X Y=DX

import numpy as np
import scipy.sparse as ss

# 生成随机稀疏矩阵 X
num_col_X = 30
num_row_X = 10
num_ele_X = 40
a = [np.random.randint(0,num_row_X) for _ in range(num_ele_X)]
b = [np.random.randint(0,num_col_X) for _ in range(num_ele_X - num_col_X)] + [i for i in range(num_col_X)]
c = [np.random.rand()*10 for _ in range(num_ele_X)]
rows, cols, v = np.array(a), np.array(b), np.array(c)
sparseX = ss.coo_matrix((v,(rows,cols)))
X = sparseX.todense()

# 随机生成字典 D
num_row_D = 10
num_col_D = num_row_X
D = np.random.random((num_row_D,num_col_D))

# 生成 Y
Y = D @ X
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

原始数据
在这里插入图片描述
完备字典

dic, code = KSVD(Y, 10)
Y_reconstruct = dic @ code
  • 1
  • 2

在这里插入图片描述
欠完备字典

dic, code = KSVD(Y, 5)
Y_reconstruct = dic @ code
  • 1
  • 2

在这里插入图片描述
超完备字典

dic, code = KSVD(Y, 15)
Y_reconstruct = dic @ code
  • 1
  • 2

在这里插入图片描述

结果可视化函数

def showmat(X, cmap='Oranges'):
    fig = plt.figure(figsize=(10,5))
    ax = fig.add_subplot(111)
    X_abs = np.abs(X)
    ax.matshow(X_abs, vmin=np.min(X_abs), vmax=np.max(X_abs), cmap=cmap)
    ax.set_xticks([])
    ax.set_yticks([])

showmat(Y_reconstruct), showmat(Y)
showmat(code,'Greens'), showmat(X,'Greens')
showmat(dic,'Reds'), showmat(D, 'Reds')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号