/* $Id: shoot_dartel.c 4875 2012-08-30 20:04:30Z john $ */ /* (c) John Ashburner (2011) */ #include #include #include #include "shoot_optim3d.h" #include "shoot_diffeo3d.h" #include "shoot_regularisers.h" #include "shoot_boundary.h" extern double log(double x); extern double exp(double x); #define LOG(x) (((x)>0) ? log(x+0.001): -6.9078) /* * In place Cholesky decomposition */ void chol3(mwSize m, float A[]) { float *p00 = A, *p11 = A+m, *p22 = A+m*2, *p01 = A+m*3, *p02 = A+m*4, *p12 = A+m*5; double a00, a11, a22, a01, a02, a12; double s; mwSignedIndex i; for(i=0; i1) { return(smalldef_objfun2(dm, f, g, v, jd, sc, b, A)); } j = 0; for(j2=0; j21) { return(initialise_objfun2(dm, f, g, t0, J0, jd, b, A)); } for(j=0; j0) { m1 = 30*m; if (code==1) m1 += 9*m; m2 = 9*m+fmg3_scratchsize(dm,1); if (m1>m2) return(m1); else return(m2); } else { m1 = 9*m; if (code==1) m1 += 6*m; m2 = 9*m + fmg3_scratchsize(dm,1); if (m1>m2) return(m1); else return(m2); } } void iteration(mwSize dm[], int k, float v[], float g[], float f[], float jd[], double param0[], double lmreg0, int cycles, int its, int code, float ov[], double ll[], float *buf) { float *sbuf; float *b, *A; double ssl, ssp, sc; static double param[] = {1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0}; mwSignedIndex m = dm[0]*dm[1]*dm[2]; mwSignedIndex j; /* Allocate memory. 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 [ A A A A A A t t t J J J J J J J J J t t t J J J J J J J J J] for computing derivatives */ b = ov; A = buf; sbuf = buf + 6*m; if(k>0) { float *t0, *t1, *J0, *J1; t0 = buf + 6*m; J0 = buf + 9*m; t1 = buf + 18*m; J1 = buf + 21*m; sc = 1.0/pow2(k); expdef(dm, k, 1.0, v, t0, t1, J0, J1); jac_div_smalldef(dm, sc, v, J0); if (code==2) ssl = initialise_objfun_mn(dm, f, g, t0, J0, jd, b, A); else ssl = initialise_objfun(dm, f, g, t0, J0, jd, b, A); smalldef_jac(dm, -sc, v, t0, J0); squaring(dm, k, code==1, b, A, t0, t1, J0, J1); if (code==1) { float *b1, *A1; A1 = buf + 30*m; b1 = buf + 36*m; jac_div_smalldef(dm, -sc, v, J0); ssl += initialise_objfun(dm, g, f, t0, J0, (float *)0, b1, A1); smalldef_jac(dm, sc, v, t0, J0); squaring(dm, k, 0, b1, A1, t0, t1, J0, J1); for(j=0; j0.0) param[3] = param[3] + lmreg0; fmg3(dm, A, b, param, cycles, its, sbuf, sbuf+3*m); for(j=0; j2) mexErrMsgTxt("Incorrect usage"); for(i=0; i<3; i++) if (!mxIsNumeric(prhs[i]) || mxIsComplex(prhs[i]) || mxIsSparse(prhs[i]) || !mxIsSingle(prhs[i])) mexErrMsgTxt("Data must be numeric, real, full and single"); if (!mxIsNumeric(prhs[3]) || mxIsComplex(prhs[3]) || mxIsSparse(prhs[3]) || !mxIsDouble(prhs[3])) mexErrMsgTxt("Data must be numeric, real, full and double"); if (mxGetNumberOfDimensions(prhs[0])!=4) mexErrMsgTxt("Wrong number of dimensions."); if (mxGetNumberOfDimensions(prhs[1])>4) mexErrMsgTxt("Wrong number of dimensions."); if (mxGetNumberOfDimensions(prhs[2])!=mxGetNumberOfDimensions(prhs[1])) mexErrMsgTxt("Incompatible number of dimensions."); dm[0] = mxGetDimensions(prhs[0])[0]; dm[1] = mxGetDimensions(prhs[0])[1]; dm[2] = mxGetDimensions(prhs[0])[2]; dm[3] = mxGetDimensions(prhs[0])[3]; if (dm[3]!=3) mexErrMsgTxt("4th dimension of 1st arg must be 3."); if (mxGetDimensions(prhs[1])[0] != dm[0]) mexErrMsgTxt("Incompatible 1st dimension."); if (mxGetDimensions(prhs[1])[1] != dm[1]) mexErrMsgTxt("Incompatible 2nd dimension."); if (mxGetNumberOfDimensions(prhs[1])>=3 && mxGetDimensions(prhs[1])[2] != dm[2]) mexErrMsgTxt("Incompatible 3rd dimension."); if (mxGetDimensions(prhs[2])[0] != dm[0]) mexErrMsgTxt("Incompatible 1st dimension."); if (mxGetDimensions(prhs[2])[1] != dm[1]) mexErrMsgTxt("Incompatible 2nd dimension."); if (mxGetNumberOfDimensions(prhs[2])>=3 && mxGetDimensions(prhs[2])[2] != dm[2]) mexErrMsgTxt("Incompatible 3rd dimension."); if (nrhs>=5) { if (!mxIsNumeric(prhs[4]) || mxIsComplex(prhs[4]) || mxIsSparse(prhs[4]) || !mxIsSingle(prhs[4])) mexErrMsgTxt("Data must be numeric, real, full and single"); if (mxGetNumberOfDimensions(prhs[4])!=3) mexErrMsgTxt("Wrong number of dimensions."); if (mxGetDimensions(prhs[4])[0] != dm[0]) mexErrMsgTxt("Incompatible 1st dimension."); if (mxGetDimensions(prhs[4])[1] != dm[1]) mexErrMsgTxt("Incompatible 2nd dimension."); if (mxGetDimensions(prhs[4])[2] != dm[2]) mexErrMsgTxt("Incompatible 3rd dimension."); jd = (float *)mxGetPr(prhs[4]); } if (mxGetNumberOfElements(prhs[3]) >10) mexErrMsgTxt("Fourth argument should contain param1, param2, param3, param4, param5, LMreg, ncycles, nits, nsamps and code."); if (mxGetNumberOfElements(prhs[3]) >=1) param[3] = mxGetPr(prhs[3])[0]; if (mxGetNumberOfElements(prhs[3]) >=2) param[4] = mxGetPr(prhs[3])[1]; if (mxGetNumberOfElements(prhs[3]) >=3) param[5] = mxGetPr(prhs[3])[2]; if (mxGetNumberOfElements(prhs[3]) >=4) param[6] = mxGetPr(prhs[3])[3]; if (mxGetNumberOfElements(prhs[3]) >=5) param[7] = mxGetPr(prhs[3])[4]; if (mxGetNumberOfElements(prhs[3]) >=6) lmreg0 = mxGetPr(prhs[3])[5]; if (mxGetNumberOfElements(prhs[3]) >=7) cycles = mxGetPr(prhs[3])[6]; if (mxGetNumberOfElements(prhs[3]) >=8) its = mxGetPr(prhs[3])[7]; if (mxGetNumberOfElements(prhs[3]) >=9) k = mxGetPr(prhs[3])[8]; if (mxGetNumberOfElements(prhs[3]) >=10) code = mxGetPr(prhs[3])[9]; plhs[0] = mxCreateNumericArray(4,dm, mxSINGLE_CLASS, mxREAL); plhs[1] = mxCreateNumericArray(2,nll, mxDOUBLE_CLASS, mxREAL); v = (float *)mxGetPr(prhs[0]); g = (float *)mxGetPr(prhs[1]); f = (float *)mxGetPr(prhs[2]); ov = (float *)mxGetPr(plhs[0]); ll = (double*)mxGetPr(plhs[1]); scratch = (float *)mxCalloc(iteration_scratchsize((mwSize *)dm, code,k),sizeof(float)); dm[3] = 1; if (mxGetNumberOfDimensions(prhs[1])>=4) dm[3] = mxGetDimensions(prhs[1])[3]; /* set_bound(0); */ iteration(dm, k, v, g, f, jd, param, lmreg0, cycles, its, code, ov, ll, scratch); mxFree((void *)scratch); } void exp_mexFunction(mwSize nlhs, mxArray *plhs[], mwSize nrhs, const mxArray *prhs[]) { int k=6; mwSize nd; const mwSize *dm; float *v, *t, *t1; double sc = 1.0; int flg = 0; if (((nrhs != 1) && (nrhs != 2)) || (nlhs>2)) mexErrMsgTxt("Incorrect usage."); if (!mxIsNumeric(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0]) || !mxIsSingle(prhs[0])) mexErrMsgTxt("Data must be numeric, real, full and single"); nd = mxGetNumberOfDimensions(prhs[0]); if (nd!=4) mexErrMsgTxt("Wrong number of dimensions."); dm = mxGetDimensions(prhs[0]); if (dm[3]!=3) mexErrMsgTxt("4th dimension must be 3."); if (nrhs>1) { if (!mxIsNumeric(prhs[1]) || mxIsComplex(prhs[1]) || mxIsSparse(prhs[1]) || !mxIsDouble(prhs[1])) mexErrMsgTxt("Data must be numeric, real, full and double"); if (mxGetNumberOfElements(prhs[1]) > 3) mexErrMsgTxt("Params must contain one to three elements"); if (mxGetNumberOfElements(prhs[1]) >= 1) k = (int)(mxGetPr(prhs[1])[0]); if (mxGetNumberOfElements(prhs[1]) >= 2) sc = (float)(mxGetPr(prhs[1])[1]); if (mxGetNumberOfElements(prhs[1]) >= 3) flg = (int)(mxGetPr(prhs[1])[2]); } v = (float *)mxGetPr(prhs[0]); plhs[0] = mxCreateNumericArray(nd,dm, mxSINGLE_CLASS, mxREAL); t = (float *)mxGetPr(plhs[0]); t1 = mxCalloc(dm[0]*dm[1]*dm[2]*3,sizeof(float)); /* set_bound(0); */ if (nlhs < 2) { expdef((mwSize *)dm, k, sc, v, t, t1, (float *)0, (float *)0); } else { float *J, *J1; mwSize dmj[5]; dmj[0] = dm[0]; dmj[1] = dm[1]; dmj[2] = dm[2]; if (flg==0) { dmj[3] = 3; dmj[4] = 3; plhs[1] = mxCreateNumericArray(5,dmj, mxSINGLE_CLASS, mxREAL); J = (float *)mxGetPr(plhs[1]); J1 = mxCalloc(dm[0]*dm[1]*dm[2]*3*3,sizeof(float)); expdef((mwSize *)dm, k, sc, v, t, t1, J, J1); } else { plhs[1] = mxCreateNumericArray(3,dmj, mxSINGLE_CLASS, mxREAL); J = (float *)mxGetPr(plhs[1]); J1 = mxCalloc(dm[0]*dm[1]*dm[2],sizeof(float)); expdefdet((mwSize *)dm, k, sc, v, t, t1, J, J1); } mxFree((void *)J1); } unwrap((mwSize *)dm, t); mxFree((void *)t1); }