% @=============================================================================
% Copyright (c)2025 University of Granada - SPAIN
% This software is distributed under the terms of the GNU General Public License
% as published by the Free Software Foundation. Further details on the GPLv3
% license can be found at http://www.gnu.org/copyleft/gpl.html
% 
% FOR RESEARCH PURPOSES ONLY. THE SOFTWARE IS PROVIDED "AS IS," AND THE
% UNIVERSITY OF GRANADA DO NOT MAKE ANY WARRANTY, EXPRESS OR IMPLIED, INCLUDING 
% BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 
% PURPOSE, NOR DO THEY ASSUME ANY LIABILITY OR RESPONSIBILITY FOR THE USE 
% OF THIS SOFTWARE.
%
% For more information see /license/license.html
% =============================================================================@
%
% Author: Juan Ruiz de Miras, 2025

clear;
close all;

% MobileNetV2 input
inputSize = [224 224 3];

% parameters for sliding-windows
imageSize = 640;
patchSize = [64, 64];
overlap = 0.65; 
holdout = 0.05; 

class = {'veronese', 'non-veronese'};
nonVeronesePaintings = {'05', '06'};

% Veronese paintings for each test iteration
veronesePaintingsData = {{'01', '02', '03', '04'}, ...
                         {'02', '03', '04'}, ...
                         {'01', '03', '04'}, ...
                         {'01', '02', '04'}, ...
                         {'01', '02', '03'}};

% test paintings for each test iteration: painting '11' is the painting under
% authentication
testPaintingsData = {{'07', '08', '09', '10', '11'}, ...
                     {'01'}, ...
                     {'02'}, ...
                     {'03'}, ...
                     {'04'}};

for t = 1:size(veronesePaintingsData,2)
    % CPU and GPU seed for random calculations and reproducibility
    seed = 123456; 
    rng(seed); 
    gpurng(seed); 
    deep.gpu.deterministicAlgorithms(true);
    veronesePaintings = veronesePaintingsData{t};
    testPaintings = testPaintingsData{t};
    
    % sliding-window processing
    [patchesColorVeronese, patchesGrayscaleVeronese, patchesEdgesVeronese] = extractPatches(patchSize, overlap, class{1}, veronesePaintings);
    [patchesColorNonVeronese, patchesGrayscaleNonVeronese, patchesEdgesNonVeronese] = extractPatches(patchSize, overlap, class{2}, nonVeronesePaintings);
    [X, Y] = prepareDataCNN(patchesColorVeronese, patchesGrayscaleVeronese, patchesEdgesVeronese, ...
                            patchesColorNonVeronese, patchesGrayscaleNonVeronese, patchesEdgesNonVeronese, ...
                            patchSize);
    
    % holdout partition
    cv = cvpartition(size(X,4), 'HoldOut', holdout);
    
    % 3 channels (RGB) in MobileNetV2
    XTrain = X(:, :, 1:3, training(cv));
    YTrain = Y(training(cv));
    XVal = X(:, :, 1:3, test(cv));
    YVal = Y(test(cv));
    
    % 64 x 64 -> 224 x 224
    if isrow(YTrain), YTrain = YTrain'; end
    if isrow(YVal), YVal = YVal'; end
    dsTrain = augmentedImageDatastore(inputSize, XTrain, YTrain, 'ColorPreprocessing', 'gray2rgb');
    dsVal = augmentedImageDatastore(inputSize, XVal, YVal, 'ColorPreprocessing', 'gray2rgb');
    
    % loads MobileNetV2
    net = mobilenetv2;
    lgraph = layerGraph(net);
    
    % last layer must have only 2 classes (0 and 1)
    newFCLayer = fullyConnectedLayer(2, 'Name', 'new_fc', ...
                                    'WeightLearnRateFactor', 10, 'BiasLearnRateFactor', 10);
    newSoftmaxLayer = softmaxLayer('Name', 'new_softmax');
    lgraph = replaceLayer(lgraph, 'Logits', newFCLayer);
    lgraph = replaceLayer(lgraph, 'ClassificationLayer_Logits', newSoftmaxLayer);
    
    % dlnetwork is needed for trainnet
    dlnet = dlnetwork(lgraph);

    % training parameters
    miniBatchSize = 16;
    numTrainImages = size(XTrain,4);
    iterationsPerEpoch = ceil(numTrainImages / miniBatchSize);
    
    options = trainingOptions('adam', ... 
                              'MaxEpochs', 6, ... 
                              'InitialLearnRate', 1e-4, ... 
                              'MiniBatchSize', miniBatchSize, ...
                              'Shuffle', 'every-epoch', ... 
                              'ValidationData', dsVal, ...
                              'ValidationFrequency', iterationsPerEpoch, ...
                              'ValidationPatience', 3, ...    
                              'Metrics','accuracy', ...
                              'Verbose', true, ...
                              'Plots', 'none'); 
    
    % training
    disp('Training MobileNetV2...');
    netTransfer = trainnet(dsTrain, dlnet, "crossentropy", options);
    
    disp('Predicting probabilities for each patch in test paintings...');
    for i=1:length(testPaintings)
        [patchesColorTest, patchesGrayscaleTest, patchesEdgesTest] = extractPatches(patchSize, overlap, 'test', {testPaintings{i}});
        [XTest, ~] = prepareDataCNN(patchesColorTest, patchesGrayscaleTest, patchesEdgesTest, ...
                                    [], [], [], ...
                                    patchSize);
        % only 3 channels (RGB) and 64 x 64 -> 224 x 224
        XTest_RGB = XTest(:,:,1:3,:);
        dsTest = augmentedImageDatastore(inputSize, XTest_RGB, 'ColorPreprocessing', 'gray2rgb');
        
        % computing probabilities for each patch
        score = minibatchpredict(netTransfer, dsTest);
        [~, idxMax] = max(score, [], 2);
        predictedClass = idxMax - 1; % 1 -> 0, 2 -> 1
    
        % statistics
        probs_veronese = score(:, 2);
        mu = mean(probs_veronese);          
        sigma = std(probs_veronese);
        N = length(probs_veronese);
        SEM = sigma / sqrt(N); % standard error of the mean
        CI_lower = mu - (1.96 * SEM); % confidence interval at 95% (Z-score = 1.96)
        CI_upper = mu + (1.96 * SEM);
        
        % show result
        numPatchesVeronese = sum(predictedClass == 1);
        numPatchesNonVeronese = sum(predictedClass == 0);
    
        disp(['Results for painting: ' testPaintings{i}]);
        disp(['Number of patches classified as Veronese = ' num2str(numPatchesVeronese)]);
        disp(['Number of patches classified as non-Veronese = ' num2str(numPatchesNonVeronese)]);
        disp(['Average Veronese probability = ' num2str(mu * 100, '%3.1f') '%']);
        fprintf('Standard Deviation (SD): %.2f\n', sigma); 
        fprintf('Confidence Interval (95%%): [%.2f%% - %.2f%%]\n', CI_lower * 100, CI_upper * 100);
        disp(' ');
        
        % show Veronese probability heatmap
        heatMapProbVeronese(imageSize, patchSize(1,1), overlap, score, testPaintings{i});
    end
end
