% Experiments for the paper 'Joint time-vertex stationary signal processing". 
% The code tests the accuracy of JPSD estimation methods (mode 'accuracy') shown 
% in Fig. 1 of the paper, as well as their computational complexity in machine 
% time (mode 'complexity') shown in Fig. 2.
% 
% The first part of the code takes a long time since it repeats all experiments. 
% To just see the results, set the variable replay to 1.
%  
% The home directory should be at the same folder as this file (with
% subdirectories results/utils/data).
% 
% Requirements:
%   The latest version of the gspbox, ltfat, and UNLocBox
%   To save the figures you will need the export_fig matlab package.
% 
% Andreas Loukas, Nathanael Perraudin
% 29 September 2016

clear; close all; clc; gsp_reset_seed(1)

replay         = 1;                     % 1 to see the results, 0 to recompute everything from scratch (takes a long time)
overwrite      = 0;                     % 1 to save (overwrite) the results / figures
datasets       = {'separable'};           
distributions  = {'gaussian'};          % choices: 'gaussian' 
mode           = 'accuracy';            % choices: complexity / accuracy 

switch mode,
    case 'accuracy'
        runs = 20; 
        Nall = 256; 
        Tall = 128;
        Lall = 2.^(log2(2):1:log2(max(Tall))); 
        Fall = round(logspace(log10(2), log10(max(Nall)), 8));
        methods = {'sample' 'TVA'};     % TVA corresponds to the convolutional JPSD estimator
    case 'complexity'
        runs = 10;
        Nall = 1000:2000:11000;
        Tall = 128;
        Lall = [32 64];
        Fall = [25 50];
        methods        = {'TVA_eig' 'TVA'};
end

fs  = 1; R = 1;      

PSD_real    = cell(numel(Nall), numel(Tall),                 runs, numel(datasets), numel(distributions));
PSD         = cell(numel(Nall), numel(Tall), numel(methods), runs, numel(datasets), numel(distributions), numel(Lall), numel(Fall));
error       = nan( numel(Nall), numel(Tall), numel(methods), runs, numel(datasets), numel(distributions), numel(Lall), numel(Fall));
time        = nan( numel(Nall), numel(Tall), numel(methods), runs, numel(datasets), numel(distributions), numel(Lall), numel(Fall));
        
