function [w_rls_hist,fpgridHist] = PEARLS(d,lambda,rls_xi,Lmax,fs,fmin,fmax,fdist)
%
%   This is an implementation of the PEARLS algorithm described in "Online
%   Estimation of Multiple Harmonic Signals" by Elvander et. al., published
%   in IEEE/ACM Transactions on Audio, Language, and Speech Processing.
%   DOI: 10.1109/TASLP.2016.2634118
%
%
% INPUT
% d         -       the complex valued signal, given as an N-times-1
%                   vector.
% lambda    -       forgetting factor, scalar on the interval (0,1)
% rls_xi    -       the smoothness parameter, called \xi in the paper.
% Lmax      -       the maximum number of considered harmonic for each
%                   pitch.
% fs        -       the sampling frequency in Hz.
% fmin      -       the minimum considered pitch frequency, in Hz.
% fmax      -       the maxmum considered pitch frequency, in Hz.
% fdist     -       the initial resolution of the candidate pitch grid,
%                   i.e., the distance in Hz between two consecutive 
%                   candidate pitches.
%
% OUTPUT
% w_rls_hist    -       the trajectory of the filter coefficients, i.e., 
%                       the amplitude of each harmonic of the candidate 
%                       pitches, given as an N-times-PLmax matrix where N
%                       is the length of the signal d, P is the number of
%                       candidate pitches and Lmax is the number of
%                       harmonics for each pitch.
% fpgridHist    -       the trajectory of the candidate pitch frequencies,
%                       given as an N-times-P matrix where N is the length
%                       of the signal d, and P is the number of candidate
%                       pitches.


%% SETTINGS AND INITIALIZATION

%%%%%%%%%%%%%% ADDITIONAL SETTINGS %%%%%%%%%%%%%%%%%
% Set to zero if no dictionary updated should be performed
doDictionaryLearning = 1;

% Set to zero if no update speed-up should be performed
doActiveUpdate = 1;

% Nbr of samples for dictionary update
nbrSamplesForPitch = floor(45*1e-3*fs); % use, e.g., 45 ms of the signal

% Settings for speed-up
waitingPeriod = 9e-3; % The waiting period during which a pitch block can be excluded from updating (ms).
blockUpdateThreshold = floor(fs*waitingPeriod); % As above, expressed in nbr of samples.
zeroUpdateThreshold = floor(blockUpdateThreshold/10); % Determine how often to check whether a pitch block should be set to zero
speedUpHorizon = blockUpdateThreshold; % Determine when to start activating/deactivating pitch blocks

% The number of dictyionary length stored in memory
dictionaryLength = 2000;

% Initial values for the penalty parameters
gamma = 4;
gamma2 = 80;

% The proximal gradient step-size
stepSize = 1e-4;
maxIter = 20; % maximum number of iterations

% Set to one to print sample number
doPrint = 1;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%%%%%%%%%%%%%%%% INITIALIZATION %%%%%%%%%%%%%%%%%%%%%%
N = length(d);
fpgrid = fmin:fdist:fmax; % the pitch frequency grid
P = length(fpgrid);
nbrOfVariables = P*Lmax;
freqMat = (1:Lmax)'*fpgrid;
t = 0:N-1;
tTemp = 0:dictionaryLength-1;
AInner = 2*pi*tTemp(:)*freqMat(:)'/fs; % original dictionary
AInnerNoPhase = AInner;
A = exp(1i*AInner);
AOld = A;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


%%%%%%%%%%%%% WINDOW FOR GAMMA UPDATE %%%%%%%%%%%%%%%%%%%%%
Delta = floor(log(0.01)/log(lambda));
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%%%%%%%%%%%%%%%%%%% INITIALIZE ALL VARIABLES %%%%%%%%%%%%%%
xn = A(1,:)';
dn = d(1);
Rn = xn*xn';
rn = xn*dn;
w_hat = zeros(nbrOfVariables,1);

% RLS
w_rls = zeros(nbrOfVariables,1);
w_rls_hist = zeros(nbrOfVariables,N);

fpgridHist = zeros(P,N);

% Counter for block-update
blockNotUpdatedSince = zeros(1,P);
hasBeenUntouchedSince = zeros(1,P);


