% @=============================================================================
% Copyright (c)2025 University of Granada - SPAIN
% This software is distributed under the terms of the GNU General Public License
% as published by the Free Software Foundation. Further details on the GPLv3
% license can be found at http://www.gnu.org/copyleft/gpl.html
% 
% FOR RESEARCH PURPOSES ONLY. THE SOFTWARE IS PROVIDED "AS IS," AND THE
% UNIVERSITY OF GRANADA DO NOT MAKE ANY WARRANTY, EXPRESS OR IMPLIED, INCLUDING 
% BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 
% PURPOSE, NOR DO THEY ASSUME ANY LIABILITY OR RESPONSIBILITY FOR THE USE 
% OF THIS SOFTWARE.
%
% For more information see /license/license.html
% =============================================================================@
%
% Author: Juan Ruiz de Miras, 2025

clear;

% CPU and GPU seed for random calculations and reproducibility
seed = 123456;
rng(seed);
gpurng(seed);
deep.gpu.deterministicAlgorithms(true);

% parameters for sliding-windows
patchSize = [64, 64];       
overlap   = 0.65;

% MobileNetV2 input
inputSize = [224 224 3]; 

class = {'veronese', 'non-veronese'};
veronesePaintings    = {'01','02','03','04'};
nonVeronesePaintings = {'05','06'};

% sliding-window processing
[veroneseColorPatches, veroneseGrayscalePatches, veroneseEdgePatches] = ...
    extractPatches(patchSize, overlap, class{1}, veronesePaintings);
[nonVeroneseColorPatches, nonVeroneseGrayscalePatches, nonVeroneseEdgePatches] = ...
    extractPatches(patchSize, overlap, class{2}, nonVeronesePaintings);

[X, Y] = prepareDataCNN(veroneseColorPatches, veroneseGrayscalePatches, veroneseEdgePatches, ...
                        nonVeroneseColorPatches, nonVeroneseGrayscalePatches, nonVeroneseEdgePatches, ...
                        patchSize);

% Initialize vectors for predictions and scores
y_pred = zeros(size(Y));     % 0/1
scores = zeros(length(Y),1); % prob. clase '1'

totalPatches = length(y_pred);
numPaintings = numel(veronesePaintings) + numel(nonVeronesePaintings);
patchesPerPainting = totalPatches / numPaintings;

% Cross-validation training on patches from all paintings except for the painting used for validation
for i = 1:numPaintings
    fprintf('\n=== Fold %d/%d ===\n', i, numPaintings);

    testIdx = false(1, totalPatches);
    iniPatch = (i - 1) * patchesPerPainting + 1;
    endPatch = i * patchesPerPainting;
    testIdx(iniPatch:endPatch) = true;
    trainIdx = ~testIdx; % train patches are the patches that are not test patches

    XTrain_Fold = X(:,:,:,trainIdx);
    YTrain_Fold = Y(trainIdx);
    XVal_Fold   = X(:,:,:,testIdx);
    YVal_Fold   = Y(testIdx);

    % MobileNetV2 input has 3 channels (RGB)
    XTrain_Fold = XTrain_Fold(:,:,1:3,:);
    XVal_Fold = XVal_Fold(:,:,1:3,:);
    
    % 64 x 64 --> 224 x 224
    if isrow(YTrain_Fold), YTrain_Fold = YTrain_Fold'; end
    if isrow(YVal_Fold), YVal_Fold = YVal_Fold'; end
    dsTrain = augmentedImageDatastore(inputSize, XTrain_Fold, YTrain_Fold, 'ColorPreprocessing', 'gray2rgb');
    dsVal   = augmentedImageDatastore(inputSize, XVal_Fold, YVal_Fold, 'ColorPreprocessing', 'gray2rgb');
    
    % loads MobileNetV2
    net = mobilenetv2;
    lgraph = layerGraph(net);

    % last layer must have only 2 classes (0 and 1)
    newFCLayer = fullyConnectedLayer(2, 'Name', 'new_fc', ...
                                     'WeightLearnRateFactor', 10, 'BiasLearnRateFactor', 10);
    newSoftmaxLayer = softmaxLayer('Name', 'new_softmax');
    lgraph = replaceLayer(lgraph, 'Logits', newFCLayer);
    lgraph = replaceLayer(lgraph, 'ClassificationLayer_Logits', newSoftmaxLayer);
    
    % dlnetwork is needed for trainnet
    dlnet = dlnetwork(lgraph);
    
    % training parameters
    options = trainingOptions('adam', ...
        'MiniBatchSize', 16, ...
        'MaxEpochs', 6, ... 
        'InitialLearnRate', 1e-4, ...
        'Shuffle', 'every-epoch', ...
        'ValidationData', dsVal, ...
        'Verbose', true, ... 
        'Plots', 'none');
        
    % training
    netTransfer = trainnet(dsTrain, dlnet, "crossentropy", options);
    
    % validation for the current fold
    score_fold = minibatchpredict(netTransfer, dsVal);
    scores(testIdx) = score_fold(:, 2);
    [~, idxMax] = max(score_fold, [], 2);
    y_pred_fold = idxMax - 1;
    y_pred(testIdx) = y_pred_fold;
end

% confusion matrix
confMat = confusionmat(str2double(string(Y)), y_pred);
TP = confMat(2,2);
TN = confMat(1,1);
FP = confMat(1,2);
FN = confMat(2,1);

% performance metrics
fprintf('Performance metrics for %d-fold cross-validation\n', numPaintings);
accuracy = (TP + TN) / sum(confMat(:));
precision = TP / (TP + FP);
sensitivity = TP / (TP + FN);
specificity = TN / (TN + FP);
f1_score = 2 * (precision * sensitivity) / (precision + sensitivity);

Gmean = sqrt(sensitivity * specificity);
po = accuracy;
pe = ((TP + FP)*(TP + FN) + (FN + TN)*(FP + TN)) / sum(confMat(:))^2;
kappa = (po - pe) / (1 - pe);

fprintf('Accuracy     : %.2f%%\n', accuracy * 100);
fprintf('Precision    : %.2f%%\n', precision * 100);
fprintf('F1-Score     : %.2f%%\n', f1_score * 100);
fprintf('Sensitivity  : %.2f%%\n', sensitivity * 100);
fprintf('Specificity  : %.2f%%\n', specificity * 100);
fprintf('G-mean       : %.2f%%\n', Gmean * 100);
fprintf('Kappa        : %.2f \n', kappa);

% ROC curve
[Xroc, Yroc, T, AUC] = perfcurve(Y, scores, '1');
figure;
plot(Xroc, Yroc, 'b-', 'LineWidth', 2);
hold on;
plot([0 1], [0 1], 'r--');
xlabel('False Positive Rate (1 - Specificity)'); 
ylabel('True Positive Rate (Sensitivity)');
title(['ROC Curve (AUC = ', num2str(AUC, '%.2f'), ')']); 
legend('CNN Model', 'Random Classifier', 'Location', 'southeast');
axis square;
grid on; 
hold off; 

fprintf('AUC-ROC      : %.3f\n', AUC);
