function [fk_est,block_amps,fpgrid] = pebsi_lite_selfreg(y,fs,fmin,fmax,Lmax,Lsmin,Lsmax,tau)
% Implementation of the self-reguralized PEBSI-Lite multi-pitch estimator,
% as presented in "An adaptive penalty multi-pitch estimator with
% self-regularization" by F. Elvander, et al. 
% The paper can be found at http://dx.doi.org/10.1016/j.sigpro.2016.02.015
%
% INPUT
% y             -       complex valued data vector, N-times-1. 
% fs            -       sampling frequency in Hz.
% fmin          -       minimial considered pitch frequency, in Hz.
% fmax          -       maximal considered pitch frequency, in Hz.
% Lmax          -       maximal harmonic order, i.e., maximal number of
%                       harmonics per pitch. Default: 10.
% Lsmin         -       minimial number of considered sinusoids in the 
%                       signal. Default: 10.
% Lsmin         -       maximal number of considered sinusoids in the 
%                       signal. Default: 60.
% tau           -       variance threshold. Default: 0.1.
%
% OUTPUT
% fk_est        -       vector containing estimated fundamental frequencies 
%                       in Hz.
% block_amps    -       full pitch-amplitude vector for all considered
%                       pitch candidates. Could be used for post-processing
%                       and thresholding to achieve sparser solutions
% fpgrid         -       frequency vector corresponding to the considered
%                       pitch candidates. Use together with block_amps.
%
% EXAMPLE USAGE
% fk_est = pebsi_lite_selfreg(y,44.1e3,90,1000)
%
% NOTE TO USER
% Make sure that you have installed the CVX package, found at
% http://cvxr.com/cvx/download/, as well as added the Multi-pitch toolbox  
% found at http://www.morganclaypool.com/page/multi-pitch and the folder
% pebsi_aux_func to your MATLAB path before running this function.
%
% Filip Elvander, March 18 2016.

%% Default values

if nargin<8
    tau = 0.1;
    if nargin< 8
        Lsmax = 60;
    end
    if nargin < 6
        Lsmin = 10;
    end
    if nargin < 5
       Lmax = 10; 
    end
end


%% Some additional internal options

upperLambdaLimit = 1; % upper lambda-limit to safe-guard against infinite loops
adjustWithNLS = 1; % set to "1" in order to refine the final estimates with non-linear LS
doFreqMatAdjust = 1; % Adjust each individual harmonic frequency
gridAdjustTol = 3; % tolerance in Hz for adjusting harmonics of each pitch
sinusoidSafetyMargin = 2; % safety margin for safeguarding against BIC selecting to few sinusoidal components

doPrintOuts = 1; % set to zero to avoid print-outs

%% Normalize y

N = length(y);
t = 0:N-1;
y_per = 1/sqrt(N)*abs(fft(y,fs));
y = y/max(y_per);

%% Estimate individual sinousoidal frequencies

M = floor(length(y)/3);
BIC_esprit = zeros(Lsmax-Lsmin+1,1);
sol_esprit = cell(Lsmax-Lsmin+1,1);
err_var_esprit = zeros(Lsmax-Lsmin+1,1);
min_amp_esprit = zeros(Lsmax-Lsmin+1,1);

