clear all

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% set paths
% Matlab_Path: where the matlab scripts are
% Bin_Path: where the C++ binaries are
% Model_Path: where to store the models

%Matlab_Path = 'C:\Users\pehrob\Documents\RobustASR\V2\Matlab\libspn\';
%Bin_Path = 'C:\Users\pehrob\Documents\RobustASR\V2\CPP\bin\';
%Model_Path = 'C:\Users\pehrob\Documents\RobustASR\V2\Models\';

% Matlab_Path = '/afs/spsc.tugraz.at/user/peharz/RobustASR/V2/Matlab/libspn/';
% Bin_Path = '/afs/spsc.tugraz.at/user/peharz/RobustASR/V2/CPP/bin/';
% Model_Path = '/afs/spsc.tugraz.at/user/peharz/RobustASR/V2/Models/';



RobustASR='/home/jmorales/SpeechData/Softwares/SPNv2';

Matlab_Path = [RobustASR '/Matlab/libspn/'];
Bin_Path = [RobustASR '/CPP/bin/'];
Model_Path = [RobustASR '/Models/'];






addpath(Matlab_Path);


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% call your data TRAINX, TRAINL, VALX, VALL
%
% TRAINX:     every row is a sample, every column is a dimension. 
%             size(TRAINX,2) MUST be L1 * L2 (see below, parameters for PD
%             learning).
%             It is assumed that the data is organized as a 2D-array, 
%             where samples consist of L2-many blocks of length L1.
%             For example, when L1 = 40 and L2 = 11, then 
%             TRAINX(i,1:40) and TRAINX(i,41:80) are the first 
%             two blocks of the i'th sample. You can interpret the blocks
%             as the first two rows (or columns) of an image.
%
%             In other words, the first dimension "toggles faster" in 
%             the linear representation of 2D data. Don't be confused
%             by the fact that the PD algorithm refers to dim1 as "width"
%             and to dim2 as "height". "width" and "height" are just 
%             arbitrary names for the two dimensions.
%             
%             You can make sure the correct indexing as follows:
%             If your i'th sample is a matrix R{i}, then set
%             TRAINX(i,:) = R{i}';    
%             [L1,L2] = size(R{i});
%             Of course size(R{i}) must be the same for all i.
%
%
% VALX:       organized the same way as TRAINX
%             used for early stopping
%
%
% TRAINL:     labels for train data
%             TRAINL MUST contain ALL numbers between 0 and 
%             max(TRAINL)! Otherwise the CPP-binaries will throw an 
%             exception.
%             Sorry, this is a bit inconvenient if you want to do 
%             quick debugging, but it should decrease the chance of 
%             inconsistencies 
%
%
% VALL:       labels for dev data
%             all VALL must be <= max(TrainLabelsPD)



%%%%
%%%%  DELETE the following block and insert something useful for 
%%%%  TRAINX, TRAINL, VALX and VALL
%%%%

%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%% DELETE ME! START %%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%
load('TrDev');                                       % <- DELETE THIS!
Tr.Post = floor(5*rand(size(Tr.Post)));              % <- DELETE THIS!
Dev.Post = floor(5*rand(size(Dev.Post)));            % <- DELETE THIS!
TRAINX = Tr.Feat';                                   % <- DELETE THIS!
TRAINL = Tr.Post;                                    % <- DELETE THIS!
VALX = Dev.Feat';                                    % <- DELETE THIS!
VALL = Dev.Post;                                     % <- DELETE THIS!
clear Tr Dev                                         % <- DELETE THIS!
TRAINX = TRAINX(1:5000,:);                           % <- DELETE THIS!
TRAINL = TRAINL(1:5000);                             % <- DELETE THIS!
VALX = VALX(1:2200,:);                               % <- DELETE THIS!
VALL = VALL(1:2200);                                 % <- DELETE THIS!



%%%%%%%%%%%%%%%%%%%%%%%%
%%%% DELETE ME! END %%%%
%%%%%%%%%%%%%%%%%%%%%%%%



%%%%%%%%%%%%%%%
% PD Learning
%
% numGauss: how many input distributions (Gaussians) per dimension?
%
% numSums: how many sum nodes per rectangle? 
%          (exception: the root (top) rectangle has lgth(unique(labels)) 
%          many sum nodes)
%
% batchSize: batch size for incremental EM
%
% sparsePrior: penalty factor for creating a new sum node
%
% sparsePriorAnneal: actual sparsePrior grows linearly within the first 
%                    "sparsePriorAnneal" iterations      
%
% coarseRes, coarseResX, coarseResY: structure determining parameters
%
% L1, L2: length of 1st and 2nd dimension
%
% numSamplesTrainPD: number of samples to use for PD learning
%
% numSamplesDevPD: number of samples for Development set
%
% numIterPD: max number of iterations
%
% earlyStoppingKrel, earlyStoppingKabs: early stopping, how many 
%           iterations are allowed with no relative/absolute 
%           increase of the dev likelihood.
%
% 
%            USE AS MANY SAMPLES AS POSSIBLE FOR PD LEARNING!
%            I.e. set numSamplesTrainPD as large as possible, so that 
%            training still finishes.
%

