% Experiments for the paper 'Joint time-vertex stationary signal processing". 
% The code tests recovery accuracy and generates the figures in Section VI B 
% of the paper.
% 
% The first part of the code takes a long time since it repeats all experiments. 
% To just see the results, set the variable replay to 1.
% 
% By changing the variables 'dataset' and 'xaxis, one obtains all the different 
% results featured in Section VI B.
% 
% The home directory should be at the same folder as this file (with
% subdirectories results/utils/data).
% 
% Requirements:
%   The latest version of the gspbox, ltfat, and UNLocBox
%   To save the figures you will need the export_fig matlab package.
% 
% Andreas Loukas, Nathanael Perraudin
% 29 September 2016

clear; clc; gsp_reset_seed(1); addpath utils; addpath data;

methods   = {'joint' 'time' 'graph' 'ts'};  % choices: 'joint' (JWSS) 'time' (TWSS) 'graph' (GWSS) 'ts' (MTWSS)
xaxis     = 'miss';                    % choices: miss, train (test the influence of a 
                                       % different number of missing entries, or the sensitivity 
                                       % of the training set size.)
dataset   = 'molene';                  % choices: molene, pems, epidemic
replay    = 1;                         % 1 to see directly the results, 0 to recompute everything from scratch
overwrite = 0;                         % 1 to save (overwrite) the results / figures
verbose   = 0;                         % 1 to see diagnostic figures
mode      = 'entries';

