function [ T_infect, S, I, R ] = agsp_SIRS( A, p, T, seeds, params )
%SIRS Simulates an epidemic spread according to the SIR model.
%   A is the graph adjacency matrix,
%   p is the contagion probability,
%   T is the length of the infectious period, and
%   seeds are the vertices infected at time zero (patient zero)

n = size(A, 1);

if not(isfield(params, 'population')),   params.population = 1;   end
if not(isfield(params, 'maxTime')), params.maxTime = round(1000);   end
if not(isfield(params, 'model')),   params.model = 'SIR';   end
if not(isfield(params, 'immunity')),   params.immunity = 1;   end

I = zeros(n, params.maxTime+1); I(seeds, 1) = 1;
S = ~I;
R = zeros(n, params.maxTime+1);

% scale up
As = kron(A, ones(params.population));
ns = n * params.population;
Ss = kron(S, ones(params.population,1));
Is = kron(I, ones(params.population,1));
Rs = kron(R, ones(params.population,1));

% time of infection
Ts_infect = nan(ns,1); Ts_infect(logical(Is)) = 0;
Ts_recover = nan(ns,1);
Tr =  params.immunity;
% run the model
for t = 1:params.maxTime,
    
    switch params.model,
        case 'SI'
            Inew = (As*Is(:,t)*p > rand(ns,1)) .* (Is(:,t)==0) ;
            Rnew = ((t-Ts_infect)>T) .* (Is(:,t)==1) .* (Rs(:,t)==0);
            
            Ss(:,t+1) = Ss(:,t) - Inew + Rnew;
            Is(:,t+1) = Is(:,t) + Inew - Rnew;
            
        case 'SIR',
            Inew = (As*Is(:,t)*p > rand(ns,1)) .* (Is(:,t)==0) .* (Rs(:,t)==0);
            Rnew = ((t-Ts_infect)>T) .* (Is(:,t)==1) .* (Rs(:,t)==0);
            
            Ss(:,t+1) = Ss(:,t) - Inew;
            Is(:,t+1) = Is(:,t) + Inew - Rnew;
            Rs(:,t+1) = Rs(:,t) + Rnew;
        case 'SIRS',
            Inew = (As*Is(:,t)*p > rand(ns,1)) .* (Is(:,t)==0) .* (Rs(:,t)==0);
            Rnew = ((t-Ts_infect)>T) .* (Is(:,t)==1) .* (Rs(:,t)==0);
            Snew = ((t-Ts_recover)>Tr).*(Rs(:,t)==1);
            
            Ss(:,t+1) = Ss(:,t) - Inew + Snew;
            Is(:,t+1) = Is(:,t) + Inew - Rnew;
            Rs(:,t+1) = Rs(:,t) + Rnew - Snew;
    end
    
    Ts_infect(find(Inew)) = t;
    Ts_recover(find(Rnew)) = t;
    
    if sum(Ss(:,t) + Is(:,t) + Rs(:,t)) ~= ns,
        error('invalid state');
    end
    
    if max(isnan(Ts_infect)) == 0 ...
            || sum(Is(:,t+1)) == ns ...
            || sum(Rs(:,t+1)) == ns ...
            || sum(Is(:,t+1) + Rs(:,t+1)) == ns,
        break;
    end
end

% handle early termination
Ss = Ss(:,1:t);
Is = Is(:,1:t);
Rs = Rs(:,1:t);

Ts_infect(isnan(Ts_infect)) = params.maxTime;

% scale down
M = kron(eye(n),ones(1,params.population))/params.population;
S        = M*Ss;
I        = M*Is;
R        = M*Rs;
T_infect = M*Ts_infect;

end