/****************************************************************************************
 *
 * Author:		Andrew I. Hanna
 * Data:		25/02/2006
 * Function:	A function to calculate the thin plate spline warp function
 *
 ****************************************************************************************/


/*#include <stdio.h>*/
#include <math.h>
#include "mex.h"


// Define globals **********************************************************************
#define BMP_WIDTH	1024
#define BMP_HEIGHT	1024
#define TINY 1.0e-20;

// Function declarations ****************************************************************
float *GetInputImage(const mxArray *, int *, int *);
void init_image(mxArray *);
void pts2TPS_param(float *, float *, int , float **, float **, float *);
void K_mat(float *, int , float *, float *);
void P_mat(float *, float *, int);
void L_mat(float *K, float *P, float **L, int Nvert);
void Y_mat(float *ipts, float **Y, int Nvert);
void _minverse(float **m, float **invm, int *indx, int N, float *col);
void lubksb(float **a, int n, int *indx, float b[]);
void ludcmp(float **a, int n, int *indx, float *d);
void _mmult(float **mat1, float **mat2, float **z, int l, int m, int n);
void get_w(float **, float **, int );
void get_a(float **, float **, int );
void matdistance(float *, float *, int, int, float *);
void U_rbs(float *r, int M, int N, float *);
void print_mat(float *mat, int M, int N);
void print2dmatrix(float **A, int M, int N);
float **dmatrix(long nrow, long ncol);
void max_dims(float *pts, int *w, int *h, int N);
void get_channel(float *I, float **C, int channel, int w, int h);
void psi_tps(float **M, float *U, float **a, float **w, float *pts, int N, int Nvert, float **R, float *r, float *K, float **Kmat, float **R2);
void warp_channel(float **RI, int irows, int icols, float **RO, int orows, int ocols, float **w, float **a, float *pts, int Nvert, float **M, float *U, float **TPS, float *r, float *K, float **Kmat, float **R2);
void M_mat(int, int, float **);
void U_mat(int, int, float *);
void round_TPS(float **TPS, float *TPSx, float *TPSy, int N);
void interp_pts(float **I, float **O, float **, int N, int r, int width, int height);
void matrix2vector(float **M, float *v, int rows, int cols, int offset);
// **************************************************************************************


