/**
 * TollBPDirect.cpp -- C++ implementation of the BP algorithm for the BP Estimator
 * (constrained shortest path algorithm). 
 *  See the documentation in BPDirect.m for usage
 *
 * THIS IS PART OF CHIRPLAB
 */

/**
 * $RCSfile: TollBPDirect.cpp,v $
 * $Date: 2008/05/19 22:51:39 $
 * $Revision: 1.1 $
 * Copyright (c) Hannes Helgason, California Institute of Technology, 2007
 */

#include <mex.h>
#include <vector>
#include <cmath>
#include <math.h>
#include "GraphParam.hpp"
#include "Base2Utilities.hpp"
using namespace std;


/**
 * Given a GraphParam object corresponding to a chirplet graph and a 
 * chirplet coefficient table associated with this chirplet graph
 * this function calculates the best paths of lengths 1,...,maxLength
 * and returns it in the vector optimalDistances.
 *
 * Inputs:
 *   gp                 a GraphParam structure
 *   cTable             a chirplet coefficient table associated with gp
 *   maxLength          the maximum allowable number of chirplets in a path
 *   optimalDistances   the cost of best path with exactly k chirplets will
 *                      be returned in the coordinate optimalDistances[k-1]
 *   returnBestPath     set to false if only the value of the best cost should be returned,
 *                      otherwise set to true
 *   bestPaths          if returnBestPath==false, this variable will be ignored
 *                      distance labels of the best paths (distance label
 *                      indexing starts from 0)
 *                      bestPaths[k-1] gives the distance labels of the
 *                      best path using exactly k chirplets
 *
 */
void findBPcostsAndPaths(GraphParam& gp, const mxArray *cTable, 
                int maxLength, vector<double>& optimalDistances,    
                bool returnBestPath,
                vector< vector<int> >& bestPaths) {
    // variables for input data
    int finestScale; 
    int coarsestScale;
    int J;
    int fmin; 
    int fmax;
    int nSlopes;
    int slopeInd;
    
    const mxArray *cTableCellElementPtr;
    double *cTableReal;
    double *cTableImag;
    int cTableInd;
    int cnvaCorrection; // for indexing in the case of colored noise
    int cnvaNoPerSlope;     // and varying amplitude XT
    
    
    int N;
    int nNodes;         // total no. of nodes in the graph
    int nStartNodes;    // no. nodes where chirplet paths can start from
    int nFreqs;         // no. frequency indices
    int nTFnodes;       // no. nodes where chirplets can eminate from
    int nTimeIndices;   // total number of time indices in [0,1)
    
    int endfreq;
    
    vector< vector<double> > distLabelList;    // distance labels
    vector< vector<int> > pred;  // predecessor nodes

    vector<int> optEndNodes;  // the endnodes of the best paths
    double arcCost;
    double tentDistance;
    int startNodeInd;
    int endNodeInd;
    
    int k,len,m,s;    // indices for loops
    int b;       // variable to store dyadic index
    int coarsestAllowScale;

    J = gp.j();
    coarsestScale = gp.cs();
    finestScale = gp.fs();
    fmin = gp.fmin();
    fmax = gp.fmax();
    N = gp.n();
    nFreqs = gp.numFrequencies();
	nTimeIndices = gp.numTimeIndices();
	nTFnodes = gp.numTimeFreqNodes();
	nNodes = gp.numNodes();
	nStartNodes = gp.numStartNodes();
    
    
    // initialize distance labels and predecessor nodes
    distLabelList.resize(nNodes);

    if (returnBestPath)
        pred.resize(nNodes); 

    for (k=0; k < nStartNodes; k++) {
            distLabelList[k].push_back(0);
            if (returnBestPath) 
                pred[k].push_back(k);
    }

    //  set all other distance labels to infinity
    for (k=nStartNodes; k < nNodes; k++) {
            distLabelList[k].push_back(mxGetInf());
            if (returnBestPath) 
                pred[k].push_back(0);
    }
    
    // SPECIAL HANDLING FOR THE COLORED NOISE VARYING AMPLITUDE 
    // CHIRPLET COEFF DATA STRUCTURE
    cnvaCorrection = 0;
    cnvaNoPerSlope = N;
    if (gp.xtType()==COLOREDNOISEVARAMPXT) {
        cnvaCorrection = -fmin;
        cnvaNoPerSlope = nFreqs;
    }

    // loop over time indices k, t_k=k/2^fs
    for (k=0; k<nTimeIndices; k++) {
        // find the coarsest scale that can start from this time index
        coarsestAllowScale = gp.coarsestAllowableScale(k);

        // loop over allowable scales
        for (s=finestScale; s>=coarsestAllowScale; s--) {
            b = divPow2(k,finestScale-s); // dyadic time index
              
            cTableCellElementPtr = mxGetCell(cTable,pow2(s)+b -1);
            cTableReal = mxGetPr(cTableCellElementPtr);
  //          cTableImag = mxGetPi(cTableCellElementPtr);
            
            nSlopes = gp.nslopes(s);
            // loop over frequency indices
            for (m=fmin; m<=fmax; m++) {
                startNodeInd = k*nFreqs+m-fmin;
                
                // loop over slope indices
                for (slopeInd=0; slopeInd<nSlopes; slopeInd++) {
                    // find the end frequency
                    endfreq = m + (int) rint(gp.dfreq(s,slopeInd));
                    if (endfreq>=fmin) {
                        if (endfreq<=fmax) {
                            // update distance label if needed
                        
                            // get cost, note that we have to handle the
                            // different indexing specifically for colored
                            // noise and varying amplitude
                            cTableInd = slopeInd*cnvaNoPerSlope + m + cnvaCorrection;
                            arcCost = -cTableReal[cTableInd];
//                            if (cTableImag!=NULL) {
//                                arcCost += -cTableImag[cTableInd]*cTableImag[cTableInd];
//                            }

                            endNodeInd = (k+pow2(finestScale-s))*nFreqs+endfreq-fmin;
                            // loop over those distance labels
                            for (len=1;len<distLabelList[startNodeInd].size()+1 & len<=maxLength ;len++) {
                                tentDistance = distLabelList[startNodeInd][len-1] + arcCost;
                                if (distLabelList[endNodeInd].size() < len+1) {
                                    distLabelList[endNodeInd].push_back(tentDistance); 
                                    if (returnBestPath) 
                                        pred[endNodeInd].push_back(startNodeInd);
                                } else {
                                    if (distLabelList[endNodeInd][len] > tentDistance) {
                                        distLabelList[endNodeInd][len] = tentDistance;
                                        if (returnBestPath)
                                            pred[endNodeInd][len] = startNodeInd;
                                    }
                                }
                            }  
                        } else {
                            // assuming that the slopes are in an increasing order
                            // we break the loop as soon as we exceed the maximum
                            // allowed frequency
                            break;
                        }
                    }
                }
            }
        }
    }
    
    // All distance labels are optimal. Just need to
    // read off the shortest distances for each length
    optimalDistances.resize(maxLength);
    optEndNodes.resize(maxLength);
    for (len=0; len<maxLength; len++) {
        optimalDistances[len] = mxGetInf(); // initialize
        for (m=0; m < nFreqs; m++) {
            // there are nFreqs nodes at t=1 and nTFnodes+nFreqs nodes in total
            if ( distLabelList[nTFnodes+m][len+1] < optimalDistances[len] ) { 
                // +1 because first entry corresponds to zero length
                optimalDistances[len] =  distLabelList[nTFnodes+m][len+1];
                optEndNodes[len] =  nTFnodes+m;
            }
        }

    }
    if (returnBestPath) {
        // The minimum possible path length depends on the coarsest
        // allowable scale. If cs is the coarsest scale, then we
        // can have chirplet paths of length 2^cs and up.
    
        bestPaths.resize(maxLength);
        for (len=pow2(gp.cs())-1; len<maxLength; len++) {
            bestPaths[len].resize(len+2);
            bestPaths[len][len+1] = optEndNodes[len];
            for (m=len; m>=0; m--) {
                bestPaths[len][m] = pred[bestPaths[len][m+1]][m+1];
            }
        }
    }
}