numGauss = 10;
numSums = 20;
batchSize = 10;
sparsePrior = 1;
sparsePriorAnneal = 10;
coarseRes = [4];
coarseResX = [];
coarseResY = [];
L1 = 40;
L2 = 11;
%numSamplesTrainPD = 2000;
%numSamplesDevPD = 2200;
numSamplesTrainPD = 200;
numSamplesDevPD = 220;

numIterPD = 20;
earlyStoppingKrel = 3;
earlyStoppingKabs = 5;

% Prepare data for PD learning
TrainDataPD = TRAINX(1:numSamplesTrainPD,:);
DevDataPD = VALX(1:numSamplesDevPD,:);

TrainLabelsPD = TRAINL(1:numSamplesTrainPD);
VALLabelsPD = VALL(1:numSamplesDevPD);



%%%%%%%%%%%%%
% EM learning
%
% numIterEM: number of iterations
% earlyStoppingK_EM: early stopping parameter
% updateWeights: update sum weights?
% updateMeans: update means?
% updateSigmas: update sigmas?
% minSigma: minimum value for Gaussian sigmas

numIterEM = 30;      % I think you can leave this fixed
earlyStoppingK_EM = 3;
updateWeights = 1;
updateMeans = 1;
updateSigmas = 1;
minSigma = 0.1;      % I think you can leave this fixed




%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%
% Run PD algorithm
%

modelFile = sprintf('%sSPN_G%d_S%d_B%d_sP%d.mod', Model_Path, numGauss, numSums, batchSize, sparsePrior);

LABELS{1} = TrainLabelsPD;
LABELS{2} = VALLabelsPD;

trainSPN_PD(TrainDataPD, L1, L2, ...
    'valData', DevDataPD, ...
    'outputLabels', LABELS,...
    'binPath', Bin_Path,...
    'numGauss', numGauss, ...
    'numIter', numIterPD, ...
    'modelFile', modelFile,...
    'stop_absLikelihoodChange', -1e20,...    
    'earlyStoppingKrel', earlyStoppingKrel,...
    'earlyStoppingKabs', earlyStoppingKabs,...
    'coarseRes', coarseRes, ...
    'coarseResX', coarseResX, ...
    'coarseResY', coarseResY, ...
    'batchSize', batchSize);


LLtrain = inferLL(TrainDataPD, [], 'modelFile', modelFile, 'binPath', Bin_Path);
LLdev = inferLL(DevDataPD, [], 'modelFile', modelFile, 'binPath', Bin_Path);

LLtrainIdx = (1:size(LLtrain,1)) + TrainLabelsPD * size(LLtrain,1);
LLdevIdx = (1:size(LLdev,1)) + VALLabelsPD * size(LLdev,1);

fprintf('LL train after PD algorithm: %d   LL dev: %d\n', sum(LLtrain(LLtrainIdx)), sum(LLdev(LLdevIdx)));
fprintf('\n');

%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%
% Run EM algorithm
%

EMmodelFile = sprintf('%s_EM_%d_%d_%d.mod', modelFile(1:end-4), updateWeights, updateMeans, updateSigmas);
LABELS{1} = TRAINL;
LABELS{2} = VALL;


%  'outputLabels', LABELS,...  'outputLabels', LABELS,...
[~, history] = trainSPN_EM(modelFile, TRAINX, VALX, EMmodelFile, ...    
    'numIter', numIterEM, ...
    'outputLabels', LABELS,...
    'updateWeights', updateWeights, ...
    'updateMeans', updateMeans, ...
    'updateSigmas', updateSigmas, ...
    'minSigma', minSigma, ...
    'earlyStoppingK', earlyStoppingK_EM, ...
    'stop_relLikelihoodChange', -1e6,...
    'binPath', Bin_Path);

LLtrain = inferLL(TRAINX, [], 'modelFile', EMmodelFile, 'binPath', Bin_Path);
LLdev = inferLL(VALX, [], 'modelFile', EMmodelFile, 'binPath', Bin_Path);

LLtrainIdx = (1:size(LLtrain,1)) + TRAINL * size(LLtrain,1);
LLdevIdx = (1:size(LLdev,1)) + VALL * size(LLdev,1);

fprintf('LL train after EM algorithm: %d   LL dev: %d\n', sum(LLtrain(LLtrainIdx)), sum(LLdev(LLdevIdx)));