void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    int iwidth, iheight, npts, owidth, oheight, i, ii, iii;
    int dims[3];
    float *ivert, *overt, **w, **a, **CI, **CO, *I, *pix, **M, *U, **TPS, *r, *K, **Kmat, **R2;
		float *regterm;
    mxArray *oimg;
    
    
    if(nrhs < 3)
        mexErrMsgTxt("Must give the input image and the vertices!");
    
    // Get the texture image **************************************************************
    pix = GetInputImage(prhs[0], &iwidth, &iheight);
    mexPrintf("Input image\n***********\n");
    mexPrintf("Height: %d, Width: %d\n\n", iheight, iwidth);
    
    // Get regularization term *****************************************************************
    if(!mxIsNumeric(prhs[3]) || mxIsEmpty(prhs[3]))
        mexErrMsgTxt("\nImage is larger than the rendering window\n");
    else
        regterm = (float *)mxGetData(prhs[3]);
    
    
    mexPrintf("Regularization term: %f\n", *regterm);
    
    
    // Check the image size < bitmap
    if(iwidth > BMP_WIDTH || iheight > BMP_HEIGHT)
        mexErrMsgTxt("\nImage is larger than the rendering window\n");
    
    
    // Get input vertices *****************************************************************
    if(!mxIsNumeric(prhs[2]) || mxIsEmpty(prhs[2]))
        mexErrMsgTxt("Input vertices must not be empty\n");
    else
        ivert = (float*)mxGetData(prhs[2]);
    
    // The number of vertices
    npts = mxGetNumberOfElements(prhs[2]) / 2;
    
    mexPrintf("Number of vertices: %d\n", npts);
    
    
    // Get output vertices *****************************************************************
    if(!mxIsNumeric(prhs[1]) || mxIsEmpty(prhs[1]))
        mexErrMsgTxt("Output vertices must not be empty\n");
    else
        overt = (float*)mxGetData(prhs[1]);
    
    
    // The number of vertices
    if((mxGetNumberOfElements(prhs[1]) / 2) != npts)
        mexErrMsgTxt("Number of input and output vertices must match");
    
    // Setup the return image and return **************************************************
    
  //  for(i=0; i<Nvert; i++)
    //    mexPrintf("%f  %f\n", overt[2*i], overt[2*i+1]);
    max_dims(overt, &oheight, &owidth, npts);    
    mexPrintf("Output image\n***********\n");
    mexPrintf("Height: %d, Width: %d\n\n", oheight, owidth);
    
    dims[0] = owidth;
    dims[1] = oheight;
    dims[2] = 3;
    if ((oimg = mxCreateNumericArray(3, dims, mxSINGLE_CLASS, mxREAL)) == NULL)
        mexErrMsgTxt("Unable to allocate output matrix\n");
    
    /* set up all the data first */
    I = (float*)mxGetData(oimg);
    w = dmatrix(npts,2);
    a = dmatrix(3,2);
    CI = dmatrix(iwidth,iheight);
    CO = dmatrix(owidth, oheight);
    M = dmatrix(owidth, 3);
    U = mxCalloc(2*owidth, sizeof(float));
    TPS = dmatrix(owidth, 2);
    r = mxCalloc(owidth*npts, sizeof(float));
    K = mxCalloc(owidth*npts, sizeof(float));
    Kmat = dmatrix(owidth, npts);
    R2 = dmatrix(owidth, 2);
    
    pts2TPS_param(overt, ivert, npts, w, a, regterm);
    mexPrintf("height: %d, width: %d\n", iheight, iwidth);
    for(i=0; i<3; i++){
        get_channel(pix, CI, i, iwidth, iheight);
        get_channel(I, CO, i, owidth, oheight);
        warp_channel(CI, iwidth, iheight, CO, owidth, oheight, w, a, overt, npts, M, U, TPS, r, K, Kmat, R2);
        matrix2vector(CO, I, owidth, oheight, i);
    }
    mxSetData(oimg, I);
    plhs[0] = oimg;
}

void matrix2vector(float **M, float *v, int rows, int cols, int offset){
    int i, ii;
    for(i=0; i<rows; i++){
        for(ii=0; ii<cols; ii++){
            v[i + ii*rows + rows*cols*offset] = M[i][ii];
        }
    }
}

