赞
踩
工欲善其事,必先利其器,要写出好的MATLAB代码,先从最基础的代码开始做起。
以下内容整理自Standford University的Boyd老师的课件和论文。
ADMM问题的基本形式:
最优化问题形式包括两组可分离自变量和线性等式约束:
min
x
,
z
f
(
x
)
+
g
(
z
)
s
.
t
.
A
x
+
B
z
=
c
写出该问题对应的拉格朗日函数式:
L
ρ
(
x
,
z
,
λ
)
=
f
(
x
)
+
g
(
z
)
+
y
T
(
A
x
+
B
z
−
c
)
+
ρ
2
∥
A
x
+
B
z
−
c
∥
2
2
L_{\rho}(\mathbf{x}, \mathbf{z}, \mathbf{\lambda}) =f(\mathbf{x})+g(\textbf{z})+\mathbf{y}^T(\mathbf{Ax}+\mathbf{Bz} - \mathbf{c})+\frac{\rho}{2}{\Vert\mathbf{Ax}+\mathbf{Bz} - \mathbf{c}\Vert^2_2}
Lρ(x,z,λ)=f(x)+g(z)+yT(Ax+Bz−c)+2ρ∥Ax+Bz−c∥22
按如下步骤,按照Gauss-Seidel方法更新迭代:
x
k
+
1
=
a
r
g
m
i
n
x
L
ρ
(
x
,
z
k
,
y
k
)
z
k
+
1
=
a
r
g
m
i
n
z
L
ρ
(
x
k
+
1
,
z
,
y
k
)
y
k
+
1
=
y
k
+
ρ
(
A
x
k
+
1
+
B
z
k
+
1
−
c
)
进一步,如在拉格朗日函数中定义
u
k
=
(
1
/
ρ
)
y
k
\mathbf{u}^k = (1/\rho)\mathbf{y}^k
uk=(1/ρ)yk,放缩对偶变量,拉格朗日函数变为:
L
ρ
(
x
,
z
,
λ
)
=
f
(
x
)
+
g
(
z
)
+
ρ
2
∥
A
x
+
B
z
−
c
+
u
∥
2
2
+
c
o
n
s
t
L_{\rho}(\mathbf{x}, \mathbf{z}, \mathbf{\lambda}) =f(\mathbf{x})+g(\textbf{z})+\frac{\rho}{2}{\Vert\mathbf{Ax}+\mathbf{Bz} - \mathbf{c}+\mathbf{u}\Vert^2_2+const}
Lρ(x,z,λ)=f(x)+g(z)+2ρ∥Ax+Bz−c+u∥22+const
相对应的迭代式子变为:
x
k
+
1
=
a
r
g
m
i
n
x
f
(
x
)
+
ρ
2
∥
A
x
+
B
z
k
−
c
+
u
k
∥
2
2
z
k
+
1
=
a
r
g
m
i
n
z
g
(
z
)
+
ρ
2
∥
A
x
k
+
1
+
B
z
−
c
+
u
k
∥
2
2
u
k
+
1
=
u
k
+
A
x
k
+
1
+
B
z
k
+
1
−
c
到此为止,上述仍然是一些抽象的概念,与实际问题并无什么联系。好在Boyd老师在网站上面挂出来了一些案例,但是这些案例难懂,我费了一些功夫终于看明白其中一二。
有些符号看得不是很清楚,总结于以下
前面已经知道
x
x
x的更新值为
x
+
=
a
r
g
m
i
n
x
f
(
x
)
+
ρ
2
∥
A
x
+
B
z
k
−
c
+
u
k
∥
2
2
{x}^{+} = \mathop{argmin}\limits_x f({x})+ \frac{\rho}{2}{\Vert{Ax}+{Bz}^k - {c}+{u}^k\Vert^2_2}
x+=xargminf(x)+2ρ∥Ax+Bzk−c+uk∥22
令
v
=
−
B
z
+
c
−
u
v = -Bz+c-u
v=−Bz+c−u,
x
+
=
a
r
g
m
i
n
x
f
(
x
)
+
ρ
2
∥
A
x
−
v
∥
2
2
{x}^{+} = \mathop{argmin}\limits_x f({x})+ \frac{\rho}{2}{\Vert{Ax}-v\Vert^2_2}
x+=xargminf(x)+2ρ∥Ax−v∥22
又令
A
=
I
A=I
A=I:
x
+
=
a
r
g
m
i
n
x
f
(
x
)
+
ρ
2
∥
x
−
v
∥
2
2
(*)
该式子的右端项可以用
prox
f
,
ρ
(
v
)
\textbf{prox}_{f,\rho}(v)
proxf,ρ(v)来表示,我暂时将其翻译为:函数
f
f
f带惩罚项
ρ
\rho
ρ的邻近算子。所以,后面有一些式子会用式子(*)来简略表示。
当函数
f
f
f足够简单,
x
x
x的更新值成为前面所说到的邻近算子的形式,可以用解析的办法分析。其中一个例子是,假如说
f
f
f是一个非空闭凸集的指示函数(
indicator function
\text{indicator function}
indicator function),那么也可以将x的更新值表示为:
x
+
=
a
r
g
m
i
n
x
f
(
x
)
+
ρ
2
∥
x
−
v
∥
2
2
=
Π
C
(
v
)
{x}^{+} = \mathop{argmin}\limits_x f({x})+ \frac{\rho}{2}{\Vert{x}-v\Vert^2_2}=\Pi_\mathcal{C}(v)
x+=xargminf(x)+2ρ∥x−v∥22=ΠC(v)
其中,
Π
C
\Pi_\mathcal{C}
ΠC表示向
C
\mathcal{C}
C上的欧式范数的投影.
感谢前辈的文章1,2,基本了解软阈值的情况。
软阈值问题的形式为:
S
κ
(
a
)
=
{
a
−
κ
a
>
κ
0
∣
a
∣
≤
κ
a
+
κ
a
<
−
κ
=
(
a
−
κ
)
+
−
(
−
a
−
κ
)
+
S_\kappa(a) =
是在求解优化问题形如:
a
r
g
m
i
n
x
∥
x
−
B
∥
2
2
+
λ
∥
x
∥
1
\mathop{argmin}\limits_x\Vert x-B\Vert^2_2+\lambda\Vert x\Vert_1
xargmin∥x−B∥22+λ∥x∥1
的时候作用的,用于标记
x
x
x更新值的取值。其中
B
=
[
b
1
,
b
2
,
…
,
b
n
]
B = [b_1, b_2, \dots, b_n]
B=[b1,b2,…,bn]。由于
∥
x
∥
1
\Vert x\Vert_1
∥x∥1并不可微,所以通过分类讨论的结果:
S
λ
/
2
(
b
)
=
{
b
−
λ
2
b
>
λ
2
0
∣
b
∣
<
λ
2
b
+
λ
2
b
<
−
λ
2
=
(
b
−
λ
2
)
+
−
(
−
b
−
λ
2
)
+
S_{\lambda/2}(b) =
例如,在Boyd老师的论文中就提到,
x
i
x_i
xi更新值:
x
i
+
=
a
r
g
m
i
n
x
i
(
λ
∣
x
i
∣
+
(
ρ
/
2
)
(
x
i
−
v
i
)
2
)
.
x_i^+=\mathop{argmin}\limits_{x_i}(\lambda|x_i|+(\rho/2)(x_i-v_i)^2).
xi+=xiargmin(λ∣xi∣+(ρ/2)(xi−vi)2).
可以得到相应更新值:
x
i
+
:
=
S
λ
/
ρ
(
v
i
)
x_i^+:=S_{\lambda/\rho}(v_i)
xi+:=Sλ/ρ(vi)
以下选择Lasso问题的代码进行分析。代码摘自网站,此处只是复刻编程的流程,整理一下思路。
代码分为lasso.m, objective.m, shrinkage.m, factor.m等,在运行中实际起到作用分别是:
lasso问题的目标函数为:
m
i
n
i
m
i
z
e
1
2
∥
A
x
−
b
∥
2
2
+
λ
∥
x
∥
1
minimize \quad \frac{1}{2}\Vert Ax-b\Vert^2_2+\lambda\Vert x\Vert_1
minimize21∥Ax−b∥22+λ∥x∥1
写成ADMM方法所能够求解的格式:
m
i
n
i
m
i
z
e
1
2
∥
A
x
−
b
∥
2
2
+
λ
∥
z
∥
1
subject to
x
−
z
=
0
迭代的表达式为:
x
k
+
1
:
=
(
A
T
A
+
ρ
I
)
−
1
(
A
T
b
+
ρ
(
z
k
−
u
k
)
)
z
k
+
1
:
=
S
λ
/
ρ
(
x
k
+
1
+
u
k
)
u
k
+
1
:
=
u
k
+
x
k
+
1
−
z
k
+
1
function [z, history] = lasso(A, b, lambda, rho, alpha) % lasso Solve lasso problem via ADMM % [z, history] = lasso(A, b, lambda, rho, alpha); % Solves the following problem via ADMM: % minimize 1/2*|| Ax - b ||_2^2 + \lambda || x ||_1 % The solution is returned in the vector x. % history is a structure that contains the objective value, the primal and % dual residual norms, and the tolerances for the primal and dual residual % norms at each iteration. % rho is the augmented Lagrangian parameter. % alpha is the over-relaxation parameter (typical values for alpha are % between 1.0 and 1.8). % More information can be found in the paper linked at %:http://www.stanford.edu/~boyd/papers/distr_opt_stat_learning_admm.html % t_start = tic; Global constants and defaults QUIET = 0; MAX_ITER = 1000; ABSTOL = 1e-4; RELTOL = 1e-2; Data preprocessing [m, n] = size(A); % save a matrix-vector multiply Atb = A'*b; ADMM solver x = zeros(n,1); z = zeros(n,1); u = zeros(n,1); % cache the factorization [L U] = factor(A, rho); if ~QUIET fprintf('%3s\t%10s\t%10s\t%10s\t%10s\t%10s\n', 'iter', ... 'r norm', 'eps pri', 's norm', 'eps dual', 'objective'); end for k = 1:MAX_ITER % x-update q = Atb + rho*(z - u); % temporary value if( m >= n ) % if skinny x = U \ (L \ q); else % if fat x = q/rho - (A'*(U \ ( L \ (A*q) )))/rho^2; end % z-update with relaxation zold = z; x_hat = alpha*x + (1 - alpha)*zold; z = shrinkage(x_hat + u, lambda/rho); % u-update u = u + (x_hat - z); % diagnostics, reporting, termination checks history.objval(k) = objective(A, b, lambda, x, z); history.r_norm(k) = norm(x - z); history.s_norm(k) = norm(-rho*(z - zold)); history.eps_pri(k) = sqrt(n)*ABSTOL + RELTOL*max(norm(x), norm(-z)); history.eps_dual(k)= sqrt(n)*ABSTOL + RELTOL*norm(rho*u); if ~QUIET fprintf('%3d\t%10.4f\t%10.4f\t%10.4f\t%10.4f\t%10.2f\n', k, ... history.r_norm(k), history.eps_pri(k), ... history.s_norm(k), history.eps_dual(k), history.objval(k)); end if (history.r_norm(k) < history.eps_pri(k) && ... history.s_norm(k) < history.eps_dual(k)) break; end end if ~QUIET toc(t_start); end end
function p = objective(A, b, lambda, x, z)
p = ( 1/2*sum((A*x - b).^2) + lambda*norm(z,1) );
end
此处表达的是式子:
z
k
+
1
:
=
S
λ
/
ρ
(
x
k
+
1
+
u
k
)
=
(
x
k
+
1
+
u
k
−
λ
/
ρ
)
+
−
(
−
x
k
+
1
−
u
k
−
λ
/
ρ
)
z^{k+1}:=S_{\lambda/\rho}(x^{k+1}+u^k)=(x^{k+1}+u^k-\lambda/\rho)_+-(-x^{k+1}-u^k-\lambda/\rho)
zk+1:=Sλ/ρ(xk+1+uk)=(xk+1+uk−λ/ρ)+−(−xk+1−uk−λ/ρ)
function z = shrinkage(x, kappa)
z = max( 0, x - kappa ) - max( 0, -x - kappa );
end
A A A是一个 m × n m\times n m×n的矩阵。当 m < n m<n m<n,由于在 x x x的更新式中, A T A + ρ I A^TA+\rho I ATA+ρI是一个 n × n n\times n n×n的矩阵,而提一个 1 / ρ 1/\rho 1/ρ后,变为 I + ( 1 / ρ ) A A T I+(1/\rho)AA^T I+(1/ρ)AAT,是一个 m × m m\times m m×m大小的矩阵。由于矩阵求逆的复杂度为 O ( n 3 ) O(n^3) O(n3),因此,后面这种做法更便于求解。结合稀疏矢量技术完成文中代码。
function [L U] = factor(A, rho)
[m, n] = size(A);
if ( m >= n ) % if skinny
L = chol( A'*A + rho*speye(n), 'lower' );
else % if fat
L = chol( speye(m) + 1/rho*(A*A'), 'lower' );
end
% force matlab to recognize the upper / lower triangular structure
L = sparse(L);
U = sparse(L');
end
难点主要在于要矩阵表达式的运算和推导,容易出错,因此给复现带来了困难。实际案例(Example)见于链接,为上面函数的应用。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。