function xest= vtsBinaryImputation( data, mask, nest, nvar, gmm_x, imputationAlgorithm, vtsOrder )

%--------------------------------------------------
%Definition of anonymous functions for the function
%--------------------------------------------------

g0= @(x, n) log( 1 + exp(n-x) );
f0= @(x, n) 1./(1+exp(n-x));

%--------------------------------------------------
% Function definition
%--------------------------------------------------

switch imputationAlgorithm
    case 'simple'
        impute= @vtsSimpleMapImputationFast;
    case 'simple2'
        impute= @vtsSimpleMapImputationFast2;        
    case 'mapr'
        impute= @vtsMapImputationRFast;
    case 'mapr2'
        impute= @vtsMapImputationRFast2;
    case 'bcmir'
        impute= @vtsBCMImputationRFast;
    case 'mapx'
        impute= @vtsMapImputationXFast;
    case 'bcmix'
        impute= @vtsBCMImputationXFast;
end

xest= data;

[params frames]= size(data);
K= gmm_x.ncentres;
p= zeros(K, 1);
for t= 1:frames
    y= data(:,t);
    m= mask(:,t);
    n= nest(:,t);
    sn= nvar(:,t);
    u= length(find(~m));

    %Check if there are any unreliable component
    if u>0
        xr= y(m);
        yu= y(~m);
        nu= n(~m);
        sn_u= sn(~m);

        %Split the clean GMM into reliable and unreliable parts and obtain
        %the GMM for the noisy speech
        [gmm_xr, gmm_xu_r]= splitGMM( gmm_x, y, m );
        
        %Precompute terms used in VTS adaptation for speed
        G0= zeros(u, K);
        F0= zeros(u, K);
        for k= 1:K
            G0(:,k)= g0( gmm_xu_r.centres(k,:)', nu );
            F0(:,k)= f0( gmm_xu_r.centres(k,:)', nu );
        end
        %gmm_yu= adaptGMM( gmm_xu_r, nu, sn_u, vtsOrder );
        gmm_yu= adaptGMMFast( gmm_xu_r, nu, sn_u, vtsOrder, G0, F0 );
        
        %Compute the posteriors for every Gaussian
        for k= 1:K
            pr= 1;
            if params-u > 0 
                pr= probability(gmm_xr, k, xr);
            end
            p(k)= pr * probability(gmm_yu, k, yu) * gmm_x.priors(k);
        end
        p= p/sum(p);
        
        %Compute the partial estimates
        Mest= zeros(u, K);
        for k= 1:K
            Mest(:,k)= impute( gmm_xu_r, gmm_yu, k, nu, sn_u, yu, G0(:,k), F0(:,k) );
        end
        
        %Combine the partial estimates according to their posterior
        %probabilities
        xest(~m,t)= sum(bsxfun(@times, Mest, p'), 2);
     end
end


%-------------------------------------------------------------------------
function gmm_y= adaptGMM( gmm_x, nest, nvar, order )

forceSymmetric= @(X) triu(X) + triu(X, 1)';  %Anonymous function definition
gmm_y= gmm_x;

for k= 1:gmm_x.ncentres
    mx= gmm_x.centres(k,:)';
    g0= g( mx, nest );
    
    %Mean adaptation
    gmm_y.centres(k,:)= (mx + g0)';
    
    %Variance adaptation
    if order > 0
        f0= f( mx, nest );

        if strcmp(gmm_x.covar_type,'diag')
            sx= gmm_x.covars(k,:)';
            gmm_y.covars(k,:)= (f0.^2 .* sx + (1-f0).^2 .* nvar)';
        else
            Sx= gmm_x.covars(:,:,k);
            F0= diag(f0);
            gmm_y.covars(:,:,k)= forceSymmetric( F0*Sx*F0 + diag((1-f0).^2 .* nvar) );
        end
    end
end


%-------------------------------------------------------------------------
function gmm_y= adaptGMMFast( gmm_x, nest, nvar, order, G0, F0 )

forceSymmetric= @(X) triu(X) + triu(X, 1)';  %Anonymous function definition
gmm_y= gmm_x;

for k= 1:gmm_x.ncentres
    %Mean adaptation
    mx= gmm_x.centres(k,:)';
    gmm_y.centres(k,:)= (mx + G0(:,k))';
    
    %Variance adaptation
    if order > 0
        if strcmp(gmm_x.covar_type,'diag')
            sx= gmm_x.covars(k,:)';
            gmm_y.covars(k,:)= (F0(:,k).^2 .* sx + (1-F0(:,k)).^2 .* nvar)';
        else
            Sx= gmm_x.covars(:,:,k);
            gmm_y.covars(:,:,k)= forceSymmetric( diag(F0(:,k))*Sx*diag(F0(:,k)) + diag((1-F0(:,k)).^2 .* nvar) );
        end
    end
end


%-------------------------------------------------------------------------
function y= g ( x, n )
y= log( 1 + exp(n-x) );


%--------------------------------------------------------------------------
function y= f( x, n )
y= 1./(1+exp(n-x));


%--------------------------------------------------------------------------
function p= probability( gmm, k, x )

% if isempty( gmm )
%     p= 1;
% else
mean= gmm.centres(k,:)';
if strcmp(gmm.covar_type,'diag')
    s= 1./sqrt(gmm.covars(k,:)');
    z= s.*(x-mean);
    p= prod( normpdf(z).*s );
else
    S= gmm.covars(:,:,k);
    p= mvnpdf( x, mean, S );
end
% end


%--------------------------------------------------------------------------
% function Y= forceSymmetric( X )
% Y= triu(X) + triu(X, 1)';


%--------------------------------------------------------------------------
function xest= vtsSimpleMapImputation( gmm_xu, gmm_yu, k, nest, nvar, y )
xest= y - g(gmm_xu.centres(k,:)', nest);
xest= min(y, xest);


%--------------------------------------------------------------------------
function xest= vtsSimpleMapImputationFast( gmm_xu, gmm_yu, k, nest, nvar, y, g0, f0 )
xest= y - g0;
xest= min(y, xest);


%--------------------------------------------------------------------------
function xest= vtsSimpleMapImputationFast2( gmm_xu, gmm_yu, k, nest, nvar, y, g0, f0 )
xest= y - g0;


%--------------------------------------------------------------------------
function varargout= vtsMapImputationR( gmm_xu, gmm_yu, k, nest, nvar, y )

mx= gmm_xu.centres(k,:)';
my= gmm_yu.centres(k,:)';

g0= g( mx, nest );
f0= f( mx, nest );
k0= f0-1;
h0= 1-f0;

if strcmp(gmm_xu.covar_type,'diag')
    sx= gmm_xu.covars(k,:)';
    sy= gmm_yu.covars(k,:)';
    sr= k0.^2 .* sx +  h0.^2 .* nvar;
    sry= k0.*sx.*f0 + h0.^2.*nvar;
    z= (y-my)./sy;
    m_r_y= g0 + sry.*z;
    S_r_y= diag( sr - sry.^2./sy );
else
    Sx= gmm_xu.covars(:,:,k);
    Sy= gmm_yu.covars(:,:,k);
    Snh0= diag(h0.^2 .* nvar);
    Sr= diag(k0)*Sx*diag(k0) + Snh0;
    Sry= diag(k0)*Sx*diag(f0) + Snh0;
%     z= Sy\(y-my);
%     m_r_y= g0 + Sry*z;
%     S_r_y= Sr - Sry*(Sy\Sry');
    A= Sry/Sy;
    m_r_y= g0 + A*(y-my);
    S_r_y= Sr - A*Sry';
end

xest= min(y-m_r_y, y);
varargout{1}= xest;
if nargout > 1
    varargout{2}= m_r_y;
    varargout{3}= S_r_y;
end


%--------------------------------------------------------------------------
function varargout= vtsMapImputationRFast( gmm_xu, gmm_yu, k, nest, nvar, y, g0, f0 )

mx= gmm_xu.centres(k,:)';
my= gmm_yu.centres(k,:)';

k0= f0-1;
h0= 1-f0;

if strcmp(gmm_xu.covar_type,'diag')
    sx= gmm_xu.covars(k,:)';
    sy= gmm_yu.covars(k,:)';
    sr= k0.^2 .* sx +  h0.^2 .* nvar;
    sry= k0.*sx.*f0 + h0.^2.*nvar;
    z= (y-my)./sy;
    m_r_y= g0 + sry.*z;
    S_r_y= diag( sr - sry.^2./sy );
else
    Sx= gmm_xu.covars(:,:,k);
    Sy= gmm_yu.covars(:,:,k);
    Snh0= diag(h0.^2 .* nvar);
    Sr= diag(k0)*Sx*diag(k0) + Snh0;
    Sry= diag(k0)*Sx*diag(f0) + Snh0;
%     z= Sy\(y-my);
%     m_r_y= g0 + Sry*z;
%     S_r_y= Sr - Sry*(Sy\Sry');
    A= Sry/Sy;
    m_r_y= g0 + A*(y-my);
    S_r_y= Sr - A*Sry';
end

xest= min(y-m_r_y, y);
varargout{1}= xest;
if nargout > 1
    varargout{2}= m_r_y;
    varargout{3}= S_r_y;
end


%--------------------------------------------------------------------------
function varargout= vtsMapImputationRFast2( gmm_xu, gmm_yu, k, nest, nvar, y, g0, f0 )

mx= gmm_xu.centres(k,:)';
my= gmm_yu.centres(k,:)';

k0= f0-1;
h0= 1-f0;

if strcmp(gmm_xu.covar_type,'diag')
    sx= gmm_xu.covars(k,:)';
    sy= gmm_yu.covars(k,:)';
    sr= k0.^2 .* sx +  h0.^2 .* nvar;
    sry= k0.*sx.*f0 + h0.^2.*nvar;
    z= (y-my)./sy;
    m_r_y= g0 + sry.*z;
    S_r_y= diag( sr - sry.^2./sy );
else
    Sx= gmm_xu.covars(:,:,k);
    Sy= gmm_yu.covars(:,:,k);
    Snh0= diag(h0.^2 .* nvar);
    Sr= diag(k0)*Sx*diag(k0) + Snh0;
    Sry= diag(k0)*Sx*diag(f0) + Snh0;
%     z= Sy\(y-my);
%     m_r_y= g0 + Sry*z;
%     S_r_y= Sr - Sry*(Sy\Sry');
    A= Sry/Sy;
    m_r_y= g0 + A*(y-my);
    S_r_y= Sr - A*Sry';
end

xest= y - m_r_y;
varargout{1}= xest;
if nargout > 1
    varargout{2}= m_r_y;
    varargout{3}= S_r_y;
end


%--------------------------------------------------------------------------
function xest= vtsBCMImputationR ( gmm_xu, gmm_yu, k, nest, nvar, y )

[xest, m_r_y, S_r_y]= vtsMapImputationR( gmm_xu, gmm_yu, k, nest, nvar, y );

s_r_y= sqrt(abs(diag(S_r_y)));
z= m_r_y./s_r_y;
q= s_r_y.*normpdf(z)./normcdf(z);
q(isnan(q) | isinf(q))= 0;
m_r_y2= m_r_y + q;

xest= min( y-m_r_y2, y );


%--------------------------------------------------------------------------
function xest= vtsBCMImputationRFast ( gmm_xu, gmm_yu, k, nest, nvar, y, g0, f0 )

[xest, m_r_y, S_r_y]= vtsMapImputationRFast( gmm_xu, gmm_yu, k, nest, nvar, y, g0, f0 );

s_r_y= sqrt(abs(diag(S_r_y)));
z= m_r_y./s_r_y;
q= s_r_y.*normpdf(z)./normcdf(z);
q(isnan(q) | isinf(q))= 0;
m_r_y2= m_r_y + q;

xest= min( y-m_r_y2, y );


%--------------------------------------------------------------------------
function varargout= vtsMapImputationX ( gmm_xu, gmm_yu, k, nest, nvar, y )

mx= gmm_xu.centres(k,:)';
my= gmm_yu.centres(k,:)';

f0= f( mx, nest );

if strcmp(gmm_xu.covar_type,'diag')
    sx= gmm_xu.covars(k,:)';
    sy= gmm_yu.covars(k,:)';
    sxy= sx .* f0;
    z= (y-my)./sy;
    m_x_y= mx + sxy.*z;
    S_x_y= diag( sx - sxy.^2./sy );
else
    Sx= gmm_xu.covars(:,:,k);
    Sy= gmm_yu.covars(:,:,k);
    Sxy= Sx*diag(f0);
%     z= Sy\(y-my);
%     m_x_y= mx + Sxy*z;
%     S_x_y= Sx - Sxy*(Sy\Sxy');    
    A= Sxy/Sy;
    m_x_y= mx + A*(y-my);
    S_x_y= Sx - A*Sxy';
end

xest= min(m_x_y, y);
varargout{1}= xest;
if nargout > 1
    varargout{2}= m_x_y;
    varargout{3}= S_x_y;
end


%--------------------------------------------------------------------------
function varargout= vtsMapImputationXFast ( gmm_xu, gmm_yu, k, nest, nvar, y, g0, f0 )

mx= gmm_xu.centres(k,:)';
my= gmm_yu.centres(k,:)';

if strcmp(gmm_xu.covar_type,'diag')
    sx= gmm_xu.covars(k,:)';
    sy= gmm_yu.covars(k,:)';
    sxy= sx .* f0;
    z= (y-my)./sy;
    m_x_y= mx + sxy.*z;
    S_x_y= diag( sx - sxy.^2./sy );
else
    Sx= gmm_xu.covars(:,:,k);
    Sy= gmm_yu.covars(:,:,k);
    Sxy= Sx*diag(f0);
%     z= Sy\(y-my);
%     m_x_y= mx + Sxy*z;
%     S_x_y= Sx - Sxy*(Sy\Sxy');    
    A= Sxy/Sy;
    m_x_y= mx + A*(y-my);
    S_x_y= Sx - A*Sxy';
end

xest= min(m_x_y, y);
varargout{1}= xest;
if nargout > 1
    varargout{2}= m_x_y;
    varargout{3}= S_x_y;
end


%--------------------------------------------------------------------------
function xest= vtsBCMImputationX ( gmm_xu, gmm_yu, k, nest, nvar, y )

[xest, m_x_y, S_x_y]= vtsMapImputationX( gmm_xu, gmm_yu, k, nest, nvar, y );

s_x_y= sqrt(abs(diag(S_x_y)));
z= (y-m_x_y)./s_x_y;
q= s_x_y.*normpdf(z)./normcdf(z);
q(isnan(q) | isinf(q))= 0;
m_x_y2= m_x_y - q;

xest= min( m_x_y2, y );


%--------------------------------------------------------------------------
function xest= vtsBCMImputationXFast ( gmm_xu, gmm_yu, k, nest, nvar, y, g0, f0 )

[xest, m_x_y, S_x_y]= vtsMapImputationXFast( gmm_xu, gmm_yu, k, nest, nvar, y, g0, f0 );

s_x_y= sqrt(abs(diag(S_x_y)));
z= (y-m_x_y)./s_x_y;
q= s_x_y.*normpdf(z)./normcdf(z);
q(isnan(q) | isinf(q))= 0;
m_x_y2= m_x_y - q;

xest= min( m_x_y2, y );