void warp_channel(float **RI, int irows, int icols, float **RO, int orows, int ocols, float **w, float **a, float *pts, int Nvert, float **M, float *U, float **TPS, float *r, float *K, float **Kmat, float **R2){
    int i;
    for(i=0; i<ocols; i++){
        M_mat(orows, i+1, M);
        U_mat(orows, i+1, U);
        //print_mat(U, orows, 2);
        psi_tps(M, U, a, w, pts, orows, Nvert, TPS, r, K, Kmat, R2);
        interp_pts(RI, RO, TPS, orows, i, irows, icols);
    }
}
void interp_pts(float **I, float **O, float **TPS, int N, int r, int width, int height){
    int i, xind, yind;
    for(i=0; i<N; i++){
        xind = (int)TPS[i][0];
        yind = (int)TPS[i][1];
        if ((xind>0) && (xind < height) && (yind>0) && (yind<width)){
            O[i][r] = I[yind-1][xind-1];
        }
    }
}
void round_TPS(float **TPS, float *TPSx, float *TPSy, int N){
    int i;
    for(i=0; i<N; i++){
        TPSx[i] = (float)ceil(TPS[i][0]);
        TPSy[i] = (float)ceil(TPS[i][1]);
    }
}
void print2dmatrix(float **A, int M, int N){
    int i, ii;
    for(i=0; i<M; i++){
        mexPrintf("\n%d) ", i);
        for(ii=0; ii<N; ii++){
            mexPrintf("%f ", A[i][ii]);
        }
    }
    mexPrintf("\n");
}
void M_mat(int N, int r, float **M){
    int i;
    
    for(i=0; i<N; i++){
        M[i][0] = (float)1;
        M[i][1] = (float)r;
        M[i][2] = (float)i+1;
    }
}
void U_mat(int N, int r, float *M){
    int i;
    for(i=0; i<N; i++){
        M[2*i] = (float)r;
        M[2*i+1] = (float)i+1;
    }
}
void psi_tps(float **M, float *U, float **a, float **w, float *pts, int N, int Nvert, float **R, float *r, float *K, float **Kmat, float **R2){
    int ii, i;
    
    _mmult(M, a, R, N, 3, 2);
    matdistance(U, pts, N, Nvert, r);
    U_rbs(r, N, Nvert, K);
    for(i=0; i<N; i++){
        for(ii=0; ii<Nvert; ii++){
            Kmat[i][ii] = K[ii + i*Nvert];
        }
    }
    _mmult(Kmat, w, R2, N, Nvert, 2);
    for(i=0; i<N; i++){
        for(ii=0; ii<2; ii++){
            R[i][ii] = (float)ceil(R[i][ii] + R2[i][ii]);
        }
    }
}
void get_channel(float *I, float **C, int channel, int rows, int cols){
    int i, ii;
    float *c;
    c = mxMalloc((size_t)((rows*cols)*sizeof(float)));
    for(ii=0; ii<rows*cols; ii++){
        c[ii] = I[ii + (rows*cols)*channel];
    }
    for(i=0; i<cols; i++){
        for(ii=0; ii<rows; ii++){
            C[ii][i] = c[ii + i*rows];
        }
    }
    
}
void max_dims(float *pts, int *w, int *h, int N){
    
    float x, y, maxx, maxy;
    int i;
    
    maxx = -1;
    maxy = -1;
    mexPrintf("Function : max_dims(...)\n");
    for(i=0; i<N; i++){
        x = pts[2*i];
        y = pts[2*i+1];
        if(maxx<0)
            maxx = x;
        else if (maxx<x)
            maxx = x;
        if(maxy<0)
            maxy = y;
        else if (maxy<y)
            maxy = y;
    }
    (*w) = (int)ceil(maxx);
    (*h) = (int)ceil(maxy);
}
void pts2TPS_param(float *opts, float *ipts, int Nvert, float **w, float **a, float *regterm){
    
    float *P, *col, *K;
    float **invL, **Q;
    int *indx;
    float d;
    float **L, **Y;
    
    K = mxCalloc(Nvert*Nvert, sizeof(float));
    L = dmatrix(Nvert+3,Nvert+3);
    Y = dmatrix(Nvert+3,2);
    
    mexPrintf("function : pts2TPS_param\n");
    K_mat(opts, Nvert,  K, regterm);
    P = mxCalloc(Nvert*3, sizeof(float));
    P_mat(opts, P, Nvert);
    L_mat(K, P, L, Nvert);
    Y_mat(ipts, Y, Nvert);
    invL = dmatrix(Nvert+3,Nvert+3);
    indx = (int *)mxMalloc((Nvert+3)*sizeof(int));
    col = (float *)mxMalloc((Nvert+3)*sizeof(float));
    ludcmp(L, (Nvert+3), indx, &d);
    _minverse(L, invL, indx, (Nvert+3), col);
    Q = dmatrix(Nvert+3,2);
    _mmult(invL, Y, Q, Nvert+3, Nvert+3, 2);
    get_w(Q, w, Nvert);
    get_a(Q, a, Nvert);
    mxFree(P);
    mxFree(col);
    mxFree(K);
    mxFree(Q);
    mxFree(invL);
    mxFree(L);
    mxFree(Y);    
}
void get_w(float **A, float **w, int N){
    int i;
    for(i=0; i<N; i++){
        w[i][0] = A[i][0];
        w[i][1] = A[i][1];
    }
}
void get_a(float **A, float **a, int N){
    int i;
    for(i=N; i<N+3; i++){
        a[i-N][0] = A[i][0];
        a[i-N][1] = A[i][1];
    }
}
void Y_mat(float *ipts, float **Y, int Nvert){
    int i, ii;
    for(i=0; i<Nvert; i++){
        for(ii=0; ii<2; ii++){
            Y[i][ii] = ipts[i*2 + ii];
        }
    }
    for(i=Nvert; i<Nvert+3; i++){
        for(ii=0; ii<2; ii++){
            Y[i][ii] = 0;
        }
    }
}
void L_mat(float *K, float *P, float **L, int Nvert){
    int i, ii;
    for(i=0; i<Nvert; i++){
        for(ii=0; ii<Nvert; ii++){
            L[i][ii] = K[i*Nvert + ii];
            
        }
        
    }
    for(ii=Nvert; ii<Nvert+3; ii++){
        for(i=0; i<Nvert; i++){
            L[i][ii] = P[i + Nvert*(ii-Nvert)];
        }
    }
    for(i=Nvert; i<Nvert+3; i++){
        for(ii=0; ii<Nvert; ii++){
            L[i][ii] = P[ii + Nvert*(i-Nvert)];
            
        }
        
    }
    for(i=Nvert; i<Nvert+3; i++){
        for(ii=Nvert; ii<Nvert+3; ii++){
            L[i][ii] = 0;
            
        }
        
    }
}
void P_mat(float *ipts, float *P, int Nvert){
    int i;
    
    for(i=0; i<Nvert; i++){
        P[i] = 1;
    }
    for(i=0; i<Nvert; i++){
        P[i+Nvert] = ipts[2*i];
    }
    for(i=0; i<Nvert; i++){
        P[i+2*Nvert] = ipts[2*i+1];
    }
}