for kk=Lsmin:Lsmax
    index = kk - Lsmin+1;
    wtemp = freq_shiftinv(y,kk,M);

    Xtemp = exp(1i*t(:)*wtemp(:).');
    amptemp = Xtemp\y;
    
    yrecon = exp(1i*t(:)*wtemp(:).') * amptemp;
    err = y - yrecon;
    sigma2hat = var(err);
    err_var_esprit(index) = sigma2hat;
    min_amp_esprit(index) = min(abs(amptemp));
    BIC_esprit(index)=length(y)*log(sigma2hat)+...
        (5/2*kk)*log(length(y));
    sol_esprit{index} = {wtemp,amptemp};
end

[min_bic,bic_index] = min(BIC_esprit);

%% Build frequency dictionary

% Take optimal BIC order, plus some marginal
margin_order = min(bic_index+sinusoidSafetyMargin,Lsmax-Lsmin+1);
est_freqs = sort(sol_esprit{margin_order}{1}*fs/(2*pi));

fpgrid = est_freqs(est_freqs <= fmax)';
% Map estimates to fine grid for better robustness
if 1
    large_grid = fmin:.1:fmax;
    fpgrid = map_to_grid(fpgrid,large_grid);
    fpgrid = fpgrid(fpgrid<=fmax);
    fpgrid = fpgrid(fpgrid>=fmin);
end
P = length(fpgrid);
try
    freqMat = (1:Lmax).'*fpgrid;
catch
    fk_est = [];
    block_amps = [];
    fpgrid = est_freqs(:)';
    return
end

% Adjust each individual harmonic frequency
if doFreqMatAdjust
    freqMat = (1:Lmax).'*fpgrid;
    [freqMatAdj,Lkest] = adjust_freqMat(freqMat,est_freqs,gridAdjustTol);
    freqMat = freqMatAdj;
end

W = exp(2i*pi*t(:)*freqMat(:).'/fs)/sqrt(N);

if doPrintOuts
   fprintf('Dictionary constructed\n') 
end

%% Least-squares estimate of phases

a_pLS = pseudoLS(y,freqMat,t,fs);
phi = angle(a_pLS);
W = W*diag(exp(1i*phi));

%% Line-search for optimal regularization level

LkestMax = max(Lkest);

lambda_vec = [.001,.005,0.01:0.005:upperLambdaLimit];

block_sparsity_save = zeros(length(lambda_vec),1);
sigma2_save = zeros(length(lambda_vec),1);
solution_save = cell(length(lambda_vec),1);

BlockSizes = Lmax*ones(P,1).';
U = []; D = [];
xADMM = zeros(P*Lmax,1);
phi = zeros(size(xADMM));
Witer = W;
rho=5; alphak=1.8; ell = 2;
lambda_4_initFactor=10;

for kk=1:length(lambda_vec)
    noWarm=1;
    if noWarm
        U = []; D = [];
        phi = zeros(size(xADMM));
        Witer = W;
    end
    %%%%%%%%%%%%%%
    
    lambda = lambda_vec(kk);
    lambda2 = lambda;
    lambda4 = LkestMax/2*lambda;

    rew_l1 = 1; rew_tv = 1;
    
    maxDL = 5;
    for ii = 1:maxDL
        
        Witer = Witer*diag(exp(1i*phi)); %DL-update
        lambdaVec = rew_l1*lambda2; % reweighting l1
        
        if ii==1
            xiVec = rew_tv*lambda4/lambda_4_initFactor; % reweighting tv
        else
            xiVec = rew_tv*lambda4;
        end
        [xADMM, BlockSpectraOutTV, history,U,D]=PEBSI_ADMM(Witer,y,BlockSizes,lambdaVec,xiVec,alphak,rho,ell,U,D);
        
        
        if max(abs(angle(xADMM)))>1e-2 %if change in phase is small, then cancel DL
            phi = angle(xADMM);
        else
            phi = zeros(size(xADMM));
        end
        
        doRew = 1;
        if doRew && ii>1
            rew_l1 = 1./(abs(xADMM) + .01);
            rew_tv = ones(length(xADMM)+P,1);
        end
    end
    z = xADMM;
   
    
    max_z = max(abs(z));
    if max_z >0
        xADMM = z;
    end
    z_reshape = reshape(z,Lmax,P);
    block_sparsity_save(kk) = sum(norms(z_reshape,2,1)>0);
    sigma2hat = get_sigma2hat(Witer,freqMat,z,y);
    sigma2_save(kk) = sigma2hat;
    solution_save{kk} = {[lambda2,lambda4],z};
    
    % Break if the variance has increased over threshold
    rel_var_tol=tau;
    if (sigma2_save(kk) -sigma2_save(1))/var(y) >=rel_var_tol
        break
    end
    
    if sum(abs(z)>0) == 0
        break
    end
end

block_sparsity_save = block_sparsity_save(1:kk);
sigma2_save = sigma2_save(1:kk);
solution_save = solution_save(1:kk);
%% Decide optimal solution and return

rel_var_tol = tau;
d_sigma2 = (sigma2_save-sigma2_save(1))/var(y);
opt_index = max(1,find(d_sigma2 >= rel_var_tol,1,'first')-1);

nbr_pitch_opt = block_sparsity_save(opt_index);
try
    opt_index = min(find(block_sparsity_save==nbr_pitch_opt));
catch
    'help';
    opt_index = 1;
end

if isempty(opt_index)
    solution_opt = solution_save{1}{2};
else
    solution_opt = solution_save{opt_index}{2};
end


sol_reshape = reshape(solution_opt,Lmax,P);
block_amps = norms(sol_reshape,2,1);
nonzero_blocks = find(block_amps>0);

%%
% OPTIONAL: adjust final estimate using non-linear least squares
if adjustWithNLS
    [fk_adjust,fk_orig]=nls_adjustment_constrained(y,t,fs,solution_opt,fpgrid,Lmax);
else
    sol_reshape = reshape(solution_opt,Lmax,length(fpgrid));   
    nonzero_pitches = find(norms(sol_reshape,2,1)>0);
    fk_adjust = fpgrid(nonzero_pitches);
end
fk_est = fk_adjust;

if doPrintOuts
   fprintf('Estimation complete\n') 
end
end