当前位置:   article > 正文

一篇论文复现的整体思路和复现记录(三,基础实现篇)_复现论文

复现论文

工欲善其事,必先利其器,要写出好的MATLAB代码,先从最基础的代码开始做起。

1 ADMM原理

以下内容整理自Standford University的Boyd老师的课件论文
ADMM问题的基本形式:
最优化问题形式包括两组可分离自变量和线性等式约束:
min ⁡ x , z f ( x ) + g ( z ) s . t . A x + B z = c

minx,zf(x)+g(z)s.t.Ax+Bz=c
x,zmins.t.f(x)+g(z)Ax+Bz=c
写出该问题对应的拉格朗日函数式:
L ρ ( x , z , λ ) = f ( x ) + g ( z ) + y T ( A x + B z − c ) + ρ 2 ∥ A x + B z − c ∥ 2 2 L_{\rho}(\mathbf{x}, \mathbf{z}, \mathbf{\lambda}) =f(\mathbf{x})+g(\textbf{z})+\mathbf{y}^T(\mathbf{Ax}+\mathbf{Bz} - \mathbf{c})+\frac{\rho}{2}{\Vert\mathbf{Ax}+\mathbf{Bz} - \mathbf{c}\Vert^2_2} Lρ(x,z,λ)=f(x)+g(z)+yT(Ax+Bzc)+2ρAx+Bzc22
按如下步骤,按照Gauss-Seidel方法更新迭代:
x k + 1 = a r g m i n x L ρ ( x , z k , y k ) z k + 1 = a r g m i n z L ρ ( x k + 1 , z , y k ) y k + 1 = y k + ρ ( A x k + 1 + B z k + 1 − c )
xk+1=argminxLρ(x,zk,yk)zk+1=argminzLρ(xk+1,z,yk)yk+1=yk+ρ(Axk+1+Bzk+1c)
xk+1zk+1yk+1=xargminLρ(x,zk,yk)=zargminLρ(xk+1,z,yk)=yk+ρ(Axk+1+Bzk+1c)

进一步,如在拉格朗日函数中定义 u k = ( 1 / ρ ) y k \mathbf{u}^k = (1/\rho)\mathbf{y}^k uk=(1/ρ)yk,放缩对偶变量,拉格朗日函数变为:
L ρ ( x , z , λ ) = f ( x ) + g ( z ) + ρ 2 ∥ A x + B z − c + u ∥ 2 2 + c o n s t L_{\rho}(\mathbf{x}, \mathbf{z}, \mathbf{\lambda}) =f(\mathbf{x})+g(\textbf{z})+\frac{\rho}{2}{\Vert\mathbf{Ax}+\mathbf{Bz} - \mathbf{c}+\mathbf{u}\Vert^2_2+const} Lρ(x,z,λ)=f(x)+g(z)+2ρAx+Bzc+u22+const
相对应的迭代式子变为:
x k + 1 = a r g m i n x f ( x ) + ρ 2 ∥ A x + B z k − c + u k ∥ 2 2 z k + 1 = a r g m i n z g ( z ) + ρ 2 ∥ A x k + 1 + B z − c + u k ∥ 2 2 u k + 1 = u k + A x k + 1 + B z k + 1 − c
xk+1=argminxf(x)+ρ2Ax+Bzkc+uk22zk+1=argminzg(z)+ρ2Axk+1+Bzc+uk22uk+1=uk+Axk+1+Bzk+1c
xk+1zk+1uk+1=xargminf(x)+2ρAx+Bzkc+uk22=zargming(z)+2ρAxk+1+Bzc+uk22=uk+Axk+1+Bzk+1c

到此为止,上述仍然是一些抽象的概念,与实际问题并无什么联系。好在Boyd老师在网站上面挂出来了一些案例,但是这些案例难懂,我费了一些功夫终于看明白其中一二。

2 部分特殊表示

有些符号看得不是很清楚,总结于以下

