function varargout= forwardBackwardImputation( data, mask, hmm, imputationAlgorithm, probComputation )

VAR_RELIABLE= 0; %Variance of the estimate for the reliable components

switch imputationAlgorithm
    case 'bmap'
        impute= @bmapImputation;
    case 'map'
        impute= @mapImputation;
    case 'faubel'
        impute= @boundedConditionalMeanImputation;
    case 'bcmi'        
        impute= @boundedConditionalMeanImputation2;
    case 'eig'
        impute= @boundedConditionalMeanImputationEig;
    case 'svd'
        impute= @boundedConditionalMeanImputationSvd;
end

est= data;
if nargout > 1
    varest= zeros( size(data) ) + VAR_RELIABLE; %Variance of the reliable components
end

%1. Compute the forward-backward probabilities
B= computeObservationProb( data, mask, hmm, probComputation );
alfas= computeAlfas( hmm, B );
betas= computeBetas( hmm, B );
P= computeGammas( alfas, betas, hmm.priors', B ); %gammas probabilities

%2. Impute the missing values
T= size(data,2);
K= hmm.ncentres;
for t=1:T
    y= data(:,t);
    m= mask(:,t);
    u= length(find(~m));
    %Check if there are any unreliable component
    if u>0        
        %Compute the partial estimates
        xest= zeros(u, K);
        if nargout > 1
            Sest= zeros( u, u, K );
        end
        for k=1:K
            if nargout > 1
                [xest(:,k) Sest(:,:,k)]= impute( y, m, hmm, k );
            else
                xest(:,k)= impute( y, m, hmm, k );
            end
        end
        
        %Combine the partial estimates according to their posterior
        %probabilities
        y(~m)= sum(bsxfun(@times, xest, P(:,t)'), 2);
        est(:,t)= min([data(:,t) y], [], 2); % -Inf <= xu(t) <= y(t)
        
        %Compute the variance for every estimate
        if nargout > 1
            diff= bsxfun(@minus, xest, y(~m));
            
            Si= zeros( u );
            for k=1:hmm.ncentres
                Si= Si + P(k,t)*(Sest(:,:,k) + diff(:,k)*diff(:,k)');
            end
            varest(~m,t)= diag(Si); %Variance of the unreliable components            
        end
    end
end

varargout{1}= est;
if nargout > 1
    varargout{2}= varest;
end


%--------------------------------------------------------------------------
function alfas= computeAlfas (hmm, observationP)

priors= hmm.priors';
A= hmm.transp;
B= observationP;
alfas= zeros( size(B) );

%Initialize the alfas
alfas(:,1)= priors.*B(:,1);
alfas(:,1)= alfas(:,1)/sum(alfas(:,1));

%Compute the alfa-probabilities for every frame
T= size(alfas,2);
for t=2:T
    alfas(:,t)= (A'*alfas(:,t-1)).*B(:,t); %alfas(i,t)= (sum_j alfas(j,t-1)*A(j,i))*B(i,t)
   
    %Reinitialize the alfas if the probability is too small
    if sum(alfas(:,t)) <= 1e-10
        aux= priors.*B(:,t-1); %Reinitialize the alfas in the previous frame
        aux= aux/sum(aux);
        alfas(:,t)= (A'*aux).*B(:,t);
    end
    
    alfas(:,t)= alfas(:,t)/sum(alfas(:,t));        
end


%--------------------------------------------------------------------------
function betas= computeBetas (hmm, observationP)

priors= hmm.priors';
A= hmm.transp;
B= observationP;
betas= zeros( size(B) );

%Initialize the betas
betas(:,end)= 1;

%Compute the beta-probabilities for every frame
T= size(betas,2);
for t=T-1:-1:1
    betas(:,t)= A*(B(:,t+1).*betas(:,t+1));

    %Reinitialize the betas if the probability is too small
    if sum(betas(:,t)) <= 1e-10
        betas(:,t)= A*(B(:,t+1).*priors);
    end
    
    betas(:,t)= betas(:,t)/sum(betas(:,t));        
end


%--------------------------------------------------------------------------
function gammas= computeGammas( alfas, betas, priors, observationP )

gammas= alfas.*betas;
v= sum(gammas);
idx= find(v <= 1e-10);
if ~isempty(idx)
    gammas(:,idx)= bsxfun(@times, observationP(:,idx), priors);
end
gammas= bsxfun(@times, gammas, 1./sum(gammas));


%--------------------------------------------------------------------------
function B= computeObservationProb( data, mask, hmm, probComputation )

switch probComputation
    case 'diag'
        probability= @boundedProbabilityDiag;
    case 'faubel'
        probability= @boundedProbabilityFaubel;
    case 'bcp'
        probability= @boundedProbability2;
    case 'eig'
        probability= @boundedProbabilityEig;
    case 'svd'
        probability= @boundedProbabilitySvd;        
end

T= size(data,2);
K= hmm.ncentres;
B= zeros( K, T );
for t=1:T
    y= data(:,t);
    m= mask(:,t);
   
    for k=1:K
        B(k,t)= probability( y, m, hmm, k );
    end
end

%The probability routine multiplies every element in B(k,:) by the a priori
%probability hmm.priors(k), so we divide B(k,:) by the prior to remove this
%effect.
B= bsxfun(@times, B, 1./hmm.priors');

%Normalize the observation probabilities
B= bsxfun(@times, B, 1./sum(B));