void print_mat(float *mat, int M, int N){
    int i, ii, indx;
    float x;
    mexPrintf("Size: %dx%d\n", M, N);
    for(i=0; i<M; i++){
        mexPrintf("\n%d) ", i);
        for(ii=0; ii<N; ii++){
            indx = i*N + ii;
            x = mat[indx];
            mexPrintf("%f ", x);
        }
    }
    mexPrintf("\n");
    
}

void K_mat(float *opts, int Nvert, float *K, float *regterm){
    float *r;
		int i;
    r = mxCalloc(Nvert*Nvert, sizeof(float));
    matdistance(opts, opts, Nvert, Nvert, r);
    U_rbs(r, Nvert, Nvert, K);
		for(i=0; i<(Nvert*Nvert); i++){
			K[i] = K[i]*(*regterm);
		}
		
}


void matdistance(float *opts, float *ipts, int Nverto, int Nverti, float *r){
    int i, ii;
    for(i=0; i<Nverto; i++){
        for(ii=0; ii<Nverti; ii++){
            r[ii+Nverti*i] = (float)sqrt((opts[2*i] - ipts[2*ii])*(opts[2*i] - ipts[2*ii]) + (opts[2*i+1] - ipts[2*ii+1])*(opts[2*i+1] - ipts[2*ii+1]));
        }
    }
}

void U_rbs(float *r, int M, int N, float *Ur){
    int i;
    for(i=0; i<M*N; i++){
        Ur[i] = r[i]*r[i];
        if(Ur[i] == 0)
            Ur[i] = 1;
        Ur[i] = Ur[i]*(float)log(Ur[i]);
    }
}

void init_image(mxArray *img){
    int i, width, height, ndim;
    float *I;
    const int *dim;
    
    ndim = mxGetNumberOfDimensions(img);
    dim = mxGetDimensions(img);
    
    width = dim[0];
    height = dim[1];
    I = (float*)mxGetData(img);
    for (i=0; i<(width*height); i++){
        I[i] = 0;
    }
    mxSetData(img, I);
}



float *GetInputImage(const mxArray *img, int *width, int *height)
{
    int	ndim, nx, ny;
    const int *dim;
    
    if(!mxIsNumeric(img) || mxIsEmpty(img) || !mxIsSingle(img))
        mexErrMsgTxt("Input image must be float 2D or 3D numeric matrix");
    
    ndim = mxGetNumberOfDimensions(img);
    if(ndim != 3)
        mexErrMsgTxt("Input image must be (nx * ny * 3) 24-bit RGB");
    
    dim = mxGetDimensions(img);
    *width = nx = dim[0];
    *height = ny = dim[1];
    
    if(dim[2] > 3)
        mexWarnMsgTxt("Only using first 3 values of 3rd dimension");
    else if(dim[2] < 3)
        mexErrMsgTxt("Third dimension must have size = 3");
    
    return((float*)mxGetData(img));
}

