function [zk, BlockSpectraOut, history,U,D]=PEBSI_ADMM(A,b,BlockSizes,lambdaVec,xiVec,alphak,rho,ell,Uin,Din)
% Fast ADMM Solver for the optimization problem,
% minimize_x ||Ax-b||_2^2+lambda*||x||_1+ xi ||Fx||_1
% where x=[x_[1]^T, ..., x_[d]^T]] and each x_[k] is a complex column
% vector in C^n_k. F is an extended first difference matrix, producing
% first differences for each block in x.
%
%      INPUT
%
%      A          - In C^(N x M) the dictionary, or matrix of explanatory
%                   variables.
%
%      b          - In C^n is the data, assumed to be a multi-pitch signal
%                   observed in noise.
%
%
%      lambdaVec     - In [0,inf], sets the weighting of the ell_1 norm.
%
%      xiVecVec      - In [0,inf], sets the weighting of the ell_1 norm.
%
%      If the weighting parameters are vector they are applied pointwise, e.g.
%      for reweigthing.
%
%      BlockSizes - In N^d, Size of each block, i.e., [n_1, n_2, ... n_d].
%
%      OUTPUT
%
%      zk               - Estimate of x^*, the vector that solves the
%                         minization problem.
%
%      BlockSpectraOut  - vector containing 2-norm of each block in zk,
%                          i.e., [||z_[1]||_2, ..., ||z[d]||_2].
%

noIter = 5000;