if ~replay,
    
    % -------------------------------------------------------------------------
    % Data preparation phase
    % -------------------------------------------------------------------------
    switch dataset,
        
        case 'epidemic',
            load 'flight_network.mat';
            
            % prepare the graph of Europe
            W = exp(-0.2*squareform(pdist(coords)).^2);
            W  = W.*(W > 0.01);
            select = 1:600;
            W = W(select,select); coords = coords(select,:);
            
            % select the giant component
            [~, idx] = agsp_util_cc_split( W>0 );
            W = W(idx{1},idx{1}); coords = coords(idx{1},:);
            
            N = size(W,1); T = 180; 
            G = gsp_graph(W, coords);
            G = gsp_jtv_graph(G, T, 1);
            G = gsp_compute_fourier_basis(G);
            if verbose, figure; gsp_plot_graph(G, param_plot); end
            
            contagion_prob    = 0.005;             % contagion probability per day
            infection_length  = 2;                 % length of infection period (days)
            params.maxTime    = T;                 % days (before cure is found)
            params.population = 50;                % number of people in each airport (used to attain higher precision that 0/1)
            params.model      = 'SIRS';            % infection model (SI, SIR)
            param.immunity    = 10;
            
            seeds = [12 45]; %; 0 38];
            R_all = 10; 
            
            X = nan(N, T, R_all);
            
            if verbose, figure; hold on; end
            for r = 1:R_all,
            
                % specify the locaton(s) of patient zero
                tmp = (squareform(pdist([seeds(1,:); G.coords])));
                tmp = tmp(1,2:end); [~,patient_zero] = min(tmp);
                
                % run the model
                [~, S, I, R] = agsp_SIRS(G.W, contagion_prob, infection_length, patient_zero, params);
                fprintf('infection complete!\n');
                
                if verbose, plot(1:T, sum(I), 'DisplayName', 'infected'); end
                
                % Is = 0*I; for i = 1:N, Is(i,:) = smooth(I(i,:)); end
                X(:,:,r) = I;
            end
            if verbose, xlabel('days'); ylabel('airports (approx)'); legend show; drawnow; end
            
            switch xaxis,
                case 'miss',
                    train_prc = [0.5 0.9];                  % percentage of data used for training
                    miss_prc  = linspace(0.05, 0.95, 19);   % percentage of values missing from the test data
                case 'train'
                    train_prc = 0.1:0.1:0.9;                  % percentage of data used for training
                    miss_prc  = [0.1 0.3];   % percentage of values missing from the test data
            end
            
            L_all = round(T/1); F_all = round(N/1);
            
        case 'molene'
                load('data/meteo_molene_t.mat');

                x = info{4}; y = info{3}; z = info{5}; coords = [x,y,5*z];

                % select the first 30 days
                X = value(:, 1:30*24);

                % Remove the mean of the data (alternatively we could remove 273)
                X = X - mean(X(:)); N = size(X,1);

                days  = 2; 
                X     = reshape(X, N, 24*days, 30/days); % make sure that 30 is divisible by days
                R_all = size(X, 3);           
                T     = size(X, 2); 

                param_graph.k = 5;
                param_graph.type = 'knn';
                param_graph.epsilon = sqrt(var(coords(:)))*sqrt(3);
                param_graph.sigma = var(coords(:))*sqrt(3)*0.001;
                G = gsp_nn_graph(coords, param_graph);
                G.plotting.limits = [-92, 2826, 22313, 24769];
                G = gsp_compute_fourier_basis(G);
            
            L_all = T;
            F_all = N; 
            
            switch xaxis,
                case 'miss',
                    train_prc = linspace(0.2, 0.5, 2);      % percentage of data used for training
                    miss_prc  = linspace(0.05, 0.95, 8);    % percentage of values missing from the test data
                    
                case 'train'
                    train_prc = [1:2:R_all-1]./R_all;      % percentage of data used for training
                    miss_prc  = [0.1 0.3]; %linspace(0.05, 0.95, 19);    % percentage of values missing from the test data
            end
            
        case 'pems'            
            load('data/processed_data_pems.mat');

            % make it symmetric
            G = gsp_graph(G.W + G.W', G.coords); 
            G = gsp_compute_fourier_basis(G);
            
            % split to days
            N = G.N;
            T = 24*60/signal.samples_per_minute; 
            R_all = size(signal.X,2)/T;
            X = reshape(signal.X, G.N, T, R_all); 
            
            % skip weekend
            X = X(:,:,[1 4 5 6]); R_all = size(X,3);
            
            L_all = round(T/2); F_all = 75; 
                         
            switch xaxis,
                case 'miss',
                    train_prc = [1 2 3]./4;                  % percentage of data used for training (in days)
                    miss_prc  = linspace(0.05, 0.95, 19);    % percentage of values missing from the test data
                case 'train'
                    error('not relevant for pems');
            end
    end
    
    if verbose,
        param_plot.show_edges = 0;
        param.colorbar = 1;
        param.colormap = brewermap(100, 'RdYlGn');
        param_plot.step = 1;
        param_plot.speed = 0.1;
        gsp_plot_jtv_signal(G, X(:,:,1), param_plot)
    end
    
    %% -------------------------------------------------------------------------
    % Recovery problem
    % -------------------------------------------------------------------------
    clc;
    
    Z     = cell(numel(train_prc), numel(methods), numel(L_all), numel(F_all), R_all);
    error = nan(numel(train_prc),  numel(methods), numel(L_all), numel(F_all), R_all);
    PSD   = cell(numel(train_prc), numel(methods), numel(L_all), numel(F_all));
    
    for train_prcIdx = 1:numel(train_prc),
        
        fprintf('- Training percentage %d out of %d\n', train_prcIdx, numel(train_prc));
        
        R_train = floor(train_prc(train_prcIdx)*R_all);
        R_test  = R_all - R_train;
        
        X_train = X(:,:,1:R_train);
        X_test = X(:,:,R_train+1:end);
        
        T_train = T; T_test = T; 
        
        G_train = gsp_jtv_graph(G, T_train, 1);
        G_test  = gsp_jtv_graph(G, T_test, 1);
        C_train = gsp_compute_fourier_basis(gsp_ring(T_train, 1));
        C_test  = gsp_compute_fourier_basis(gsp_ring(T_test, 1));
        
        for LIdx = 1:numel(L_all),
            for FIdx = 1:numel(F_all),
                
                % -------------------------------------------------------------
                % PSD estimation phase
                % -------------------------------------------------------------
                
                param.Nfilt = min(F_all(FIdx), N);
                param.L     = min(L_all(LIdx), T_train);
                
                for mIdx = 1:numel(methods),
                    
                    psd = nan;
                    switch methods{mIdx},
                        
                        case {'joint'} % JWSS 
                            param.estimator = 'TVA';
                            [psd, ft] = gsp_estimate_vertex_time_psd(G_train, X_train, param);
                            psd = gsp_jtv_filter_evaluate(psd, ft, G_train.e, G_train.jtv.omega);
                            
                        case {'graph'} % GWSS
                            if FIdx == 1 && LIdx == 1,
                                CovV = gsp_stationarity_cov(reshape(X_train, N, T_train*R_train));
                                psd = gsp_experimental_psd(G_train, CovV);
                                % psd = repmat(gsp_filter_evaluate(psd, G_train.e), 1, T_train);
                            end
                            
                        case {'time'} % TWSS
                            if FIdx == 1 && LIdx == 1,
                                CovT = gsp_stationarity_cov(reshape(permute(X_train,[2,1,3]), T_train, N*R_train));
                                psd = gsp_experimental_psd(C_train, CovT);
                                % psd = repmat(gsp_filter_evaluate(psd, C_train.e)', N, 1);
                            end
                            
                        case {'ts'} % MTWSS
                            if FIdx == 1 && LIdx == 1,
                                % Windows used to average the psd
                                w = gabwin('itersine', 5, 10); 
                                % Carefull: this takes long. Only use it with small N,T
                                C_ts = covariance_matrix_estimation_stationary_process(X_train, w);
                                psd = C_ts; 
                            end
                            
                        otherwise
                            warning('unknown method.. attempting to continue.');
                    end
                    
                    % normalize PSD
                    if isnumeric(psd) & ~strcmp(methods{mIdx}, 'ts')
                        
                        % normalize energy
                        psd = abs(psd);
                        tmp = 0;
                        for r = 1:R_train,
                            tmp = tmp + (trace(X_train(:,:,r)*X_train(:,:,r)'))/R_train;
                        end
                        psd = psd *  tmp / sum(psd(:)) ;

                        if verbose,
                            figure; param_plot.logscale = 1;
                            subplot(1,2,1); gsp_plot_jft(G_train, psd, param_plot);
                            subplot(1,2,2); gsp_plot_jft( G_train, gsp_jft(G_train, X_train(:,:,1)).^2, param_plot);
                            drawnow;
                        end
                    end
                    
                    PSD{train_prcIdx, mIdx, LIdx, FIdx} = psd;
                end
            end
        end
        
        %% ---------------------------------------------------------
        % Recovery phase
        % ---------------------------------------------------------        
        for miss_prcIdx = 1:numel(miss_prc),
            imiss_prc = miss_prc(miss_prcIdx);
            
            fprintf('  - Miss percentage %d out of %d\n', miss_prcIdx, numel(miss_prc));
            % set up the recovery problem
            M = ones(N, T_test, R_test);
            for r = 1:R_test
                switch mode,
                    case 'entries'
                        M(:,:,r) = rand(N,T_test) <= (1-imiss_prc);
                end
            end
            sigma = 0.0; psdnoise = sigma.^2;
            
            Y_test = M.*X_test + sigma*randn(N, T_test, R_test);
            
            paramsolver.maxit = 1000;
            paramsolver.tol   = 1e-10;
            
            for LIdx = 1:numel(L_all),
            for FIdx = 1:numel(F_all),
            for mIdx = 1:numel(methods),
                        
                psd = PSD{train_prcIdx, mIdx, LIdx, FIdx};
                
                switch methods{mIdx},
                    
                    case 'time'
                        if FIdx == 1 && LIdx == 1,
                            for r = 1:R_test,
                                iZ = real(gsp_wiener_inpainting(C_test, Y_test(:,:,r)', M(:,:,r)', psd, psdnoise, paramsolver))';
                                
                                error(miss_prcIdx, train_prcIdx, mIdx, LIdx, FIdx, r) = norm(iZ - X_test(:,:,r), 'fro') / norm(X_test(:,:,r), 'fro');
                                Z{miss_prcIdx, train_prcIdx, mIdx, LIdx, FIdx, r} = iZ;
                            end
                        end
                        
                    case 'graph'
                        if FIdx == 1 && LIdx == 1,
                            % psd_test = gsp_jtv_interpolate_psd(psd, G_test.jtv.T);
                            for r = 1:R_test,
                                
                                % iZ = real(gsp_jtv_wiener_inpainting(G_test, Y_test(:,:,r), M(:,:,r), psd_test, psdnoise, paramsolver));
                                iZ = real(gsp_wiener_inpainting(G_test, Y_test(:,:,r), M(:,:,r), psd, psdnoise, paramsolver));
                                
                                error(miss_prcIdx, train_prcIdx, mIdx, LIdx, FIdx, r) = norm(iZ - X_test(:,:,r), 'fro') / norm(X_test(:,:,r), 'fro');
                                Z{miss_prcIdx, train_prcIdx, mIdx, LIdx, FIdx, r} = iZ;
                            end
                        end
                        
                    case 'ts'
                        C_ts = psd;
                        if FIdx == 1 && LIdx == 1,
                            for r = 1:R_test,
                                iZ = grm_js_estimator(C_ts, M(:,:,r), Y_test(:,:,r), psdnoise);
                                
                                error(miss_prcIdx, train_prcIdx, mIdx, LIdx, FIdx, r) = norm(iZ - X_test(:,:,r), 'fro') / norm(X_test(:,:,r), 'fro');
                                Z{miss_prcIdx, train_prcIdx, mIdx, LIdx, FIdx, r} = iZ;
                            end
                        end
                        
                    case 'joint'
                        
                        % psd_test = gsp_jtv_interpolate_psd(psd, G_test.jtv.T);
                        psd_test = psd;
                        
                        for r = 1:R_test,
                            
                            iZ = real(gsp_jtv_wiener_inpainting(G_test, Y_test(:,:,r), M(:,:,r), psd_test, psdnoise, paramsolver));
                            
                            error(miss_prcIdx, train_prcIdx, mIdx, LIdx, FIdx, r) = norm(iZ - X_test(:,:,r), 'fro') / norm(X_test(:,:,r), 'fro');
                            Z{miss_prcIdx, train_prcIdx, mIdx, LIdx, FIdx, r} = iZ;
                        end
                end
            end
            end
            end
        end
    end
   
    if overwrite,    
        save(sprintf('results/recovery_%s_%s_%s.mat', dataset, mode, xaxis));
    end
else
    load(sprintf('results/recovery_%s_%s_%s.mat', dataset, mode, xaxis));
end

%% -------------------------------------------------------------------------
% Visualization
% -------------------------------------------------------------------------
overwrite = 0;
figure; set(gcf, 'Color', [1 1 1], 'Position', [100 200 750 230]); hold on; 

colors = brewermap(numel(methods),'Set1')*0.9;
markers = {'-o' '--s' '--d' '-^'};
fsize  = 12;
methodNames = {'joint' 'time' 'graph' 'time (full)'}; % choices: 'joint' 'time' 'graph'  

switch xaxis,
    case 'train'
        for miss_prcIdx = 2:numel(miss_prc),
        
%             subplot(numel(miss_prc),1,miss_prcIdx); 
            hold on; set(gca, 'YScale', 'log', 'Fontsize', fsize, 'YTick', 10.^[-2:1:3], 'XTick', [0.1:0.1:1]); 
            for LIdx = 1:numel(L_all),
                for FIdx = 1:numel(F_all),
                    
                    for mIdx = 1:numel(methods),
                        ierror = squeeze(error(miss_prcIdx, :, mIdx, LIdx, FIdx, :));
                        ierror = mean(ierror, 2); %squeeze(median(error(miss_prcIdx, :, mIdx, LIdx, FIdx, r), 6));
                        if (sum(ierror)==0), continue; end
                        name = sprintf('%s', methodNames{mIdx}); %, L_all(LIdx), F_all(FIdx));
                        
                        plot(train_prc, ierror, markers{mIdx}, 'DisplayName', name, 'Color', colors(mIdx,:), 'MarkerFaceColor', colors(mIdx,:))
                    end
                end
            end
            
            xlim([0 1]); ylim([0.005 1])
            xlabel('percentage of data used for training'); ylabel('error')
            legend1 = legend('show');
            set(legend1,'EdgeColor',[1 1 1], 'Fontsize', fsize, 'Location', 'Best');
        end
        
    case 'miss'
        for train_prcIdx = 1, %numel(train_prc),
            
            % subplot(numel(train_prc),1,train_prcIdx); 
            hold on; set(gca, 'YScale', 'linear', 'Fontsize', fsize); 

            for LIdx = 1:numel(L_all),
                for FIdx = 1:numel(F_all),
                    
                    for mIdx = 1:numel(methods),
                        ierror = squeeze(error(:, train_prcIdx, mIdx, LIdx, FIdx, :));
                        ierror = mean(ierror, 2);
                        if (sum(ierror)==0), continue; end
                        name = sprintf('%s', methodNames{mIdx}); %, L_all(LIdx), F_all(FIdx));
                        
                        plot(miss_prc, ierror, markers{mIdx}, 'DisplayName', name, 'Color', colors(mIdx,:), 'MarkerFaceColor', colors(mIdx,:), 'MarkerSize', 5)
                        
                    end
                end
            end
            set(gca, 'XTick', [0:0.1:1]);
            
            % title(sprintf('%2d percent of data used in training', round(train_prc(train_prcIdx)*100)));
            xlabel(['percentage of missing ' mode '' ]); ylabel('error')
            legend1 = legend('show'); 
            switch dataset
                case 'molene'
                    ylim([0 1]); set(gca, 'YTick', [0.0:0.2:1]); 
                case 'epidemic'
                    ylim([0 0.65]); set(gca, 'YTick', [0.0:0.2:1]); 
                case 'pems'
                    ylim([0 0.65]);
            end
            set(legend1,'EdgeColor',[1 1 1], 'Fontsize', fsize, 'Location', 'NorthWest');
        end
end

if overwrite, 
    path = sprintf('results/recovery_%s_%s_%s.pdf', dataset, mode, xaxis);
    export_fig(path);
end


