function varargout= forwardBackwardImputationFast( data, mask, hmm, imputationAlgorithm, probComputation )

VAR_RELIABLE= 0; %Variance of the estimate for the reliable components

switch imputationAlgorithm
    case 'map'
        impute= @mapImputationFast;
    case 'bcmi'
        impute= @bcmi2Fast;
end

est= data;
if nargout > 1
    varest= zeros( size(data) ) + VAR_RELIABLE; %Variance of the reliable components
end

%Split the clean HMM into reliable and unreliable parts for every time
%instant t
T= size(data,2);
K= hmm.ncentres;
gmm= hmm;
gmm.transp=[];
gmm_xr= cell(T, 1);
gmm_xu= cell(T, 1);
for t= 1:T
    [gmm_xr{t}, gmm_xu{t}]= splitGMM( gmm, data(:,t), mask(:,t) );
end

%1. Compute the forward-backward probabilities
B= computeObservationProbFast( data, mask, gmm_xr, gmm_xu, K, probComputation );
alfas= computeAlfas( hmm, B );
betas= computeBetas( hmm, B );
P= computeGammas( alfas, betas, hmm.priors', B ); %gammas probabilities

%2. Impute the missing values
for t=1:T
    y= data(:,t);
    m= mask(:,t);
    u= length(find(~m));
    %Check if there are any unreliable component
    if u>0        
        %Compute the partial estimates
        xest= zeros(u, K);
        if nargout > 1
            Sest= zeros( u, u, K );
        end
        for k=1:K
            if nargout > 1
                [xest(:,k) Sest(:,:,k)]= impute( y, m, gmm_xu{t}, k );
            else
                xest(:,k)= impute( y, m, gmm_xu{t}, k );
            end
        end
        
        %Combine the partial estimates according to their posterior
        %probabilities
        y(~m)= sum(bsxfun(@times, xest, P(:,t)'), 2);
        est(:,t)= min([data(:,t) y], [], 2); % -Inf <= xu(t) <= y(t)
        
        %Compute the variance for every estimate
        if nargout > 1
            diff= bsxfun(@minus, xest, y(~m));
            
            Si= zeros( u );
            for k=1:hmm.ncentres
                Si= Si + P(k,t)*(Sest(:,:,k) + diff(:,k)*diff(:,k)');
            end
            varest(~m,t)= diag(Si); %Variance of the unreliable components            
        end
    end
end

varargout{1}= est;
if nargout > 1
    varargout{2}= varest;
end


%--------------------------------------------------------------------------
function alfas= computeAlfas (hmm, observationP)

priors= hmm.priors';
A= hmm.transp;
B= observationP;
alfas= zeros( size(B) );

%Initialize the alfas
alfas(:,1)= priors.*B(:,1);
alfas(:,1)= alfas(:,1)/sum(alfas(:,1));

%Compute the alfa-probabilities for every frame
T= size(alfas,2);
for t=2:T
    alfas(:,t)= (A'*alfas(:,t-1)).*B(:,t); %alfas(i,t)= (sum_j alfas(j,t-1)*A(j,i))*B(i,t)
   
    %Reinitialize the alfas if the probability is too small
    if sum(alfas(:,t)) <= 1e-10
        aux= priors.*B(:,t-1); %Reinitialize the alfas in the previous frame
        aux= aux/sum(aux);
        alfas(:,t)= (A'*aux).*B(:,t);
    end
    
    alfas(:,t)= alfas(:,t)/sum(alfas(:,t));        
end


%--------------------------------------------------------------------------
function betas= computeBetas (hmm, observationP)

priors= hmm.priors';
A= hmm.transp;
B= observationP;
betas= zeros( size(B) );

%Initialize the betas
betas(:,end)= 1;

%Compute the beta-probabilities for every frame
T= size(betas,2);
for t=T-1:-1:1
    betas(:,t)= A*(B(:,t+1).*betas(:,t+1));

    %Reinitialize the betas if the probability is too small
    if sum(betas(:,t)) <= 1e-10
        betas(:,t)= A*(B(:,t+1).*priors);
    end
    
    betas(:,t)= betas(:,t)/sum(betas(:,t));        
end


%--------------------------------------------------------------------------
function gammas= computeGammas( alfas, betas, priors, observationP )

gammas= alfas.*betas;
v= sum(gammas);
idx= find(v <= 1e-10);
if ~isempty(idx)
    gammas(:,idx)= bsxfun(@times, observationP(:,idx), priors);
end
gammas= bsxfun(@times, gammas, 1./sum(gammas));


%--------------------------------------------------------------------------
function B= computeObservationProbFast( data, mask, hmm_xr, hmm_xu, K, probComputation )

switch probComputation
    case 'diag'
        probability= @boundedProbabilityDiagFast;
    case 'bcp'
        probability= @boundedProbability2Fast;
end

T= size(data,2);
B= zeros( K, T );
for t=1:T
    y= data(:,t);
    m= mask(:,t);    
    for k= 1:K
        B(k,t)= probability( y, m, hmm_xr{t}, hmm_xu{t}, k );
    end
end

%Normalize the observation probabilities
B= bsxfun(@times, B, 1./sum(B));