[N,M]=size( A);
Lmax = BlockSizes(1);
Pfreq = M / Lmax;
noBlocks=length(BlockSizes);
CumSumBlockSizes=[0, cumsum(BlockSizes)];
mu=2; % ADMM tuning parameter
uk1=zeros(N,1);
uk2=zeros(M,1);
uk3 = zeros(M+Pfreq,1);
dk1=zeros(N,1);
dk2=zeros(M,1);
dk3 = zeros(M+Pfreq,1);
warmstart=1;
if warmstart==1 && ~isempty(Uin)
    uk1 = Uin{1};
    uk2 = Uin{2};
    uk3 = Uin{3};
    dk1 = Din{1};
    dk2 = Din{2};
    dk3 = Din{3};
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Fblock = diag([1 -ones(1,Lmax-1)]) + diag(ones(1,Lmax-1),-1);
Fblock = [Fblock;[zeros(1,Lmax-1),1]];
F = kron(eye(Pfreq),Fblock);
FtFpI = sparse(F'*F) + sparse(eye(M));
[Lp, Up,P] = factor2(A,F, 1);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


QUIET=1;
if ~QUIET
    fprintf('%3s\t%1s\t%10s\t%10s\t%10s\t%10s\t%10s\n', 'iter', ...
        'r norm', 'eps pri', 's norm', 'eps dual', 'objective','mu');
end
BlockSpectraOut=zeros(noBlocks, length(lambdaVec));
lambiter=1;
lambda=lambdaVec;
for i=1:noIter
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % zk update
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    Gq= A'*( uk1+dk1)+( uk2+dk2) + F'*(uk3+dk3);
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    doMatrixInversionL=1;
    if doMatrixInversionL==1
        IpFFTinvGq=FtFpI \ (Gq) ;
        zk=IpFFTinvGq- ( FtFpI \ (A'*(Up\(Lp\(P*(A*IpFFTinvGq))))));
    end
    
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %Store old uk1-3, update Azk and zk2, zk3 variables
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    uk1old=uk1;
    uk2old=uk2;
    uk3old=uk3;
    Azk=A*zk;
    zk2=zk; % Corresponding to uk2;
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    zk3 = F*zk;
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %Over-relaxation, speeds up convergence a factor 3-5
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    Azk=alphak*Azk+(1-alphak)*uk1;
    zk2=alphak*zk2+(1-alphak)*uk2;
    zk3=alphak*zk3+(1-alphak)*uk3;
    
    
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %Update uk1, i.e., the ell_1 or ell_2 part of u
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    if ell==2
        uk1=(b+mu*(Azk-dk1))/(1+mu);
    end
    if ell==1
        uk1=softthreshold(Azk-dk1-b,1/mu,1)+b;
    end
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %Update uk2, i.e., the block ell-2 and ell-1 part of u
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    uk2=softthreshold(zk2-dk2,lambda/mu,1);
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %Update uk3, i.e., the TV part of u
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    uk3=softthreshold(zk3-dk3,xiVec/mu,1);
    
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %Update dkX, i.e., the dual variables
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    dk1=dk1-(Azk-uk1);
    dk2=dk2-(zk2-uk2);
    dk3=dk3-(zk3-uk3);
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % Convergence check and update of penalty factor mu
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    sk=mu*(A'*(uk1-uk1old)+uk2-uk2old + F'*(uk3-uk3old));
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    rk= [Azk-uk1;zk2-uk2;zk3-uk3];
    if (norm(rk) > rho*norm(sk)) && (i<3000)
        mu=2*mu;
        dk1=dk1/2;
        dk2=dk2/2;
        dk3=dk3/2;
    elseif (norm(sk)>rho*norm(rk)) &&( i<3000)
        mu=mu/2;
        dk1=dk1*2;
        dk2=dk2*2;
        dk3=dk3*2;
    end
    if ell==2
        ABSTOL=1e-5;
        RELTOL=1e-6;
    end
    if ell==1
        ABSTOL=1e-5;
        RELTOL=1e-6;
    end
    
    history.mu(i)=mu;
    history.objval(i)  = objective(b, lambda,uk1, uk2, uk3,ell,xiVec);
    history.r_norm(i)  = norm(rk);
    history.s_norm(i)  = norm(sk);
    normzk2=sum(abs(zk).^2);
    normAzk=sum(abs(Azk).^2);
    normGzk=sqrt(normzk2*2+normAzk);
    
    history.eps_pri(i) = sqrt(3*M)*ABSTOL + RELTOL*max(normGzk, norm(-[uk1;uk2;uk3]));
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    Ftdk3 = F'*dk3;
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    history.eps_dual(i)= sqrt(2*M+N)*ABSTOL + RELTOL*norm(mu*[A.'*dk1;dk2;Ftdk3]);
    
    %
    if ~QUIET
        fprintf('%3d\t%10.4f\t%10.4f\t%10.4f\t%10.4f\t%10.2f\t%10.2f\n', i, ...
            history.r_norm(i), history.eps_pri(i), ...
            history.s_norm(i), history.eps_dual(i), history.objval(i), history.mu(i));
    end
    
    if (history.r_norm(i) < history.eps_pri(i) && ...
            history.s_norm(i) < history.eps_dual(i))
        break;
    end
    
    
end
zk = uk2;
BlockSpectraOut(:,lambiter)=blockspectra(CumSumBlockSizes(2:end), zk);


U = cell(3,1); U{1} = uk1; U{2}=uk2; U{3} = uk3;
D = cell(3,1); D{1} = dk1; D{2}=dk2; D{3} = dk3;
end

function xout=softthreshold(xin,lambda,p)
if p == 1
    maxf=max(abs(xin)-lambda,0);
    xout=xin.*maxf./(lambda+maxf);
end
if p == 2
    norm_xin=norm(xin);
    maxf=max(norm_xin-lambda,0);
    xout=xin*maxf/(lambda+maxf);
end
end

function [L U P] = factor2(A,F, rho)
[m, n]=size(A);
[L, U, P] = lu( speye(m) + 1/rho*(A*(sparse(eye(n)+F'*F)\A')) );
% force matlab to recognize the upper / lower triangular structure
L = sparse(L);
U = sparse(U);
P=sparse(P);
end


function p = objective(b, lambda,u1, u2, u3,ell,gamma)
% This function is based on a similar function in:
% http://www.stanford.edu/~boyd/papers/admm/group_lasso/group_lasso.html
p=0;
if ell==2
    p = p+ 1/2*sum(abs(u1 - b).^2);
end
if ell==1
    p = p+ sum(abs(u1 - b));
end
if length(lambda)>1
    p=p+lambda'*abs(u2) +gamma'*abs(u3);
else
    p=p+lambda*norm(u2,1) +gamma*norm(u3,1);
end

end

function p = blockspectra(cum_part, x)
start_ind = 1;
for i = 1:length(cum_part),
    sel = start_ind:cum_part(i);
    p(i) = norm(x(sel));
    start_ind = cum_part(i) + 1;
end
end

