function rbm = rbmtrain(rbm, x, opts)

%Modified by Michi W. to support continuous input (in case 'opts.contInput'==1)
%In this case, the input is assumed to be zero mean and unit variance.
%According to Hinton, the learning rate has to be one or two orders of
%magnitude smaller than for binary RBM training.




% input is continuous (standard normal distributed)?
[contInput,opts] = checkParams(opts,'contInput',0);

% plot reconstruction error over epochs?
[doPlotError,opts] = checkParams(opts,'doPlotError',false);

% save plot to file
[plotFilename,~] = checkParams(opts,'plotFilename','');


% reset momentum from rbm.initMomentum to rbm.momentum this epoch
[resetMomentumAtEpoch,opts] = checkParams(opts,'resetMomentumAtEpoch',7);

if(isfield(opts,'sparsityTarget'))
    [sparsityLambda,opts] = checkParams(opts,'sparsityLambda',0.95);
    [sparsityCost,opts] = checkParams(opts,'sparsityCost',0.001);
end

if(doPlotError)
    scrsz = get(0,'ScreenSize');
    handle = figure('OuterPosition',[scrsz(1) scrsz(2) scrsz(3)/2 scrsz(4)]);
end

assert(isfloat(x)|islogical(x), 'x must be a float or logical');
m = size(x, 1);
numbatches = m / opts.batchsize;

assert(rem(numbatches, 1) == 0, 'numbatches not integer');

epochErr = zeros(opts.numepochs,1);


if(isfield(opts,'valData'))
    freeEnergyTrain = zeros(opts.numepochs,1);
    freeEnergyVal = zeros(opts.numepochs,1);
end

for i = 1 : opts.numepochs
    kk = randperm(m);
    err = 0;
    
    if(i<resetMomentumAtEpoch)
        momentum = rbm.initMomentum;
    else
        momentum = rbm.momentum;
    end
    if(i==resetMomentumAtEpoch)
        fprintf('Resetting momentum from %f to %f\n',rbm.initMomentum,rbm.momentum);
    end
    
    tic;
    weightFactor = rbm.alpha*opts.weightcost;
    weightupdateslog = zeros(numel(rbm.W),1);
    if(isfield(opts,'sparsityTarget'))
        q = zeros(numel(rbm.c),1);
    end
    
    for l = 1 : numbatches
        batch = x(kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize), :);
        
        if(islogical(batch))
            batch = double(batch);
        end
        
        v1 = batch;
        %h1 = sigmrnd(repmat(rbm.c', opts.batchsize, 1) + v1 * rbm.W');
        h1 = sigmrnd(rbm.c(:,ones(1,opts.batchsize))' + v1 * rbm.W');
        if(contInput)
            %mean value reconstruction of Gaussian input variable
            %v2 = sigm(repmat(rbm.b', opts.batchsize, 1) + h1 * rbm.W);
            %v2 = repmat(rbm.b', opts.batchsize, 1) + h1 * rbm.W;
            v2 = rbm.b(:,ones(1,opts.batchsize))' + h1 * rbm.W;
        else
            %v2 = sigmrnd(repmat(rbm.b', opts.batchsize, 1) + h1 * rbm.W);
            v2 = sigmrnd(rbm.b(:,ones(1,opts.batchsize))' + h1 * rbm.W);
        end
        %h2 = sigmrnd(repmat(rbm.c', opts.batchsize, 1) + v2 * rbm.W'); %original line of this toolbox
        h2 = sigm(rbm.c(:,ones(1,opts.batchsize))'+ v2 * rbm.W'); %recommended by Hinton
        
        c1 = h1' * v1;
        c2 = h2' * v2;
        
        
        rbm.vW = momentum * rbm.vW + rbm.alpha * (c1 - c2)     / opts.batchsize;
        rbm.vb = momentum * rbm.vb + rbm.alpha * sum(v1 - v2)' / opts.batchsize;
        rbm.vc = momentum * rbm.vc + rbm.alpha * sum(h1 - h2)' / opts.batchsize;
        
        if(weightFactor>0)
            %use weight decay
            rbm.vW = rbm.vW - weightFactor*rbm.W;
        end
        if(isfield(opts,'sparsityTarget'))
            %use sparsity penalty
            q = sparsityLambda*q + (1-sparsityLambda)*mean(h1,1).'; %update mean activation estimate
            dSP = sparsityCost*(opts.sparsityTarget-q);
            rbm.vW = rbm.vW + dSP(:,ones(1,size(rbm.vW,2)));
            rbm.vc = rbm.vc + dSP;
        end
        
        
        
        rbm.W = rbm.W + rbm.vW;
        rbm.b = rbm.b + rbm.vb;
        rbm.c = rbm.c + rbm.vc;
        
        err = err + sum(sum((v1 - v2) .^ 2)) / opts.batchsize;
        
        weightupdateslog = weightupdateslog + rbm.vW(:);
        
        
        
    end
    elTime = toc;
    disp(['epoch ' num2str(i) '/' num2str(opts.numepochs)  '. Average reconstruction error is: ' num2str(err / numbatches) ' Elapsed time(s): ' num2str(elTime)]);
    
    weightupdateslog = weightupdateslog/numbatches;
    epochErr(i) = err/numbatches;
    
    
    
    if(isfield(opts,'valData'))
        %compute average free energy on training and validation set.
        %If gap between curves grows, this is an indicator for overfitting.
        tmp=rbm.c(:,ones(1,size(x,1)))' + x * rbm.W';
        freeEnergyTrain(i) = mean(-x*rbm.b-sum(log(1+exp(tmp)),2));
        
        tmp=rbm.c(:,ones(1,size(opts.valData,1)))' + opts.valData * rbm.W';
        freeEnergyVal(i) = mean(-opts.valData*rbm.b-sum(log(1+exp(tmp)),2));
    end
    
    if(doPlotError)
        drawnow;
        if(mod(i,10)==0)
            
            
            nRow = 2;
            if(isfield(opts,'valData') || isfield(opts,'sparsityTarget'))
                nRow = 3;
            end
            figure(handle);
            subplot(nRow,2,1);
            plot(epochErr);
            xlabel('epoch');
            ylabel('reconstruction error');
            title(['\alpha = ', num2str(rbm.alpha)]);
            subplot(nRow,2,2);
            hist(rbm.W(:),100);
            title('Weight Histogram');
            hs = subplot(nRow,2,3);
            %cla;
            delete(hs);
            hs = subplot(nRow,2,3);
            hold off;
            v1 = x(1:min(500,size(x,1)), :);
            h1 = sigmrnd(repmat(rbm.c', size(v1,1), 1) + v1 * rbm.W');
            imagesc(h1);
            title('hidden activations');
            xlabel('hidden units');
            ylabel('input sample');
            colormap gray;
            colorbar;
            subplot(nRow,2,4);
            hist(weightupdateslog,100);
            title('Weight Update Histogram');
            if(isfield(opts,'valData'))
                subplot(nRow,2,5);
                hold off;
                plot(freeEnergyTrain);
                hold on;
                plot(freeEnergyVal,'g');
                legend('train','validation');
                title('free energy');
            end
            if(isfield(opts,'sparsityTarget'))
                subplot(nRow,2,6);
                hold off;
                plot(q);
                hold on;
                plot(opts.sparsityTarget*ones(size(q)),'g');
                legend('estimated mean activation','sparsity target');
                xlabel('hidden unit index');
            end
        end
    end
    
    
    
end

rbm.epochErr = epochErr;

if(~isempty(plotFilename))
    print(handle,'-dpdf',plotFilename);
end

end