activeBlocks = 1:P;
inactiveBlocks = [];
nbrActiveBlocks = zeros(N,1);

activeIndices = 1:P*Lmax;
indexMatrix = reshape(activeIndices,Lmax,P);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%% ALGORITHM

%%%%%%%%%%%%%%%%%%%% ALGORITHM BEGINS HERE %%%%%%%%%%%%%%%%%%%%%%%%
for n=1:N
    if doPrint && mod(n,100)==0
        fprintf('%d av %d\n',n,N)
    end
    nbrActiveBlocks(n) = length(activeBlocks);
    
    %%%%%%%%%%%%%%%%%%%%%%% SAVE PRESENT GRID %%%%%%%%%%%%%%%%%%%%%%
    fpgridHist(:,n) = fpgrid(:);
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    
    
    %%%%%%%%%%%%%%%%%%%%%%% NEW SAMPLE %%%%%%%%%%%%%%%%%%%%%%
    samplesLeft = (dictionaryLength-mod(n,dictionaryLength));
    if samplesLeft == dictionaryLength
        sampleIndex = dictionaryLength;
    else
        sampleIndex = mod(n,dictionaryLength);
    end
    xn = A(sampleIndex,:)';
    
    if samplesLeft == dictionaryLength
        AOld = A;
        upperTimeIndex = min(N,(n+dictionaryLength));
        tTemp = t((n+1):upperTimeIndex);
        if upperTimeIndex-n<dictionaryLength
           tTemp = [tTemp,zeros(1,dictionaryLength-(upperTimeIndex-n))];
        end
        AInner = 2*pi*tTemp(:)*freqMat(:)'/fs; % original dictionary
        AInnerNoPhase = AInner;
        A = exp(1i*AInner);
    end
    
    dn = d(n);
    Rn = lambda*Rn + bsxfun(@times, xn,xn');
    rn = lambda*rn + xn*dn;
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    
    
    
    %%%%%%%%%%%%%%%%%% UPDATE PENALTY PARAMETERS %%%%%%%%%%%%%%%%%
    if n>=Delta && mod(n,401)==0
        innerProdIndices = n-(Delta-1):n;
        if Delta>sampleIndex
            deltaDiff = Delta-sampleIndex;
            dForInnerProd = d(innerProdIndices);
            lambdaFactForInnerProd = lambda.^((Delta-1):-1:0)';
            AOldForInnerProd = AOld(end-(deltaDiff-1):end,:);
            AForInnerProd = A(1:sampleIndex,:);
            maxNorm = max(abs(AOldForInnerProd'*(dForInnerProd(1:deltaDiff).*lambdaFactForInnerProd(1:deltaDiff))+...
                AForInnerProd'*(dForInnerProd(deltaDiff+1:end).*lambdaFactForInnerProd(deltaDiff+1:end))));
        else    
            AInnerProdIndices = sampleIndex-(Delta-1):sampleIndex;
            maxNorm = max((abs(A(AInnerProdIndices,:)'*(d(innerProdIndices).*lambda.^((Delta-1):-1:0)'))));
        end
        gamma = 0.1*maxNorm;
        gamma2 = 1*maxNorm; 
    end
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    if doActiveUpdate
        % Check blocks -> update or not
        activationCandidates = find(blockNotUpdatedSince>blockUpdateThreshold);
        if ~isempty(activationCandidates)
            activeBlocks = sort(union(activeBlocks,activationCandidates));
            inactiveBlocks = setdiff((1:P),activeBlocks);
            newIndices = indexMatrix(:,activationCandidates);
            activeIndices = sort(union(activeIndices,newIndices(:)'));
            
            blockNotUpdatedSince(activeBlocks) = zeros(size(activeBlocks));
        end
        blockNotUpdatedSince(inactiveBlocks) = blockNotUpdatedSince(inactiveBlocks)+1;
    end
    
    
    % Update only active part
    w_ell = w_hat(activeIndices);
    Rn_small = Rn(activeIndices,activeIndices);
    rn_small = rn(activeIndices);
    
    %%%%%%%%%%%%%%%%%% FILTER UPDATE %%%%%%%%%%%%%%%%%%%
    w_ell = proximal_gradient_update(w_ell,Rn_small,rn_small,length(activeBlocks),Lmax,gamma,gamma2,maxIter,stepSize);
    w_hat(activeIndices) = w_ell;
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    
    %%%%%%%%%%%%%%%%%% RLS FILTER UPDATE %%%%%%%%%%%%%%%%%%
    if n>100
        w_rls_new = rls_update(w_rls(activeIndices),Rn_small,rn_small,Lmax,rls_xi); % Paper version
        w_rls = zeros(nbrOfVariables,1);
        w_rls(activeIndices) = w_rls_new; 
    end
    w_rls_hist(:,n) = w_rls;
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
    if doActiveUpdate
        hasBeenUntouchedSince = hasBeenUntouchedSince+1;
        
        % Set inactive part to zero
        if n>speedUpHorizon
            zeroCandidates = find(hasBeenUntouchedSince>zeroUpdateThreshold);
            if ~isempty(zeroCandidates)
                w_norms = norms(reshape(w_hat,Lmax,P),2,1);
                setToZero = find(w_norms<0.05); % 0.01
                setToZero = intersect(setToZero,zeroCandidates);
                
                if ~isempty(setToZero)
                    inactiveBlocks = sort(union(setToZero,inactiveBlocks));
                    activeBlocks = setdiff((1:P),inactiveBlocks);
                    activeIndices = indexMatrix(:,activeBlocks);
                    activeIndices = sort(activeIndices(:)');
                    hasBeenUntouchedSince(setToZero) = zeros(size(setToZero));
                end
            end
            hasBeenUntouchedSince = hasBeenUntouchedSince+1;
        end
    end
      
    %%%%%%%%%%%%%%%%%% DICTIONARY LEARNING SCHEME %%%%%%%%%%%%%%%%%%
    if doDictionaryLearning && (n >= 1000 && mod(n-1,100)==0) && max(w_norms)>0.01
        
        currentIndexTime = n;
        updateHorizon = 600;
        startIndexTime = max(n-nbrSamplesForPitch+1,2);
        stopIndexTime = currentIndexTime + updateHorizon;
        if stopIndexTime>N
            updateHorizon = updateHorizon - (stopIndexTime-N);
        end
        stopIndexTime = min(stopIndexTime,N);
        pitchLimit = fdist/2;
        
        refSignal = d(startIndexTime:currentIndexTime);
        startIndexCurrA = sampleIndex - nbrSamplesForPitch+1;
        currIndexCurrA = sampleIndex;
        stopIndexCurrA = currIndexCurrA+updateHorizon;
        
        % If we are at the end of a dictionary cycle, adjust
        indexDiff = stopIndexCurrA - dictionaryLength;
        if indexDiff>0
            stopIndexCurrA = min(dictionaryLength,stopIndexCurrA);
        end
               
        if startIndexCurrA<=0
            % we need to use old A
            startIndexOldA = dictionaryLength-(abs(startIndexCurrA));
            startIndexCurrA = 1;
            [ANew,AInnerNew,AInnerNoPhaseNew,AOldNew,fpgridNew,w_hatNew,changeFlag] = ...
                dictionaryUpdate(w_rls,refSignal,pitchLimit,A,AInner,AInnerNoPhase,fpgrid,t,fs,Lmax,P,...
                dictionaryLength,startIndexTime,stopIndexTime,currIndexCurrA,startIndexCurrA,stopIndexCurrA,AOld,startIndexOldA);
            if changeFlag
               AOld = AOldNew; 
            end
        else
            [ANew,AInnerNew,AInnerNoPhaseNew,~,fpgridNew,w_hatNew,changeFlag] = ...
                dictionaryUpdate(w_rls,refSignal,pitchLimit,A,AInner,AInnerNoPhase,fpgrid,t,fs,Lmax,P,...
                dictionaryLength,startIndexTime,stopIndexTime,currIndexCurrA,startIndexCurrA,stopIndexCurrA,[]);
        end
        if changeFlag
            A = ANew;
            AInner = AInnerNew;
            AInnerNoPhase = AInnerNoPhaseNew;
            fpgrid = fpgridNew;
        end
        
        
     end
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
end
%%%%%%%%%%%%%%%%%%%%% END OF ALGORITHM %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