邻近算子( Proximity Operator \text{Proximity Operator} Proximity Operator

前面已经知道 x x x的更新值为
x + = a r g m i n x f ( x ) + ρ 2 ∥ A x + B z k − c + u k ∥ 2 2 {x}^{+} = \mathop{argmin}\limits_x f({x})+ \frac{\rho}{2}{\Vert{Ax}+{Bz}^k - {c}+{u}^k\Vert^2_2} x+=xargminf(x)+2ρAx+Bzkc+uk22
v = − B z + c − u v = -Bz+c-u v=Bz+cu
x + = a r g m i n x f ( x ) + ρ 2 ∥ A x − v ∥ 2 2 {x}^{+} = \mathop{argmin}\limits_x f({x})+ \frac{\rho}{2}{\Vert{Ax}-v\Vert^2_2} x+=xargminf(x)+2ρAxv22
又令 A = I A=I A=I
x + = a r g m i n x f ( x ) + ρ 2 ∥ x − v ∥ 2 2 (*)

x+=argminxf(x)+ρ2xv22
\tag{*} x+=xargminf(x)+2ρxv22(*)
该式子的右端项可以用 prox f , ρ ( v ) \textbf{prox}_{f,\rho}(v) proxf,ρ(v)来表示,我暂时将其翻译为:函数 f f f带惩罚项 ρ \rho ρ的邻近算子。所以,后面有一些式子会用式子(*)来简略表示。

向集合的投影( Projection \text{Projection} Projection

当函数 f f f足够简单, x x x的更新值成为前面所说到的邻近算子的形式,可以用解析的办法分析。其中一个例子是,假如说 f f f是一个非空闭凸集的指示函数( indicator function \text{indicator function} indicator function),那么也可以将x的更新值表示为:
x + = a r g m i n x f ( x ) + ρ 2 ∥ x − v ∥ 2 2 = Π C ( v ) {x}^{+} = \mathop{argmin}\limits_x f({x})+ \frac{\rho}{2}{\Vert{x}-v\Vert^2_2}=\Pi_\mathcal{C}(v) x+=xargminf(x)+2ρxv22=ΠC(v)
其中, Π C \Pi_\mathcal{C} ΠC表示向 C \mathcal{C} C上的欧式范数的投影.

软阈值( Soft Thresholding \text{Soft Thresholding} Soft Thresholding

感谢前辈的文章12,基本了解软阈值的情况。
软阈值问题的形式为:
S κ ( a ) = { a − κ a > κ 0 ∣ a ∣ ≤ κ a + κ a < − κ = ( a − κ ) + − ( − a − κ ) + S_\kappa(a) =

{aκa>κ0|a|κa+κa<κ
=(a-\kappa)_+-(-a-\kappa)_+ Sκ(a)= aκ0a+κa>κaκa<κ=(aκ)+(aκ)+
是在求解优化问题形如:
a r g m i n x ∥ x − B ∥ 2 2 + λ ∥ x ∥ 1 \mathop{argmin}\limits_x\Vert x-B\Vert^2_2+\lambda\Vert x\Vert_1 xargminxB22+λx1
的时候作用的,用于标记 x x x更新值的取值。其中 B = [ b 1 , b 2 , … , b n ] B = [b_1, b_2, \dots, b_n] B=[b1,b2,,bn]。由于 ∥ x ∥ 1 \Vert x\Vert_1 x1并不可微,所以通过分类讨论的结果:
S λ / 2 ( b ) = { b − λ 2 b > λ 2 0 ∣ b ∣ < λ 2 b + λ 2 b < − λ 2 = ( b − λ 2 ) + − ( − b − λ 2 ) + S_{\lambda/2}(b) =
{bλ2b>λ20|b|<λ2b+λ2b<λ2
=(b-\frac{\lambda}{2})_+-(-b-\frac{\lambda}{2})_+
Sλ/2(b)= b2λ0b+2λb>2λb<2λb<2λ=(b2λ)+(b2λ)+

例如,在Boyd老师的论文中就提到, x i x_i xi更新值:
x i + = a r g m i n x i ( λ ∣ x i ∣ + ( ρ / 2 ) ( x i − v i ) 2 ) . x_i^+=\mathop{argmin}\limits_{x_i}(\lambda|x_i|+(\rho/2)(x_i-v_i)^2). xi+=xiargmin(λxi+(ρ/2)(xivi)2).
可以得到相应更新值:
x i + : = S λ / ρ ( v i ) x_i^+:=S_{\lambda/\rho}(v_i) xi+:=Sλ/ρ(vi)

3 部分代码解析

以下选择Lasso问题的代码进行分析。代码摘自网站,此处只是复刻编程的流程,整理一下思路。

3.1 代码架构

代码分为lasso.m, objective.m, shrinkage.m, factor.m等,在运行中实际起到作用分别是:

  1. function lasso.m是整个ADMM的算法执行流程,包括数据记录,数据迭代,数据输出,迭代终点判断等内容;
  2. objective.m是整个优化函数的目标函数;
  3. shrinkage.m表示整个优化函数的目标函数;
  4. factor.m则是根据A的形状的不同,进行的分解。
  5. 通过实际案例检验所写的代码是否正确反映了算法。

3.2 代码解构

3.2.1 lasso.m

lasso问题的目标函数为:
m i n i m i z e 1 2 ∥ A x − b ∥ 2 2 + λ ∥ x ∥ 1 minimize \quad \frac{1}{2}\Vert Ax-b\Vert^2_2+\lambda\Vert x\Vert_1 minimize21Axb22+λx1
写成ADMM方法所能够求解的格式:
m i n i m i z e 1 2 ∥ A x − b ∥ 2 2 + λ ∥ z ∥ 1 subject to x − z = 0

minimize12Axb22+λz1subject toxz=0
minimizesubject to21Axb22+λz1xz=0
迭代的表达式为:
x k + 1 : = ( A T A + ρ I ) − 1 ( A T b + ρ ( z k − u k ) ) z k + 1 : = S λ / ρ ( x k + 1 + u k ) u k + 1 : = u k + x k + 1 − z k + 1
xk+1:=(ATA+ρI)1(ATb+ρ(zkuk))zk+1:=Sλ/ρ(xk+1+uk)uk+1:=uk+xk+1zk+1
xk+1zk+1uk+1:=(ATA+ρI)1(ATb+ρ(zkuk)):=Sλ/ρ(xk+1+uk):=uk+xk+1zk+1

function [z, history] = lasso(A, b, lambda, rho, alpha)
% lasso  Solve lasso problem via ADMM
% [z, history] = lasso(A, b, lambda, rho, alpha);
% Solves the following problem via ADMM:
%   minimize 1/2*|| Ax - b ||_2^2 + \lambda || x ||_1
% The solution is returned in the vector x.
% history is a structure that contains the objective value, the primal and
% dual residual norms, and the tolerances for the primal and dual residual
% norms at each iteration.
% rho is the augmented Lagrangian parameter.
% alpha is the over-relaxation parameter (typical values for alpha are
% between 1.0 and 1.8).
% More information can be found in the paper linked at
%:http://www.stanford.edu/~boyd/papers/distr_opt_stat_learning_admm.html
%

t_start = tic;
Global constants and defaults
QUIET    = 0;
MAX_ITER = 1000;
ABSTOL   = 1e-4; 
RELTOL   = 1e-2;
Data preprocessing
[m, n] = size(A);

% save a matrix-vector multiply
Atb = A'*b;
ADMM solver
x = zeros(n,1);
z = zeros(n,1);
u = zeros(n,1);

% cache the factorization
[L U] = factor(A, rho);

if ~QUIET
    fprintf('%3s\t%10s\t%10s\t%10s\t%10s\t%10s\n', 'iter', ...
      'r norm', 'eps pri', 's norm', 'eps dual', 'objective');
end

for k = 1:MAX_ITER

    % x-update
    q = Atb + rho*(z - u);    % temporary value
    if( m >= n )    % if skinny
       x = U \ (L \ q);
    else            % if fat
       x = q/rho - (A'*(U \ ( L \ (A*q) )))/rho^2;
    end

    % z-update with relaxation
    zold = z;
    x_hat = alpha*x + (1 - alpha)*zold;
    z = shrinkage(x_hat + u, lambda/rho);

    % u-update
    u = u + (x_hat - z);

    % diagnostics, reporting, termination checks
    history.objval(k)  = objective(A, b, lambda, x, z);

    history.r_norm(k)  = norm(x - z);
    history.s_norm(k)  = norm(-rho*(z - zold));

    history.eps_pri(k) = sqrt(n)*ABSTOL + RELTOL*max(norm(x), norm(-z));
    history.eps_dual(k)= sqrt(n)*ABSTOL + RELTOL*norm(rho*u);

    if ~QUIET
        fprintf('%3d\t%10.4f\t%10.4f\t%10.4f\t%10.4f\t%10.2f\n', k, ...
            history.r_norm(k), history.eps_pri(k), ...
            history.s_norm(k), history.eps_dual(k), history.objval(k));
    end

    if (history.r_norm(k) < history.eps_pri(k) && ...
       history.s_norm(k) < history.eps_dual(k))
         break;
    end

end

if ~QUIET
    toc(t_start);
end
end
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84

3.2.2 objective.m

function p = objective(A, b, lambda, x, z)
    p = ( 1/2*sum((A*x - b).^2) + lambda*norm(z,1) );
end
  • 1
  • 2
  • 3

3.2.3 shrinkage.m

此处表达的是式子:
z k + 1 : = S λ / ρ ( x k + 1 + u k ) = ( x k + 1 + u k − λ / ρ ) + − ( − x k + 1 − u k − λ / ρ ) z^{k+1}:=S_{\lambda/\rho}(x^{k+1}+u^k)=(x^{k+1}+u^k-\lambda/\rho)_+-(-x^{k+1}-u^k-\lambda/\rho) zk+1:=Sλ/ρ(xk+1+uk)=(xk+1+ukλ/ρ)+(xk+1ukλ/ρ)

function z = shrinkage(x, kappa)
    z = max( 0, x - kappa ) - max( 0, -x - kappa );
end
  • 1
  • 2
  • 3

3.2.4 factor.m

A A A是一个 m × n m\times n m×n的矩阵。当 m < n m<n m<n,由于在 x x x的更新式中, A T A + ρ I A^TA+\rho I ATA+ρI是一个 n × n n\times n n×n的矩阵,而提一个 1 / ρ 1/\rho 1/ρ后,变为 I + ( 1 / ρ ) A A T I+(1/\rho)AA^T I+(1/ρ)AAT,是一个 m × m m\times m m×m大小的矩阵。由于矩阵求逆的复杂度为 O ( n 3 ) O(n^3) O(n3),因此,后面这种做法更便于求解。结合稀疏矢量技术完成文中代码。

function [L U] = factor(A, rho)
    [m, n] = size(A);
    if ( m >= n )    % if skinny
       L = chol( A'*A + rho*speye(n), 'lower' );
    else            % if fat
       L = chol( speye(m) + 1/rho*(A*A'), 'lower' );
    end

    % force matlab to recognize the upper / lower triangular structure
    L = sparse(L);
    U = sparse(L');
end
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

3.2.5 难点和实施

难点主要在于要矩阵表达式的运算和推导,容易出错,因此给复现带来了困难。实际案例(Example)见于链接,为上面函数的应用。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小舞很执着/article/detail/967765
推荐阅读
相关标签
  

闽ICP备14008679号