%% Denoising of dynamic mesh using joint time-vertex Tikhonov and comparison with time and graph denoising
%
%   In this demo, we perform denoising of a dynamic mesh representing a dog
%   walking using a joint Tikhonov approach and we compare with time- and
%   graph-only denoising. Best regularization parameters are obtained
%   through exhaustive search.
%   Dataset can be found at http://research.microsoft.com/en-us/um/redmond/events/geometrycompression/data/default.html
%

% Author: Francesco Grassi
% Date: April 2017

clear
close all
clc

gsp_start
init_unlocbox

load dog.mat

%% Signal
X1 = X(:,:,1) - repmat(mean(X(:,:,1),1),[N,1]);
X2 = X(:,:,2) - repmat(mean(X(:,:,2),1),[N,1]);
X3 = X(:,:,3) - repmat(mean(X(:,:,3),1),[N,1]);

X = cat(3,X1,X2,X3);

%% Parameters
verbose = 1;
param.k = 5;
param.transform = 'dct';
param_solver.verbose = verbose;

itermax = 20;

err = @(x,y) norm(vec(x)-vec(y),'fro')/norm(vec(x),'fro');

%%

for iter=1:itermax
    % Add noise to vertices position
    noise = randn(size(X));
    noise = 0.2 * noise * norm(X(:)) / norm(noise(:));
    Xn = X + noise;
    
    % Graph
    x0 = squeeze(mean(Xn,2));
    G = gsp_nn_graph(x0,param);
    G = gsp_compute_fourier_basis(G);
    G = gsp_jtv_graph(G,size(X,2),[],param);
    
    % Time only
    TAUt = linspace(1,6,10);
    
    for ii = 1:length(TAUt)
        tau1 = TAUt(ii);
        
        
        param_solver.maxit  = 100;
        
        f1.beta = 2;
        
        f2.prox = @(x,T) prox_l2grad(x',T*tau1)';
        f2.grad = @(x) tau1*2*x*G.jtv.LT';
        f2.beta = 2*4*tau1;
        f2.eval = @(x) tau1*trace(x*G.jtv.LT*x');
        
        param_fid.y = Xn(:,:,1);
        f1.grad = @(x) 2*(x-param_fid.y);
        f1.eval = @(x) norm(x-param_fid.y,'fro').^2;
        Y1 = solvep(param_fid.y,{f1 f2},param_solver);
        
        param_fid.y = Xn(:,:,2);
        f1.grad = @(x) 2*(x-param_fid.y);
        f1.eval = @(x) norm(x-param_fid.y,'fro').^2;
        Y2 = solvep(param_fid.y,{f1 f2},param_solver);
        
        param_fid.y = Xn(:,:,3);
        f1.grad = @(x) 2*(x-param_fid.y);
        f1.eval = @(x) norm(x-param_fid.y,'fro').^2;
        Y3 = solvep(param_fid.y,{f1 f2},param_solver);
        
        Y_time = cat(3,Y1,Y2,Y3);
        
        errtime{iter}(ii) = err(X,Y_time);
        
    end
    
    % Graph only
    TAUg = linspace(1,20,10);
    
    for ii = 1:length(TAUg)
        tau2 = TAUg(ii);
        
        param_solver.maxit  = 100;
        f1.beta = 2;
        
        f3.grad = @(x) tau2*2*G.L*x;
        f3.beta = 2*G.lmax*tau2;
        f3.eval = @(x) tau2*trace(x'*G.L*x);
        
        param_fid.y = Xn(:,:,1);
        f1.grad = @(x) 2*(x-param_fid.y);
        f1.eval = @(x) norm(x-param_fid.y,'fro').^2;
        Y1 = solvep(param_fid.y,{f1 f3},param_solver);
        
        param_fid.y = Xn(:,:,2);
        f1.grad = @(x) 2*(x-param_fid.y);
        f1.eval = @(x) norm(x-param_fid.y,'fro').^2;
        Y2 = solvep(param_fid.y,{f1 f3},param_solver);
        
        param_fid.y = Xn(:,:,3);
        f1.grad = @(x) 2*(x-param_fid.y);
        f1.eval = @(x) norm(x-param_fid.y,'fro').^2;
        Y3 = solvep(param_fid.y,{f1 f3},param_solver);
        
        Y_graph = cat(3,Y1,Y2,Y3);
        errgraph{iter}(ii) = err(X,Y_graph);
    end

    
    % time-vertex
    
    TAUj1 = linspace(0.1,3,7);
    TAUj2 = linspace(0.1,5,7);
    for ii = 1:length(TAUj1)
        for jj = 1:length(TAUj2)
            tau1 = TAUj1(ii);
            tau2 = TAUj2(jj);
            
            param_solver.maxit  = 200;
            f1.beta = 2;
            
            f2.prox = @(x,T) prox_l2grad(x',T*tau1)';
            f2.grad = @(x) tau1*2*x*G.jtv.LT';
            f2.beta = 2*4*tau1;
            f2.eval = @(x) tau1*trace(x*G.jtv.LT*x');
            
            f3.grad = @(x) tau2*2*G.L*x;
            f3.beta = 2*G.lmax*tau2;
            f3.eval = @(x) tau2*trace(x'*G.L*x);
            
            
            param_fid.y = Xn(:,:,1);
            f1.grad = @(x) 2*(x-param_fid.y);
            f1.eval = @(x) norm(x-param_fid.y,'fro').^2;
            Y1 = solvep(param_fid.y,{f1 f2 f3},param_solver);
            
            param_fid.y = Xn(:,:,2);
            f1.grad = @(x) 2*(x-param_fid.y);
            f1.eval = @(x) norm(x-param_fid.y,'fro').^2;
            Y2 = solvep(param_fid.y,{f1 f2 f3},param_solver);
            
            param_fid.y = Xn(:,:,3);
            f1.grad = @(x) 2*(x-param_fid.y);
            f1.eval = @(x) norm(x-param_fid.y,'fro').^2;
            Y3 = solvep(param_fid.y,{f1 f2 f3},param_solver);
            
            Y_joint = cat(3,Y1,Y2,Y3);
            
            errjoint{iter}(ii,jj) = err(X,Y_joint);
        end
    end
    
end


%% Show results
Y = Y_joint;
t0 = 1;
figure(   'Position',[81         400        1099         184]);


param.view = [0 90];
subplot(231)
scatter3(X(:,t0,1),X(:,t0,2),X(:,t0,3),'k.')
axis on
axis([min(vec(X(:,t0,1))) max(vec(X(:,t0,1))) min(vec(X(:,t0,2))) max(vec(X(:,t0,2))) min(vec(X(:,t0,3))) max(vec(X(:,t0,3))) ])
view(param.view)

subplot(232)
scatter3(Xn(:,t0,1),Xn(:,t0,2),Xn(:,t0,3),'k.')
axis on
axis([min(vec(Xn(:,t0,1))) max(vec(Xn(:,t0,1))) min(vec(Xn(:,t0,2))) max(vec(Xn(:,t0,2))) min(vec(Xn(:,t0,3))) max(vec(Xn(:,t0,3))) ])
view(param.view)

subplot(233)
scatter3(Y(:,t0,1),Y(:,t0,2),Y(:,t0,3),'k.')
axis on
axis([min(vec(Y(:,t0,1))) max(vec(Y(:,t0,1))) min(vec(Y(:,t0,2))) max(vec(Y(:,t0,2))) min(vec(Y(:,t0,3))) max(vec(Y(:,t0,3))) ])
view(param.view)


errresult = [errtime;errgraph;errjoint];
minerr = cell2mat(cellfun(@(x) min(vec(x)),errresult,'uniformoutput',0));
averagejoint = mean(reshape(cell2mat(errjoint),numel(TAUj1),numel(TAUj2),itermax),3);

subplot(234)
imagesc(TAUj1,TAUj2,averagejoint)
caxis([min(averagejoint(:)) max(averagejoint(:))])
xlabel('\tau_1','Fontsize',11)
ylabel('\tau_2','Fontsize',11)
axis xy
colormap gray


subplot(236)
boxplot(minerr','colors',[0 0 0],'sym','k*')
set(gca,'XTickLabel',{'Time-only','Graph-only','Joint'})
ylabel('Minimum relative error','Fontsize',11)

%% Summary statistics
figure
boxplot([errtime{1}';errgraph{1}';vec(errjoint{1})],[ones(numel(TAUt),1); 2*ones(numel(TAUg),1); 3*ones(numel(TAUj1)*numel(TAUj2),1)],'colors',[0 0 0],'sym','k*')
set(gca,'XTickLabel',{'Time-only','Graph-only','Joint'})
ylabel('Relative error','Fontsize',11)
set(gca,'Fontsize',11)