float **dmatrix(long nrow, long ncol){
    
    long i;
    float **m;
    
    m = (float **)mxMalloc((size_t)(nrow*sizeof(float*)));
    if (!m)
        mexPrintf("error: could not make matrix\n");
    m[0] = (float *)mxMalloc((size_t)((nrow*ncol)*sizeof(float)));
    if (!m[0])
        mexPrintf("error: could not make matrix\n");
    for(i=1; i<nrow; i++)
        m[i] = m[i-1] + ncol;
    return m;
}
void _minverse(float **m, float **invm, int *indx, int N, float *col){
    int i, j;
    for(j=0; j<N; j++){
        for (i=0; i<N; i++){
            col[i] = 0.0;
        }
        col[j] = 1.0;
        lubksb(m, N, indx, col);
        for (i=0; i<N; i++){
            invm[i][j] = col[i];
        }
    }
}
void lubksb(float **a, int n, int *indx, float b[])
{
    int i,ii=0,ip,j;
    float sum;
    for (i=0;i<n;i++) {
        ip=indx[i];
        sum=b[ip];
        b[ip]=b[i];
        if (ii)
            for (j=ii-1;j<=i-1;j++) sum -= a[i][j]*b[j];
        else if (sum) ii=i+1;
        b[i]=sum;
    }
    for (i=n-1;i>=0;i--) {
        sum=b[i];
        for (j=i+1;j<n;j++) sum -= a[i][j]*b[j];
        b[i]=sum/a[i][i];
    }
}

void ludcmp(float **a, int n, int *indx, float *d)
{
    int i,imax,j,k;
    float big,dum,sum,temp;
    float *vv;
    
    vv = (float *)mxMalloc(n*sizeof(float));
    
    //vv=vector(0,n-1);
    (*d)=1.0;
    for (i=0;i<n;i++) {
        big=0.0;
        for (j=0;j<n;j++)
            if ((temp=(float)fabs(a[i][j])) > big) big=temp;
        if (big == 0.0){
            //printf("WARNING: Singular matrix in routine ludcmp!!\n");
            *d = 0;
            return;
        }
        vv[i]=(float)1.0/big;
    }
    for (j=0;j<n;j++) {
        for (i=0;i<j;i++) {
            sum=a[i][j];
            for (k=0;k<i;k++) sum -= a[i][k]*a[k][j];
            a[i][j]=sum;
        }
        big=0.0;
        for (i=j;i<n;i++) {
            sum=a[i][j];
            for (k=0;k<j;k++) sum -= a[i][k]*a[k][j];
            a[i][j]=sum;
            if ( (dum=vv[i]*(float)fabs(sum)) >= big) {
                big=dum;
                imax=i;
            }
        }
        if (j != imax) {
            for (k=0;k<n;k++) {
                dum=a[imax][k];
                a[imax][k]=a[j][k];
                a[j][k]=dum;
            }
            *d = -(*d);
            vv[imax]=vv[j];
        }
        indx[j]=imax;
        if (a[j][j] == 0.0) a[j][j]=(float)TINY;
        if (j != n-1) {
            dum=(float)1.0/(a[j][j]);
            for (i=j+1;i<n;i++) a[i][j] *= dum;
        }
    }
    //free_vector(vv,0,n-1);
}
void _mmult(float **mat1, float **mat2, float **z, int l, int m, int n) {
    int i, ii, iii;
    for(i=0; i<l; i++){
        for(ii=0; ii<n; ii++){
            z[i][ii] = 0;
            for(iii=0; iii<m; iii++){
                z[i][ii] += mat1[i][iii]*mat2[iii][ii];
            }
        }
    }
}