if ~replay,
    
    for distribution = 1:numel(distributions), 
    for dataset      = 1:numel(datasets), 
    for run          = 1:runs,
        
        for NIdx = 1:numel(Nall), 
        for TIdx = 1:numel(Tall),             
            N = Nall(NIdx); T = Tall(TIdx);

            % -------------------------------------------------------------
            % PREPARE DATA
            % -------------------------------------------------------------
            G = gsp_sensor(N);
            G = gsp_estimate_lmax(G); 
            G = gsp_jtv_graph(G, T, fs); 
            Geig = gsp_compute_fourier_basis(G);
            
            % desired joint frequency response
            switch datasets{dataset},
                case 'separable'
                    g        = @(lambda,omega) exp(-1*(lambda/Geig.lmax)).* exp(-5*(omega.^2));
                    ft       = 'js';
                    psd_real = abs(gsp_jtv_filter_evaluate(g, ft, Geig.e, Geig.jtv.omega)).^2;
            end            
            
            PSD_real{NIdx,TIdx,run,dataset,distribution} = psd_real;
            
            % generate joint process
            switch distributions{distribution}
                case 'gaussian'
                    E = randn(N,T,1,R);
            end

            X = gsp_jtv_filter_synthesis(Geig, g, ft, E);
            
            % figure; gsp_plot_jft(G, gsp_jft(G, X).^2);
            % (sum(psd_real(:)) - (trace(X*X')/R)) 
            
            % -------------------------------------------------------------
            % PSD ESTIMATION
            % -------------------------------------------------------------            
            fprintf('distr:%s, data:%s, run:%d, N:%d, T:%d\n', distributions{distribution}, datasets{dataset}, run, N, T);
            
            for LIdx = 1:numel(Lall),
            for FIdx = 1:numel(Fall),

                param = struct;
                param.Nfilt = min(Fall(FIdx), Nall(NIdx));
                param.L     = min(Lall(LIdx), Tall(TIdx)); 
                param.Nrandom = 100;
                msg = sprintf('\t(L:%3d, F:%3d)', Lall(LIdx), Fall(FIdx));
                
                for mIdx = 1:numel(methods),

                    if ~isempty(strfind( methods{mIdx}, '_eig')) 
                        iG = Geig;
                    else
                        iG = G;
                    end
                    
                    tstart = tic; 
                    switch methods{mIdx},
                        
                        case  'sample',
                            
                            if FIdx > 1 || LIdx > 1,
                                psd = PSD{NIdx,TIdx,mIdx,run,dataset,distribution,1,1};
                            else
                                psd = mean(gsp_jft(Geig, X).^2, 3);
                            end
                                                        
                        case {'TVA' 'TVA_eig'}
                            
                            param.estimator = 'TVA';
                            [psd, ft] = gsp_estimate_vertex_time_psd(iG, X, param);
                            
                        otherwise
                            warning('unknown method.. attempting to continue.');
                    end
                    fprintf(['Method: ',methods{mIdx},' -- time: %f \n'],toc(tstart))
                    if ~isnumeric(psd),
                        psd = gsp_jtv_filter_evaluate(psd, ft, Geig.e, Geig.jtv.omega);
                    end
                    
                    % normalize energy
                    psd = abs(psd);
                    psd = psd * ( (trace(X*X')/R) / sum(psd(:)) );
                                        
                    % take into acount the eigenvalue computation in the execution time
                    if strcmp(mode, 'complexity') && ~isempty(strfind( methods{mIdx}, '_eig'))
                        tstart2 = tic; [U,E] = eig(full(G.L)); fprintf('Time for diagonalization: %f \n', toc(tstart2))
                    end

                    exec_time = toc(tstart);
                    if strcmp(mode, 'accuracy')
                        PSD{NIdx,TIdx,mIdx,run,dataset,distribution, LIdx, FIdx} = psd;
                    end
                    % estimation error
                    error(NIdx, TIdx, mIdx, run, dataset, distribution, LIdx, FIdx) = ...
                        norm(psd - psd_real, 'fro') ./ norm(psd_real, 'fro');
                    
                    % execution time
                    time(NIdx, TIdx, mIdx, run, dataset, distribution, LIdx, FIdx) = exec_time;
                   
                    % update the console message
                    msg = sprintf('%s %s %1.3f |', msg, methods{mIdx}, error(NIdx, TIdx, mIdx, run, dataset, distribution, LIdx, FIdx));                    
                end
                fprintf('%s\n', msg);
            end
            end
            clear G iG Geig
        end
        end
        save('timing.mat','time','error');
            
    end
    end
    end

    if overwrite,
        save(sprintf('results/JPSD_%s.mat', mode));
    end
else 
    load(sprintf('results/JPSD_%s.mat', mode));
end

%% Visualization code 
switch mode,
    
    case 'accuracy'    
        %% Visualization 1: show PSD / print errors
        % this is only for visual inspection and did not appear in the paper
        clc; close all;
        
        run = 1; NIdx = 1; TIdx = 1; LIdx = 5; FIdx = 5;
        
        % dummy graph
        N = Nall(NIdx); T = Tall(TIdx);
        G = gsp_sensor(N);
        G = gsp_compute_fourier_basis(G);
        G = gsp_jtv_graph(G, T, fs);
        
        for distribution = 1:numel(distributions),
            for dataset = 1:numel(datasets),
                
                psd_real = PSD_real{NIdx,TIdx,run,dataset, distribution};
                psd      = cell2mat(PSD(NIdx, TIdx, :, run, dataset, distribution, LIdx, FIdx));
                clim = [min([psd(:); psd_real(:)]) max([psd(:); psd_real(:)])];
                name = sprintf('%9s, %9s, L:%d, F:%d', distributions{distribution}, datasets{dataset}, Lall(LIdx), Fall(FIdx));
                
                if ~(sum(psd(:))), continue; end
                
                figure; set(gcf, 'Color', [1 1 1], 'Name', name)
                msg = sprintf('%9s, %9s, L:%3d, F:%3d | ', distributions{distribution}, datasets{dataset}, Lall(LIdx), Fall(FIdx));
                
                param_plot.logscale = 1; 
                param_plot.dB = inf;
                
                % the real PSD
                subplot(2,3,1);
                gsp_plot_jft(G, psd_real, param_plot);
                title('real'); caxis(clim);
                
                for mIdx = 1:size(PSD,3),
                    % estimated PSD
                    subplot(2,3,mIdx+1);
                    gsp_plot_jft(G, psd(:,:,mIdx), param_plot);
                    title(methods{mIdx}); caxis(clim);
                    
                    msg = sprintf('%s %s %1.3f |', msg, methods{mIdx}, mean(error(NIdx, TIdx, mIdx, :, dataset, distribution, LIdx, FIdx)));
                end
                
                fprintf('%s\n', msg);
            end
        end
        
        %% Visualization 2: plot variance/bias trade-off for all parameters
        % Fig. 1 in the paper.
        
        NIdx = 1; TIdx = 1; distribution = 1; dataset = 1; overwrite = 0;
        psd_real = vec(PSD_real{NIdx, TIdx, 1, dataset, distribution});
        
        % TVA
        bias_TVA = nan(numel(Lall), numel(Fall));
        var_TVA  = nan(numel(Lall), numel(Fall));
        time_TVA = nan(numel(Lall), numel(Fall));
        e_TVA    = nan(numel(Lall), numel(Fall));
        for LIdx = 1:numel(Lall),
            for FIdx = 1:numel(Fall),
                
                psd = zeros(Nall(NIdx)*Tall(TIdx), runs);
                for run = 1:runs
                    psd(:,run) = vec(squeeze(PSD{NIdx, TIdx, strcmp(methods,'TVA'), run, dataset, distribution, LIdx, FIdx}));
                end
                psd_mean = mean(psd, 2);
                bias_TVA(LIdx,FIdx) = norm(psd_real - psd_mean)/norm(psd_real);
                tmp = psd - repmat(psd_mean, 1, runs);
                var_TVA(LIdx,FIdx) = trace(tmp'*tmp)/runs/norm(psd_mean,2)^2; %/(Nall(NIdx)*Tall(TIdx));
                
                time_TVA(LIdx,FIdx) = mean(time( NIdx, TIdx, strcmp(methods,'TVA'), :, dataset, distribution, LIdx, FIdx));
                e_TVA(LIdx,FIdx)    = mean(error(NIdx, TIdx, strcmp(methods,'TVA'), :, dataset, distribution, LIdx, FIdx));
            end
        end
        std_TVA = (var_TVA.^(0.5));
        [bias_TVA var_TVA time_TVA];
        
        [iX, iY] = meshgrid(Lall, Fall);
        xtick =  Lall(1:2:end); ytick =  Fall(1:2:end); fsize = 12;
        width = 330; height = 230;
        
        angle = [161 36];
        angle = [0 90];
        figure; set(gcf, 'Color', [1 1 1], 'Position', [100 200 width height]); hold on; set(gca, 'FontSize', fsize);
        set(gca, 'XScale', 'log', 'YScale', 'log', 'ZScale', 'linear', 'XTick', xtick, 'YTick', ytick, 'ZTick', [0:0.2:0.6]);
        surf(iX, iY, e_TVA');
        xlabel('L'); ylabel('F'); %, 'Position', [42 0.2 -0.9]);
        zlabel('error'); view(angle);
        xlim([Lall(1)/1.0 Lall(end)*1.0]); ylim([Fall(1)/1.0 Fall(end)*1.0]); grid on;
        zlim([0 0.6]);
        colorbar;
        colormap(brewermap(100,'*RdYlGn')*0.95);
        caxis([min(e_TVA(:)) 0.7*max(e_TVA(:))]);
        if overwrite,
            export_fig(sprintf('results/JPSD_accuracy_%d_%s_variance_bias_1.pdf', Nall(NIdx), datasets{dataset}));
        end
        
        figure; set(gcf, 'Color', [1 1 1], 'Position', [100 200 width height]); hold on; set(gca, 'FontSize', fsize);
        set(gca, 'XScale', 'log', 'YScale', 'log', 'ZScale', 'linear', 'XTick', xtick, 'YTick', ytick, 'ZTick', [0:0.2:0.6]);
        surf(iX, iY, bias_TVA');
        xlabel('L'); ylabel('F'); %, 'Position', [42 0.2 -0.9]);
        zlabel('bias'); view(angle);
        xlim([Lall(1)/1.0 Lall(end)*1.0]); ylim([Fall(1)/1.0 Fall(end)*1.0]); grid on;
        zlim([0 0.6]);
        colorbar;
        colormap(brewermap(100,'*RdYlGn')*0.95);
        caxis([min(e_TVA(:)) 0.7*max(e_TVA(:))]);
        if overwrite,
            export_fig(sprintf('results/JPSD_accuracy_%d_%s_variance_bias_2.pdf', Nall(NIdx), datasets{dataset}));
        end
        
        angle = [-33 45];
        angle = [0 90];
        figure; set(gcf, 'Color', [1 1 1], 'Position', [100 200 width height]); hold on; set(gca, 'FontSize', fsize);
        set(gca, 'XScale', 'log', 'YScale', 'log', 'ZScale', 'linear', 'XTick', xtick, 'YTick', ytick, 'ZTick', [0.04:0.02:0.12]);
        surf(iX, iY, std_TVA');
        xlabel('L'); ylabel('F'); %, 'Position', [42 0.2 -16]);
        zlabel('std. dev.'); view(angle);
        xlim([Lall(1)/1.0 Lall(end)*1.0]); ylim([Fall(1)/1.0 Fall(end)*1.0]); grid on; zlim([0.04 0.12]);
        colorbar;
        colormap(brewermap(100,'*RdYlGn')*0.95);
        caxis([min(std_TVA(:)) 0.9*max(std_TVA(:))]);
        if overwrite,
            export_fig(sprintf('results/JPSD_accuracy_%d_%s_variance_bias_3.pdf', Nall(NIdx), datasets{dataset}));
        end
        
        figure; set(gcf, 'Color', [1 1 1], 'Position', [100 200 width height]); hold on; set(gca, 'FontSize', fsize);
        set(gca, 'XScale', 'log', 'YScale', 'log', 'ZScale', 'linear', 'XTick', xtick, 'YTick', ytick); %, 'ZTick', [10.^(-2:1:4)]);
        surf(iX, iY, time_TVA');
        xlabel('L'); ylabel('F'); %, 'Position', [42 0.2 -31]);
        zlabel('execution (sec)'); view(angle);
        xlim([Lall(1)/1.0 Lall(end)*1.0]); ylim([Fall(1)/1.0 Fall(end)*1.0]); grid on;
        colormap(brewermap(100,'*RdYlGn')*0.95);
        colorbar;
        caxis([min(time_TVA(:)) 0.5*max(time_TVA(:))]);
        zlim([0 20]);
        
        if overwrite,
            export_fig(sprintf('results/JPSD_accuracy_%d_%s_variance_bias_4.pdf', Nall(NIdx), datasets{dataset}));
        end
        
    case 'complexity'
        
        %% Visualization of scalability
        % Fig. 2 in the paper
        
        TIdx = 1; distribution = 1; dataset = 1; LIdx = 2; 
        
        colors      = brewermap(4,'Set1')*0.9;
        markers     = {':o' '-s'};
        fsize       = 12;
        methodNames = {'normal,' 'fast,'};
        
        figure; set(gcf, 'Color', [1 1 1], 'Position', [100 200 750 230]); hold on;
        set(gca, 'FontSize', fsize, 'YScale', 'linear', 'XScale', 'linear');
        
        i = 1;
        for mIdx = 1:numel(methods),
            for FIdx = 1:numel(Fall),
                
                itime  = squeeze(time(1:end-1, TIdx, mIdx, :, dataset, distribution, LIdx, FIdx));
                l = prctile(itime', 0);
                m = prctile(itime', 50);
                u = prctile(itime', 100) ;
                
                if mIdx == 2,
                    ierror1 = squeeze(error(1:end-1, TIdx, 1, :, dataset, distribution, LIdx, FIdx));
                    ierror2 = squeeze(error(1:end-1, TIdx, 2, :, dataset, distribution, LIdx, FIdx));
                    mean(ierror1(:) - ierror2(:))
                end
                
                errorbar(Nall(1:end-1), m, m-l, u-m, ['' markers{mIdx}], 'DisplayName', [methodNames{mIdx}, ' F=', num2str(Fall(FIdx))], 'Color', colors(FIdx,:), 'MarkerFaceColor', colors(FIdx,:));
                i = i + 1;
            end
        end
        xlabel('vertices'); ylabel('computation time (sec)')
        xlim([500 9500]);
        ylim([0 250]);
        
        legend1 = legend('show');
        set(legend1, 'FontSize', fsize, ...
            'Orientation','vertical', 'Location', 'NorthWest',...
            'EdgeColor',[1 1 1]);
        
        if overwrite,
            export_fig(sprintf('results/JPSD_complexity_%s_variance_bias.pdf', datasets{dataset}));
        end
        
end