%% tetR data correction for Vipul Singhal's Thesis
% Perform Calibration-Correction on the constitutive GFP data and tetR
% repression data. 
%
% MODELS: 
% Characterization model:
% D_G + E <-> D_G:E -> D_G + E + G 
%
% Calibration model:
% D_T + E <-> D_T:E -> D_T + E + T
% D_G + E <-> D_G:E -> D_G + E + G
% 2 T <-> T2
% D_G + T2 <-> D_G:T2
%
% (c) Vipul Singhal, Caltech 2018

close all
clear all
clc


%% Initialize path and set simulation options

% add the path of the files needed to run the simulations. 
st = dbstack('-completenames');
fp = st(1).file;
slashes = regexp(fp, '/');
projdir = fp(1:slashes(end)-1);
addpath(genpath(projdir));

% Use saved results to generate plots or redo the estimations. 
usesaved = false;

% use parallel computing? true or false
parallalflag = false;

% How many steps to run the mcmc simulation for 
nsteps = 6e4;
thinning = 10;
nIter = 10;
% number of walkers to use
nW = 600;

% step size for the mcmc
stepsz = 1.5;

% parameter lower and upper bounds in log space. (all parameters)
lb =-8; ub = 8;

% mcmc noise model standard deviation and tightening
stdev = 10;
tightening = 2;

%% Collect experimental data
[~ , calibration_data, dosevals_calib] = import_ptetconstitutive;
[tvec, correction_data, dosevals_corr] = import_tetR_repression; 

%% Set up calibration step model

%%%%%%%%%% CALIBRATION STEP MODEL %%%%%%%%%%%%
% set up a function that takes a parameter point, an initial species
% concentration and a vector of time points and returns the simulation
% trajectories. 
model_calib = @(logp, sp0, tspan)...
    ode15s(@(t,sp) constitutive_gfp3(t,sp,logp), tspan, sp0);

% The pmap_calib cell array contains vectors of indices of the extract
% specific parameters, extract specific species initial concentration
% (another type of extract specific parameter), and the circuit specific
% parameters. These indices are the indices in the 'logp' and 'sp0' arrays
% in the model_calib function. 

% extract specific parameter (ESP)
espix = 1; % kc rate parameter in logp in model_calib (ie, constitutive_gfp3)

% circuit specific parameter (CSP)
cspix = 2:3; % kfP and krP rate parameters in logp

% Extract specific species initial concentration parameter (ESSP)
esspix = 2; % teh enzyme initial concentration (species "P" in model)

% A cell array of these index vectors. 
pmap_calib = {espix, esspix, cspix };

% number of extracts. In this case nEnv = 3. 
nEnv = size(calibration_data, 4);

% number of calibration parameters, ESPs, CSPs and ESSPs)
% This sets the length of logpjoint, (which, in the mcmc_simbio toolbox, is
% called the masterVector - this code was written before that toolbox, and
% indeed inspired that toolbox)
% 
% Note the structure of logpjoint: 
% 
% logpjoint = [E1 ESPs, E1 ESSPs, E2 ESPs, E2 ESSPs, E3 ESPs, E3 ESSPs, CSPs]
nparam_calib = nEnv*(length(espix) + length(esspix)) + length(cspix); 

% Note that pmap_calib's elements (espix, cspix, esspix) give the indices 
% that allow us to distribute the elements of logpjoint to the respective
% models: 
% 
% MODEL for circuit in E1: 
% 
% logp(espix) = logpjoint(1:length(espix));
%
% total_non_CSPs = nEnv*(length(espix)+ length(esspix));
% logp(cspix) = logpjoint(total_non_CSPs:nparam_calib);
% 
% startESSP1 = (length(espix)+1);
% endESSP1 = (length(espix)+ length(esspix));
% sp0(esspix) = logpjoint(startESSP1:endESSP1);
%
% With similar prescriptions for the model in the remaining extracts. 
% 
% Note that the remaining species initial concentrations are either set by
% the dosing (described below) or are set to 0. 

% Set the species to be dosed and to be measured. 

nSp_calib = 4; % total number of species in the model

% index (in sp0) of the species to be dosed. This is the initial DNA conc
% in our case. 
dosemap_calib = 1; % the 1st species is GFP dna, D_G. This gets dosed. 

% index (in sp0) of the species to compare to experimental data. 
idMS_calib = 4; % the GFP species, G is the 4th species in the model. 