/**
 * A mex interface for the findBP routine. 
 * Usage
 *    [costs,paths] = BPDirect(cc,param,maxLength)
 *    costs = BPDirect(cc,param,maxLength)
 *  Inputs
 *    cc      chirplet coefficients as returned by ChirpletTransform,
 *            the same graph parameters as in param have to be used
 *            for the transform
 *    param   chirplet graph parameters as returned by GetChirpletGraphParam.
 *    maxLength	maximum number of arcs that the best path is allowed to
 *               have.
 *  Ouput
 *    costs   a vector of length maxLength where costs(k) is the value of
 *            the best path with number of chirplets equal to k.
 *    paths	  a 1d cell array of length maxLength. Entry k in the
 *            cell array corresponds to the best path with k chirplets.
 *
 */
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray
                 *prhs[]) {
    GraphParam graphParam(prhs[1]);
    int maxL = (int) mxGetScalar(prhs[2]);
    int minL; // the minimum possible path length in the graph
    vector<double> optDistVector;
    vector< vector<int> > bestPaths;
    mxArray *bestPath;
    double *path;
    mxArray *bestPathsCell;

    // The minimum possible path length depends on the coarsest
    // allowable scale. If cs is the coarsest scale, then we
    // can have chirplet paths of length 2^cs and up.
    minL = pow2(graphParam.cs());

    if (minL>maxL) {
        mexErrMsgTxt("Index for coarsest scale should be set to a smaller number.\nNo path with length maxLength or smaller exists in the graph.");
    }
    
    // note that prhs[0] is the chirplet coeff cell structure
    if (nlhs<=1) {
        // only get the values of the best paths
        findBPcostsAndPaths(graphParam,prhs[0],maxL,optDistVector,false,bestPaths);
    } else if (nlhs>1) {
        // get both the best paths and their values
        findBPcostsAndPaths(graphParam,prhs[0],maxL,optDistVector,true,bestPaths);
    }
    
    // Write results to the left-hand side
    
    // values of the best paths
    double *optDist;
    plhs[0] = mxCreateDoubleMatrix(1,optDistVector.size(),mxREAL);
    optDist = mxGetPr(plhs[0]);
    for (int k=0; k<optDistVector.size(); k++) {
        optDist[k] = optDistVector[k];
    }

    if (nlhs>1) {
        // the best paths
        bestPathsCell = mxCreateCellMatrix(1,maxL);
        for (int len=minL-1; len<maxL; len++) {
            bestPath = mxCreateDoubleMatrix(1,bestPaths[len].size(), mxREAL);
            path = mxGetPr(bestPath);
            for (int k=0; k<bestPaths[len].size(); k++) {
                path[k] = bestPaths[len][k]+1; // +1 because different indexing in Matlab code, PEND: This could cause bugs in the future
            }
            mxSetCell(bestPathsCell,len,mxDuplicateArray(bestPath));
            mxDestroyArray(bestPath);
        }
        plhs[1] = bestPathsCell;
    }
}
