/**
 * CTCNVAstub.cpp 
 * A C++ implementation of a matrix multiplication in the chirplet transform
 * for colored noise and varying amplitude.
 * 
 * See also documentation in ChirpletTrans/ColNoiseVarAmpWeightMatrix.m
 * and ChirpletTrans/CTCNVA.m
*/


/**
 * $RCSfile: CTCNVAstub.cpp,v $
 * $Date: 2007/05/23 07:30:41 $
 * $Revision: 1.2 $
 * Copyright (c) Hannes Helgason, California Institute of Technology, 2006
 */

#include <mex.h>
#include <valarray>
#include <math.h>

// Returns the real part of the multiplication of three complex
// numbers x=xr+i*xi,y=yr+i*yi,z=zr+i*zi
inline double multiplyThreeCmplx(double xr, double xi, double yr, double yi,
                          double zr, double zi) {
    return xr*yr*zr - xr*yi*zi - xi*yr*zi - xi*yi*zr;
}

// Multiplies a complex conjugate symmetric matrix and a complex vector
inline double conjSymMatrixMult(double Br[], double Bi[], double xr[], double xi[],
                  int nRows) {
    int k,n;
    double result = 0;
    
    // take care of the diagonal elements first

    for (k=0; k<nRows; k++) {
        result += Br[k*nRows+k]*(xr[k]*xr[k]+xi[k]*xi[k]);
    }
    
    if (Bi!=NULL) {
        for (k=0; k<nRows; k++) {
            for (n=k+1; n<nRows; n++) {
                result += 2*multiplyThreeCmplx(Br[n*nRows+k],Bi[n*nRows+k],
                         xr[k],-xi[k],xr[n],xi[n]);
            }
        }
    } else {
        for (k=0; k<nRows; k++) {
            for (n=k+1; n<nRows; n++) {
                result += 2*Br[n*nRows+k]*(xr[k]*xr[n]+xi[k]*xi[n]);
            }
        }
    }

    return sqrt(std::abs(result)); 
    // taking the  absolute value is strictly speaking not needed 
    // since the matrix is pos. definite. 
}

/** 
 *  Gateway function 
 *
 *  Inputs:
 *	prhs[0] - a cell array where cell (k,m) is weighting matrix for 
 *            the k-th frequency index and m-th slope index.
 *	prhs[1] - a cell array with correlations for colored noise weighted and
 *            amplitude modulated chirplets for a particular dyadic 
 *            interval. The n-th entry corresponds to correlations for the 
 *            n-th polynomial in the varying amplitude basis. 
 *  prhs[2] - number of frequency indices
 *	prhs[3] - number of slope indices
 *	prhs[4] - dimension of the varying amplitude basis 
 *
 *  Outputs:
 *	plhs[0] - chirplet coefficients at dyadic interval the chirplet 
 *            correlations in prhs[1] correspond to	
 */
void mexFunction(int nlhs, mxArray *plhs[], int nrhs,
                 const mxArray *prhs[]) 
{ 
    double *resultMatrix;   // matrix for storing final chirplet coeff.
    int nFreqs;             // number of frequencies
    int nSlopes;            // number of frequencies
    int nRows;              // dimension of varying amplitude basis,
                            // equal to the number of rows in the
                            // weighting matrix
    int fmin;
    int fmax;
    int k,n,m;
    
    // pointer to cell array for weight matrices and chirplet coeff
    const mxArray *BcellElementPtr;
    const mxArray *xcellElementPtr;
    int weightMatrixIndex;
    int cCoeffTableIndex;
    double *Br;
    double *Bi;
    double *coeffr;
    double *coeffi;
    int startFreq;
    
    // read input variables
    int N = (int) mxGetScalar(prhs[2]);
    nSlopes = (int) mxGetScalar(prhs[3]);
    nRows = (int) mxGetScalar(prhs[4]);
    fmin = (int) mxGetScalar(prhs[5]);
    fmax = (int) mxGetScalar(prhs[6]);
    nFreqs = fmax-fmin+1;
    
    // initialize x-vector
    std::valarray<double> xr(nRows);
    std::valarray<double> xi(nRows);
    
    // prepare output
    plhs[0] = mxCreateDoubleMatrix(nFreqs,nSlopes,mxREAL);
    resultMatrix = mxGetPr(plhs[0]);
    
    // loop over frequencies
    for (k=0; k<nFreqs; k++) {
        // loop over slopes
        for (n=0; n<nSlopes; n++) {
            // get the weighting matrix
            weightMatrixIndex = n*nFreqs + k;
            cCoeffTableIndex = n*N + (k+fmin);
            BcellElementPtr = mxGetCell(prhs[0], weightMatrixIndex);
            Br = mxGetPr(BcellElementPtr);
            Bi = mxGetPi(BcellElementPtr);
            
            // fill in the x-vector
            for (m=0; m<nRows; m++) {
                xcellElementPtr = mxGetCell(prhs[1], m);
                coeffr = mxGetPr(xcellElementPtr);
                coeffi = mxGetPi(xcellElementPtr);
                xr[m] = coeffr[cCoeffTableIndex];
                if (coeffi!=NULL) {
                    xi[m] = coeffi[cCoeffTableIndex];
                } else {
                    xi[m]=0;
                }
            }
            // do matrix multiplication
            resultMatrix[weightMatrixIndex] = 
                                   conjSymMatrixMult(Br,Bi,&xr[0],&xi[0],nRows);
        }
    }

}