%% Define likelihood and prior functions for calibration step
% 
% 
lognormvec=@(res,sig) -(res./sig).^2 -log(sqrt(2*pi)).*sig;

% The function that computes the log likelihood of the model parameters
% given the model, a parameter point to evaluate the model at and the
% corresponding data set. 
loglike_calib = @(logpjoint) log_likelihood_sharedCSP(model_calib,...
    nSp_calib, logpjoint, pmap_calib, tvec(1:13),...
    1000*calibration_data(1:13,:,:,:), dosevals_calib,...
    dosemap_calib, idMS_calib,lognormvec, stdev/tightening);

% The prior function
logprior = @(logp) all(lb < logp) && all(logp < ub); 

%% Perform the calibration step Bayesian inference
% All the ESPs and CSPs in the calibration models get estimated, with the CSPs
% shared across the models. 
% Initialize the parameters in a latin hypercure defined by the parameter
% ranges. 
mdpts = (lb+ub)/2;
width = abs(lb)+ abs(ub);
lhsamp = width*(lhsdesign(nW, nparam_calib)-0.5); 
minit=bsxfun(@plus,mdpts,lhsamp'); % set of initial parameter points. 
% run the initial burn in simulation
tic
[m, lPburnin] =gwmcmc_vse(minit,{logprior loglike_calib},nsteps,...
    'StepSize',stepsz , ...
    'ThinChain',thinning,...
    'Parallel', parallalflag);%
toc
datestring = datestr(now, 'yyyymmdd_HHMMSS');
%%
% Run the actual MCMC simulation. We break up the mcmc into 10 iterations, 
% where each iteration uses nsteps number of model simulation evaluations. 
% Each iteration uses the last sets of positions of the walkers as initial
% points, so that the 10 iterations form a continuous mcmc run. 
calib_SD  = cell(nIter,1);
for i = 1:nIter
minit = m(:,:,end);
clear m

tic
[m, lp]=gwmcmc_vse(minit,{logprior loglike_calib},nsteps,...
    'ThinChain',thinning,...
    'Parallel', parallalflag,...
    'StepSize', stepsz);
toc
% save the data in a .mat file, using the datestring of the simulation run.

svstr = ['t015_calib_' datestring '_' num2str(i) ];
save(svstr);
 
 calib_SD{i} = svstr;
 pause(450)
end

%% Set up correction step 1 model
%%%%%%%%%%% TEST (CORRECTION STEP) MODEL %%%%%%%%%%%%%
model_corr = @(logp, sp0, tspan) ...
    ode15s(@(t,sp) tetR_repression(t,sp,logp), tspan, sp0);

% Similar to the calibration model, the ESPs, ESSPs and CSPs here are as
% defined in the arrays: 
espix = 1;
cspix = 2:9;
esspix = 2; 
nEnv = size(correction_data, 4);
nparam_corr = nEnv*(length(espix)+ length(esspix))+ length(cspix); 
%length of logpjoint

pmap_corr = {espix, esspix, cspix};
nSp_corr = 9; 
dosemap_corr = [6,1]; 
idMS_corr = 9;
%% Define likelihood and prior functions for correction step 1
lognormvec=@(res,sig) -(res./sig).^2 -log(sqrt(2*pi)).*sig;

loglike_corr = @(logpjoint) log_likelihood_sharedCSP(model_corr,...
    nSp_corr, logpjoint, pmap_corr, tvec(1:13),...
    1000*correction_data(1:13,:,:,:), dosevals_corr,...
    dosemap_corr, idMS_corr,lognormvec, stdev/tightening);

logprior = @(logp) all(lb < logp) && all(logp < ub); 

%% Perform the correction step 1 Bayesian inference
% ie, fix the ESP and the ESSP in the correction model and estimate the CSP
% for the correction model 
% we do the estimation using data and ESP parameters from extract number 3, eSG. 

eval(['load(''t015_calib_' datestring '_' num2str(nIter) ''', ''m'');']) 
mstacked_calib = m(:,:)';

[msorted_calib, sortIX]=sort(mstacked_calib(:, 1));
% note that msorted_calib = mstacked_calib(sortIX, 1)

medianIX = sortIX(round(length(sortIX)/2));

calib_point = mstacked_calib(medianIX, :);

kc_calib = calib_point(5);
P_calib = calib_point(6); 

% the correction step 1 model involves setting the ESPs for the third
% extract to values obtained from the calibration step, and estimating the
% CSPs. 
loglike_corr1 = @(logp_corrcsp) log_likelihood_sharedCSP(model_corr,...
    nSp_corr, [kc_calib; P_calib; logp_corrcsp], pmap_corr, tvec(1:13), ...
    1000*correction_data(1:13,:,:,3), dosevals_corr, dosemap_corr, idMS_corr,...
    lognormvec, stdev/tightening);


% there are 8 parameters in the first correction step: the 8 CSPs in the
% model model_corr (which actually calls tetR_repression.m)
nparam_corr1 = 8;
lhsamp = width*(lhsdesign(nW, nparam_corr1)-0.5); 
minit=bsxfun(@plus,mdpts,lhsamp');
%%
% burn in phase
tic
[m, lP] =gwmcmc_vse(minit,{logprior loglike_corr1},nsteps,...
    'StepSize',stepsz , ...
    'ThinChain',thinning, 'Parallel', parallalflag );%
toc

corr_SD = cell(nIter,1);


% Run the actual MCMC simulation 
for i = 1:nIter
    
minit = m(:,:,end);
clear m

tic
[m, lp]=gwmcmc_vse(minit,{logprior loglike_corr1},nsteps,...
    'ThinChain',thinning,...
    'Parallel', parallalflag,...
    'StepSize', stepsz);
toc
svstr = ['t015_corr1_' datestring '_' num2str(i) ];
save(svstr);
corr_SD{i} = svstr;
pause(450)

end

%% Plot things (chains, posterior distributions, fits and predictions)

% test015_plot_everything

%% Helper script for test015_estimation_v1.m
% (c) Vipul Singhal, 
% California Institute of Technology, 2018
% 

% 
st = dbstack('-completenames');
fp = st(1).file;
slashes = regexp(fp, '/');
projdir = fp(1:slashes(end)-1);
addpath(genpath(projdir));

% parameter legends for plotting parameter posterior distributions later
legends = {'kc1'    'P1'  'kc2'    'P2'   'kc3'    'P3', 'kfP', 'krP'};

% Concatenate the parameter arrays drawn from the .mat files. 
mcat = catMC(calib_SD);

% Plot the MCMC chains (for 1/10 of the walkers for easy vidualization)
plotChains(mcat(:,1:10:end,:), nW, legends );

%%
%
%%
% Plot pairwise projections of the joint posterior distribution. 

figure
ecornerplot_vse(mcat(:,:,round(end/1.5):end),'scatter', true,...
    'transparency',0.025,...
    'color',[.6 .35 .3], 'names', legends);

%% Visualize the resulting 'fits' of the model to the data
% Overall idea: Pick 500 points from the parameter posterior distribution,
% generate trajectories, then take means, medians, standard deviations etc.

% note that mstacked_calib comes from the final iteration of the
% calibration step MCMC runs. In general this is as "converged" as we have
% in the set of points that we have. 


nptstotal = size(mstacked_calib, 1); 
npts = min([500, round(nptstotal/5)]);

% get the point whose parameters were used in correction step 1
calib_point = mstacked_calib(medianIX, :);

% get arbitrary point
rng(42)
arbitrary_point_IX = randperm(nptstotal, 1);
arbitrary_point = mstacked_calib(arbitrary_point_IX, :);

% get a set of point to generate the calibration trajectory fits
paramid = randperm(size(mstacked_calib,1), npts);
params_to_use_calib = mstacked_calib(paramid, :);

% note that for the ESPs used in correction step 1, we used: 
% kc_calib = calib_point(5); 
% P_calib = calib_point(6); 

% %%%
envname = {'VS', 'MP', 'SG'};
% initialize things for the simulation and plotting. 

nMS = size(calibration_data, 2); % nMS = 1 here, since only GFP is measured
nICs = size(dosevals_calib,2); % ICs, GFP DNA = [1 2 5 10 20] (nM)

nEnv = size(calibration_data, 4);

espIX = pmap_calib{1};
esspIX = pmap_calib{2};
cspIX = pmap_calib{3};

nESP = length(espIX); % the ESP indices in the model (not in logpjoint)
nESSP = length(esspIX);
% the Env specific species indices in the model (not in logpjoint)
nCSP = length(cspIX); % the CSP indices in the model (not in logpjoint)

icvec = zeros(nSp_calib, 1);

simulatedtraj = zeros(length(tvec(1:13)),nMS, nICs , npts, nEnv);
maxGFP_sd = 0;

% %%%
% simulate the calibration model for all the randomly picked points from
% the posterior distribution
for kk = 1:npts
    logpjoint = params_to_use_calib(kk, :);
    cspindices = ((nESSP + nESP)*nEnv+1):length(logpjoint);
    paramvec = zeros(nESP+nCSP, 1);
    logpcsp = logpjoint(cspindices);
    paramvec(cspIX) = logpcsp;
    
    for envid = 1:nEnv
        espindices = (envid-1)*(nESP+nESSP) + (1:nESP);
        logpesp = logpjoint(espindices);
        paramvec(espIX) = logpesp;
        esspindices = ((envid-1)*(nESP+nESSP) + nESP + 1):envid*(nESP+nESSP);
        
        % set the values of the initial condition vector to the parameters
        icvec(esspIX) = exp(logpjoint(esspindices));
        
        for doseID = 1:nICs
            icvec(dosemap_calib) = dosevals_calib(:, doseID);
            
            % simulate the model
            [~, simudata] = model_calib(paramvec, icvec, tvec(1:13));
            for msid = 1:nMS
                simulatedtraj(:,msid, doseID, kk, envid) = ...
                    simudata(:, idMS_calib(msid));
            end
        end
    end
end

% %%%
% Compute the means and standard deviations 
meanvals = mean(simulatedtraj, 4); 
sdvals= std(simulatedtraj,0, 4);
maxvals = squeeze(max(max(max(meanvals+sdvals,[], 1), [], 3), [], 5)); 
% 1 by nMS array. 
lineStyles = linspecer(nICs,'sequential');
hd = zeros(nICs, 1); % data trajectory handles
hm = zeros(nICs, 1); % model fit mean trajectory handles
hsd = zeros(nICs, 1); % model fit sd trajectory handles (patch objects)
% %%%

hd = zeros(nICs, 1); % data trajectory handles
hm = zeros(nICs, 1); % model fit mean trajectory handles
hsd = zeros(nICs, 1); % model fit sd trajectory handles (patch objects)

for msid = 1:nMS
    figure
    ss = get(0, 'screensize');
    set(gcf, 'Position', [50 100 ss(3)/1.6 ss(4)/2.3]);
    
    for j= 1:nEnv
        subplot(1, nEnv,j);
        for i = 1:nICs
            linearidx = nEnv*(i-1)+j; 
            % each index correcponds to a dose environment pair. (ie, each 
            % line and patch have a common index)
            hd(i)=plot(tvec(1:13)/3600,1000*calibration_data(1:13,msid, i,j ),...
                'color',lineStyles(i, :) ,'linewidth',1.4);
            hold on
            [hm(i), hsd(i)] = boundedline(tvec(1:13)/3600,...
                meanvals(:, msid, i, 1, j), sdvals(:, msid, i, 1, j));
            set(hsd(i), 'FaceColor', lineStyles(i, :).^4, 'FaceAlpha', 0.1);
            set(hm(i), 'Color', lineStyles(i, :).^4, 'LineStyle', ':');
            hold on
            set(hm(i), 'LineWidth', 0.8)
        end
        set(gca, 'Ylim', [0, round(maxvals(msid))])
        title(sprintf('GFP, e%s', envname{j}), 'FontSize', 16)
        xlabel('time, hours')
        ylabel('GFP, nM')
    end
    legend(hd,...
        {'DNA = 1nM', 'DNA = 2nM','DNA = 5nM','DNA = 10nM','DNA = 20nM'},...
        'Location', 'NorthWest')
end

% %%% The test circuit (Correction Step)
% We start with a circuit description
% tetR_repression: tet repression model, single step, first 3 hours. 
% 
% D_T + P <-> D_T:P -> D_T + P + T
% 
% D_G + P <-> D_G:P -> D_G + P + G
% 
% 2 T <-> T2
% 
% D_G + T2 <-> D_G:T2

%%
% Next we visualize the Markov chains and prosterior distributions for
% the test circuit. 

nW = 600;
legends = {'kfPT' ,...
    'krPT'  'kfPG'    'krPG'   'kfdim', 'krdim', 'kfrep', 'krrep'};

mcat = catMC(corr_SD);
corr_SD2 = corr_SD(round(end/2):end);
% %%%
% Plot the chains for all the 17 iterations, for 20 of the walkers (for
% easier visualization). 
plotChains(mcat(:,1:30:end, :), nW, legends ); 

% %%% 
% Plot the scatterplot for the last 7 iterations 
mcat_converged= catMC(corr_SD2);
figure
 ecornerplot_vse(mcat_converged(:,:, 1:10:end),...
    'scatter', true,'transparency',0.25, 'color',[.6 .35 .3], ...
    'names', legends);


% %%% Correction Demo Figure
% Next we create the correction demo figure. This figure is arranged into 3
% columns and nICs number of rows. Each row corresponds to one dose
% (initial condition). Within each row, the subplot corresponding to the first
% column has the test circuit behavior in the two environments of interest, 
% the candidate environment eSG and the reference environment (eVS). The
% second column has the same two trajectories, bu in addition has the model
% fit to the candidate environment data. The third column has the
% 'corrected' behavior, along with the two data trajectories. 

envrefID = 1; % can be changed to 2 to generate the correction from 3 to 2. 
envcandID = 3; 

m_rearranged = mcat_converged(:,:)';
nptstotal = size(m_rearranged, 1);

paramid = randperm(nptstotal, npts);
params_to_use_corr = m_rearranged(paramid, :);
envname = {'VS', 'MP', 'SG'};
load(  't015_corr1_20171023_151627_11_MBP', 'tvec', 'nW', 'model_corr',...
    'dosevals_corr','dosemap_corr', 'correction_data', 'pmap_corr',...
    'nSp_corr', 'idMS_corr' )

nMS = size(correction_data, 2); % nMS = 1 here, since only GFP is measured
nICs = size(dosevals_corr,2); % ICs, tetR DNA = [0 0.25 0.5 0.75 1 2 5 10] (nM)

nEnv_total = size(correction_data, 4); 
% the total number of environments for which we have data. 
nEnv_used = 2; 
% the number of environments considered to demonstrate the correction procedure.
nEnv_estimated = 1; 
% the number of environments on which the parameter estimation was performed. 

espIX = pmap_corr{1};
esspIX = pmap_corr{2};
cspIX = pmap_corr{3};

nESP = length(espIX); % the ESP indices in the model (not in logpjoint)
nESSP = length(esspIX);
% the Env specific species indices in the model (not in logpjoint)
nCSP = length(cspIX); % the CSP indices in the model (not in logpjoint)

icvec = zeros(nSp_corr, 1);

% %%%
% we will have 2 sets of simulated trajectories, one for the candidate env
% and one for the reference. Therefore, the number of environments used is
% nEnv_used (= 2). 
simulatedtraj_corrstep1 = zeros(length(tvec(1:13)),nMS, nICs , npts); 
simulatedtraj_corrstep2 = zeros(length(tvec(1:13)),nMS, nICs , npts);

simtraj_cs2_cspfix = zeros(length(tvec(1:13)),nMS, nICs , npts);
% %%%
% simulate the correction model for all the randomly picked points from
% the posterior distribution, fixing the ESPs and ESSPs to the candidate
% environments values. 

for kk = 1:npts
    logpjoint_corr1 = [-0.2821 1.3714 params_to_use_corr(kk, :)];
    cspindices = ((nESSP + nESP)*nEnv_estimated+1):length(logpjoint_corr1);
    paramvec = zeros(nESP+nCSP, 1);
    logpcsp = logpjoint_corr1(cspindices);
    paramvec(cspIX) = logpcsp;
    logpesp = logpjoint_corr1(1:nESP);
    paramvec(espIX) = logpesp;
    esspindices = (nESP + 1):(nESP+nESSP);
    
    % set the values of the initial condition vector to the parameters
    icvec(esspIX) = exp(logpjoint_corr1(esspindices));
    
    for doseID = 1:nICs
        icvec(dosemap_corr) = dosevals_corr(:, doseID);
        % simulate the model
        [~, simudata] = model_corr(paramvec, icvec, tvec(1:13));
        for msid = 1:nMS
            simulatedtraj_corrstep1(:,msid,...
                doseID, kk) = simudata(:, idMS_corr(msid));
        end
    end
end

% %%%
% Compute the mean and standard deviations for correction step 1
meanvals_corrstep1 = mean(simulatedtraj_corrstep1, 4); 
sdvals_corrstep1= std(simulatedtraj_corrstep1,0, 4);
maxvals_corrstep1 = squeeze(max(max(max(meanvals_corrstep1+sdvals_corrstep1,...
    [], 1), [], 3), [], 5)); % 1 by nMS array. 
% %%%
% Also, simulate the trajectories in the reference environment. here we
% randomly mix and match points from the reference environments environment
% specific parameters and species, and the CSP from correction step 1. 

refsID = ((envrefID-1)*(nESSP + nESP)+1); % reference ESP start ID
refeID = (envrefID*(nESSP + nESP)); % reference ESP end ID

refarbitrarypt = arbitrary_point(:, refsID:refeID);
ref_ESPs_sharedCSPs = calib_point(refsID:refeID);

% %%%
% other option: params_to_use_calib(kk,...
% ((envrefID-1)*(nESSP + nESP)+1):(envrefID*(nESSP + nESP))), though this is a
% bit buggy right now
for kk = 1:npts
    logpjoint_corrstep2 = [refarbitrarypt  params_to_use_corr(kk, :)];
    cspindices = ((nESSP + nESP)+1):length(logpjoint_corrstep2);
    paramvec = zeros(nESP+nCSP, 1);
    logpcsp = logpjoint_corrstep2(cspindices);
    paramvec(cspIX) = logpcsp;
        logpesp = logpjoint_corrstep2(1:nESP);
        paramvec(espIX) = logpesp;
        esspindices = (nESP + 1):(nESP+nESSP);
        
        % set the values of the initial condition vector to the parameters
        icvec(esspIX) = exp(logpjoint_corrstep2(esspindices));
        
        for doseID = 1:nICs
            icvec(dosemap_corr) = dosevals_corr(:, doseID);
            
            % simulate the model
            [~, simudata] = model_corr(paramvec, icvec, tvec(1:13));
            for msid = 1:nMS
                simulatedtraj_corrstep2(:,msid, doseID, kk) =...
                    simudata(:, idMS_corr(msid));
            end
        end
end

% Compute the mean and standard deviations for correction step 1
meanvals_corrstep2 = mean(simulatedtraj_corrstep2, 4); 
sdvals_corrstep2= std(simulatedtraj_corrstep2,0, 4);
maxvals_corrstep2 = squeeze(max(max(max(meanvals_corrstep2+sdvals_corrstep2,...
    [], 1), [], 3), [], 5)); % 1 by nMS array. 

% %%%
% %%%
% other option: params_to_use_calib(kk,...
% ((envrefID-1)*(nESSP + nESP)+1):(envrefID*(nESSP + nESP))), though this is a
% bit buggy right now
for kk = 1:npts
    logpjoint_corrstep2_cspfix =...
        [ref_ESPs_sharedCSPs params_to_use_corr(kk, :)];
    cspindices = ((nESSP + nESP)+1):length(logpjoint_corrstep2_cspfix);
    paramvec = zeros(nESP+nCSP, 1);
    logpcsp = logpjoint_corrstep2_cspfix(cspindices);
    paramvec(cspIX) = logpcsp;
        logpesp = logpjoint_corrstep2_cspfix(1:nESP);
        paramvec(espIX) = logpesp;
        esspindices = (nESP + 1):(nESP+nESSP);
        
        % set the values of the initial condition vector to the parameters
        icvec(esspIX) = exp(logpjoint_corrstep2_cspfix(esspindices));
        
        for doseID = 1:nICs
            icvec(dosemap_corr) = dosevals_corr(:, doseID);
            
            % simulate the model
            [~, simudata] = model_corr(paramvec, icvec, tvec(1:13));
            for msid = 1:nMS
                simtraj_cs2_cspfix(:,msid, doseID, kk) =...
                    simudata(:, idMS_corr(msid));
            end
        end
end

% Compute the mean and standard deviations for correction step 1
meanvals_corrstep2_cs2_cspfix = mean(simtraj_cs2_cspfix, 4); 
sdvals_corrstep2_cs2_cspfix= std(simtraj_cs2_cspfix,0, 4);
maxvals_corrstep2_cs2_cspfix =...
    squeeze(max(max(max(meanvals_corrstep2_cs2_cspfix+...
    sdvals_corrstep2_cs2_cspfix,...
    [], 1), [], 3), [], 5)); % 1 by nMS array. 

% compute the max of the axis jointly for corrstep 1 and 2. 
maxvals_corr = max([maxvals_corrstep1; maxvals_corrstep2;...
    maxvals_corrstep2_cs2_cspfix], [], 1);

% %%%
% Inialize arrays for handles to the graphics objects. 
lineStyles = linspecer(2*nICs,'sequential');
hd_cand = zeros(nICs, 3); % data trajectory handles for candidate environment
hd_ref = zeros(nICs, 3); % data trajectory handles for reference environment
hm_cand = zeros(nICs, 1); % model fit mean trajectory handles
hsd_cand = zeros(nICs, 1); % model fit sd trajectory handles (patch objects)
hm_ref = zeros(nICs, 1); % model prediction mean trajectory handles
hsd_ref = zeros(nICs, 1); % model prediction sd trajectory handles (patch objects)

% %%% 
% create the 3 column subplot 
nICs = 4;
for msid = 1:nMS
    maxvals_corr(msid) = 1500;
    figure
    ss = get(0, 'screensize');
    set(gcf, 'Position', [50 100 ss(3)/1.6 ss(4)/1.4]);
    
% for each initial condition row
    for i = 1:nICs
        % column 1: just the experimental data
        linearidx = 4*(i-1)+1;
        subplot(nICs, 4,linearidx);
        
        hd_ref(i, 1)=plot(tvec(1:13)/3600,...
            1000*correction_data(1:13,msid, i, envrefID),...
            'color',lineStyles(i, :) ,'linewidth',0.8);
        hold on
        hd_cand(i, 1)=plot(tvec(1:13)/3600,1000*correction_data(1:13,msid,...
            i,envcandID ),...
            'color',lineStyles(nICs+i, :) ,'linewidth',0.8);
        hold on
        set(gca, 'Ylim', [0, round(maxvals_corr(msid))])
        set(gca, 'Xlim', [0, 1.6])
        title(sprintf('Experimental data, tetR DNA = %0.2g',...
            dosevals_corr(2, i)), 'FontSize', 12)
        xlabel('time, hours')
        ylabel('GFP, nM')
        legend([hd_ref(i, 1), hd_cand(i, 1)], ...
            {'Reference Extract', 'Candidate Extract'}, 'Location', 'NorthWest')
        
        % column 2: overlay correction step 1 fit (CSP estimation)
        linearidx = 4*(i-1)+2;
        subplot(nICs, 4,linearidx);
        hd_ref(i, 2)=plot(tvec(1:13)/3600,...
            1000*correction_data(1:13,msid, i, envrefID),...
            'color',lineStyles(i, :) ,'linewidth',0.8);
        hold on
        hd_cand(i, 2)=plot(tvec(1:13)/3600,...
            1000*correction_data(1:13,msid, i,envcandID ),...
            'color',lineStyles(nICs+i, :) ,'linewidth',0.8);
        hold on
        [hm_cand(i), hsd_cand(i)] = boundedline(tvec(1:13)/3600,...
            meanvals_corrstep1(:, msid, i, 1), sdvals_corrstep1(:, msid, i, 1));
        set(hsd_cand(i), 'FaceColor', lineStyles(nICs+i, :).^4, 'FaceAlpha', 0.1);
        set(hm_cand(i), 'Color', lineStyles(nICs+i, :).^4, 'LineStyle', ':');
        hold on
        set(hm_cand(i), 'LineWidth', 1)
        set(gca, 'Ylim', [0, round(maxvals_corr(msid))])
        set(gca, 'Xlim', [0, 1.6])
        title(sprintf('Correction Step 1, tetR DNA = %0.2g',...
            dosevals_corr(2, i)), 'FontSize', 12)
        xlabel('time, hours')
        legend([hd_ref(i, 2), hd_cand(i, 2), hm_cand(i)], ...
            {'Reference Extract', 'Candidate Extract', 'Model Fit (mean, sd)'},...
            'Location', 'NorthWest')
%         ylabel('GFP, nM')
        
        % column 3: overlay correction step 2 prediction istead of
        % correction step 1 fit. ("corrected behavior")
        linearidx = 4*(i-1)+3;
        subplot(nICs, 4,linearidx);
        hd_ref(i, 3)=plot(tvec(1:13)/3600,...
            1000*correction_data(1:13,msid, i, envrefID),...
            'color',lineStyles(i, :) ,'linewidth',0.8);
        hold on
        hd_cand(i, 3)=plot(tvec(1:13)/3600,...
            1000*correction_data(1:13,msid, i,envcandID ),...
            'color',lineStyles(nICs+i, :) ,'linewidth',0.8);
        hold on
        [hm_ref1(i), hsd_ref1(i)] = boundedline(tvec(1:13)/3600,...
            meanvals_corrstep2(:, msid, i, 1), sdvals_corrstep2(:, msid, i, 1));
        set(hsd_ref1(i), 'FaceColor', lineStyles(i, :).^4, 'FaceAlpha', 0.1);
        set(hm_ref1(i), 'Color', lineStyles(i, :).^4, 'LineStyle', ':');
        hold on
        set(hm_ref1(i), 'LineWidth', 1)
        set(gca, 'Ylim', [0, round(maxvals_corr(msid))])
        set(gca, 'Xlim', [0, 1.6])
        title(sprintf('Correction Step 2, tetR DNA = %0.2g',...
            dosevals_corr(2, i)), 'FontSize', 12)
        xlabel('time, hours')
        legend([hd_ref(i, 3), hd_cand(i, 3), hm_ref1(i)],...
            {'Reference Extract', 'Candidate Extract', ...
            '''Corrected'' Trajectories (mean, sd)'}, 'Location', 'NorthWest')
%         ylabel('GFP, nM')   


        % column 4: overlay correction step 2 prediction instead of
        % correction step 1 fit. ("corrected behavior") WITH CSP FIXING
        linearidx = 4*(i-1)+4;
        subplot(nICs, 4,linearidx);
        
        % reference experimental data
        hd_ref(i, 4)=plot(tvec(1:13)/3600,...
            1000*correction_data(1:13,msid, i, envrefID),...
            'color',lineStyles(i, :) ,'linewidth',0.8);
        hold on
        
        % candidate experimental data
        hd_cand(i, 4)=plot(tvec(1:13)/3600,...
            1000*correction_data(1:13,msid, i,envcandID ),...
            'color',lineStyles(nICs+i, :) ,'linewidth',0.8);
        hold on
        
        % plot the corrected traj mean and standard deviation (as a matlab
        % patch class object)
        [hm_ref2(i), hsd_ref2(i)] = boundedline(tvec(1:13)/3600,...
            meanvals_corrstep2_cs2_cspfix(:, msid, i, 1), ...
            sdvals_corrstep2_cs2_cspfix(:, msid, i, 1));
        set(hsd_ref2(i), 'FaceColor', lineStyles(i, :).^4, 'FaceAlpha', 0.1);
        set(hm_ref2(i), 'Color', lineStyles(i, :).^4, 'LineStyle', ':');
        hold on
        
        set(hm_ref2(i), 'LineWidth', 1)
        set(gca, 'Ylim', [0, round(maxvals_corr(msid))])
        set(gca, 'Xlim', [0, 1.6])
        
        title(sprintf('Correction Step 2, tetR DNA = %0.2g',...
            dosevals_corr(2, i)), 'FontSize', 12)
        xlabel('time, hours')
        legend([hd_ref(i, 4), hd_cand(i, 4), hm_ref2(i)],...
            {'Reference Extract', 'Candidate Extract', ...
            '''Corrected'' Trajectories (mean, sd)'}, 'Location', 'NorthWest')

    end
end

% %%% compute the % correction
nICs = 4;
normorig = zeros(nICs, 1);
normreduced = zeros(nICs, 1);
alpha = zeros(nICs, 1);
normreduced_cspfix = zeros(nICs, 1);
alpha_cspfix = zeros(nICs, 1);
for msid = 1:nMS
% for each initial condition row
    for i = 1:nICs
      % distance between the reference and candidate experimental data. 
      normorig(i) = norm(1000*correction_data(1:13,msid, i,...
          envrefID) - 1000*correction_data(1:13,msid, i,envcandID ));
      
      % distance between the mean of the 
      normreduced(i) = norm(1000*correction_data(1:13,msid, i, envrefID)...
          - meanvals_corrstep2(:, msid, i, 1));
      
      normreduced_cspfix(i) = norm(1000*correction_data(1:13,msid, i, envrefID)...
          - meanvals_corrstep2_cs2_cspfix(:, msid, i, 1));
      
      alpha(i) = ...
          max(sdvals_corrstep2(:, 1, i, 1))/max(sdvals_corrstep1(:, 1, i, 1));
      
      alpha_cspfix(i) = ...
          max(sdvals_corrstep2_cs2_cspfix(:,...
          1, i, 1))/max(sdvals_corrstep1(:, 1, i, 1));
    end
end

RR11 = sum(normreduced)/sum(normorig);
RR12 = sum(alpha.*normreduced)/sum(normorig);
RR13 = sum(normreduced_cspfix)/sum(normorig);
RR14 = sum(alpha_cspfix.*normreduced_cspfix)/sum(normorig);

RR21 = sum(normreduced./normorig)/nICs;
RR22 = sum(alpha.*normreduced./normorig)/nICs;
RR23 = sum(normreduced_cspfix./normorig)/nICs;
RR24 = sum(alpha_cspfix.*normreduced_cspfix./normorig)/nICs;

metricvals = [RR11 RR12 RR13 RR14 RR21 RR22 RR23 RR24]
