/* * Copyright (c) 2020 Wellcome Centre for Human Neuroimaging * John Ashburner, Mikael Brudfors & Yael Balbastre * $Id: gmmlib.c 8065 2021-02-15 12:42:04Z john $ */ #include "spm_mex.h" #include #include #include #define EXP(x) fastexp(x) typedef struct { double *s2; double *s1; double *s0; } SStype; typedef struct { mwSize P; double *mu; double *b; double *W; double *nu; double *gam; double *conN; double *conT; } GMMtype; typedef struct { mwSize Po; mwSize Pm; unsigned char *observed; double *Wt; double *L_mm; double *m_m; double *m_o; unsigned char *obs; } MissInfType; static const double pi = 3.1415926535897931; #define MaxChan ((mwSize)50) /* largest integer valued float is 2^52 */ #define Undefined ((mwSize)0xFFFFFFFFFFFFF) /* A (hopefully) faster approximation to exp. * Note that some precision is lost for values * of x further from integers. */ static double fastexp(double x) { double r, rr; mwSignedIndex i; static double lkp_mem[256], *exp_lkp = lkp_mem+128; /* exp(i+r) = exp(i)*exp(r), where: * exp(i) (where i is an integer) is from the lookup table; * exp(r) (residual) is from a generalised continued fraction * https://en.wikipedia.org/wiki/Exponential_function#Continued_fractions_for_ex * * Should not encounter values more extreme than -128 or 127, * particularly as the upper limit of x will be 0 and values * of x below log(eps)=-36.04 should be numerically equivalent.*/ i = (mwSignedIndex)rint(x); if (i<-128) i = -128; if (i> 127) i = 127; if (exp_lkp[i]==0.0) exp_lkp[i] = exp((double)i); r = x - (double)i; rr = r*r; /* return exp_lkp[i] * (1.0+2.0*r/(2.0-r+rr/(6.0+rr/(10.0+rr/14.0)))); * return exp_lkp[i] * (1.0+2.0*r/(2.0-r+rr/(6.0+rr/(10.0)))); */ return exp_lkp[i] * (1.0+2.0*r/(2.0-r+rr/6.0)); } static mwSize is_observed(mwSize code, mwSize i) { return (code>>i) & (mwSize)1; } static mwSize num_observed(mwSize code, mwSize P) { mwSize i, Po; for(i=0, Po=0; imx) mx = q[k]; } for(k=0, s=0.0; kmx) mx = q[k]; for(k=0, s=0.0; kmx) mx = q[k]; for(k=0, s=EXP(-mx); k=0; k--) sm -= a[i*n+k] * a[j*n+k]; if(i==j) { if(sm <= sm0) sm = sm0; p[i] = sqrt(sm); } else a[j*n+i] = sm / p[i]; } } } /* Solve a least squares problem with the results from a * Cholesky decomposition * * n - Dimension of matrix and data. * a & p - Cholesky decomposed matrix. * b - Vector of input data. * x - Vector or outputs. */ static void cholls(mwSize n, const double a[], const double p[], const double b[], /*@out@*/ double x[]) { mwSignedIndex i, k; double sm; for(i=0; i<(mwSignedIndex)n; i++) { sm = b[i]; for(k=i-1; k>=0; k--) sm -= a[i*n+k]*x[k]; x[i] = sm/p[i]; } for(i=(mwSignedIndex)n-1; i>=0; i--) { sm = x[i]; for(k=i+1; k<(mwSignedIndex)n; k++) sm -= a[k*n+i]*x[k]; x[i] = sm/p[i]; } } /* n! */ static mwSize factorial(mwSize n) { static mwSize products[21]; if (products[0]==0) { mwSize i; products[0] = 1; for(i=1; i<21; i++) products[i] = products[i-1]*i; } return products[n]; } /* Compute space required for storing sufficient statistics. * * P - Number of image volumes. * K - Number of tissue classes. * *m0, *m1 & *m2 - Space needed for the zeroeth, * first and second moments. */ void space_needed(mwSize P, mwSize K, mwSize *m0, mwSize *m1, mwSize *m2) { mwSize m; for(m=0, *m0=0, *m1=0, *m2=0; m<=P; m++) { mwSize nel; nel = K*factorial(P)/(factorial(m)*factorial(P - m)); *m0 += nel; *m1 += nel*m; *m2 += nel*m*m; } } /* Allocate memory for a data structure for representing * GMMs with missing data * * P - Number of images/channels. * K - Number of Gaussians */ static /*@null@*/ GMMtype *allocate_gmm(mwSize P, mwSize K) { mwSize o, code, i, n0=0,n1=0,n2=0; double *buf; unsigned char *bytes; GMMtype /*@NULL@*/ *gmm; space_needed(P, K, &n0, &n1, &n2); o = ((mwSize)1<nf[2]) n2 = nf[2]; n1 = nm[1]/skip[1]; if (n1>nf[1]) n1 = nf[1]; n0 = nm[0]/skip[0]; if (n0>nf[0]) n0 = nf[0]; for(i2=0; i20 && logpriors(Nm, lp+im, K, lkp, p)!=0) { mwSize j, j1, k, Po; double *s0, *s1, *s2; Nloglikelihoods(K, gmm, code, mx, vx, p); if (label!=NULL) { mwSize labi = (mwSize)(label[im]); Dloglikelihoods(labi, K, lnP, p); ll += softmax1(K,p,p); for(k=0; k=MaxChan || K>=128) return NAN; if ((gmm = sub_gmm(P, K, mu, b, W, nu, gam))==NULL) return NAN; if ((suffstat = suffstat_pointers(P, K, s0_ptr, s1_ptr, s2_ptr)) == NULL) { (void)free((void *)gmm); return NAN; } ll = suffstats_missing(nf, mf, vf, label, K, gmm, lnP, nm, skip, lkp, lp, suffstat, H); (void)free((void *)gmm); (void)free((void *)suffstat); return ll; } /* Compute responsibilities in a way that handles missing data. * Responsibilities used for fitting the GMM are constructed * from a VB GMM, whereas those not used ar constructed from * a VB mixture of Student's T distributions. * * nf - Vector of dimensions (n_x, n_y, n_z, P). * mf - E[f], dimensions nf. * vf - Var[f], dimensions nf. * gmm - Gaussian mixture model data structure. * nm - Dimensions of log tissue priors (4 elements). * skip - Sampling density for GMM vs TMM (in x, y and z). * lkp - Lookup table relating Gaussians to tissue classes. * lp - Log tissue priors * r - Responsibilities (n_x, n_y, n_z, max(lkp)). */ static double responsibilities(mwSize nf[], mwSize skip[], float mf[], float vf[], unsigned char label[], mwSize K, GMMtype *gmm, double lnP[], mwSize K1, mwSize lkp[], float lp[], float r[]) { mwSize P, N1, i0,i1,i2; double ll = 0.0, mx[MaxChan], vx[MaxChan], p[128]; P = nf[3]; N1 = nf[0]*nf[1]*nf[2]; for(i2=0; i2=MaxChan || K>=128) return NAN; if ((gmm = sub_gmm(P, K, mu, b, W, nu, gam))==NULL) return NAN; ll = responsibilities(nf, skip, mf, vf, label, K, gmm, lnP, K1, lkp, lp, r); (void)free((void *)gmm); return ll; } /* Gradient and Hessian for INU updates * * The computations (two channels only) can be checked with % Some MATLAB Symbolic Toolbox working... syms w_11 w_12 w_22 mu_1 mu_2 x_1 x_2 b_1 b_2 mx_1 mx_2 real syms vx_1 vx_2 positive W = [w_11 w_12; w_12 w_22]; % Precision of Gaussian mu = [mu_1; mu_2]; % Mean of Gaussian x = [x_1; x_2]; mx = [mx_1; mx_2]; % E[x] B = diag([b_1; 0]); % INU as a funciton of b_1 % Objective function for a single Gaussian. Extending to more is trivial. E0 = (x-expm(-B)*mu)'*(expm(B)'*W*expm(B))*(x-expm(-B)*mu)/2 - log(det(expm(B)'*W*expm(B)))/2; % The above objective function is equivalent to: E = (expm(B)*x-mu)'*W*(expm(B)*x-mu)/2 - log(det(expm(B)'*W*expm(B)))/2; % We're using a VB approach with x ~ N(mx,diag(vx)), so compute the expected E. pdf1 = sym('1/sqrt(2*pi*vx_1)*exp(-(x_1-mx_1)^2/(2*vx_1))'); % x_1 ~ N(mx_1,vx_1) pdf2 = sym('1/sqrt(2*pi*vx_2)*exp(-(x_2-mx_2)^2/(2*vx_2))'); % x_2 ~ N(mx_2,vx_2) E = simplify(int(int(E*pdf1*pdf2,x_1,-Inf,Inf),x_2,-Inf,Inf),1000) % Expectation (takes a while) % We now assume mx is the expectation of the INU corrected image according to the % old parameters and vx is the expected variance. Because % exp(b+b_old)*x = exp(b)*exp(b_old)*x, we can now assume our initial estimates for % b are zero and treat exp(b_old)*x as x. % A quadratic approximation (around b=0) is obtained by: E0 = simplify(subs(E,b_1,0),1000); G0 = simplify(subs(diff(E,b_1),b_1,0),1000); % Gradient H0 = simplify(subs(diff(diff(E,b_1),b_1),b_1,0),1000); % Hessian E_quad = E0 + b_1*G0 + b_1^2*H0/2; % Local quadratic approximation % Gradients (g) to use: g = W(1,1)*vx_1 + mx_1*W(1,:)*(mx-mu) - 1 if simplify(G0 - g)~=0, disp('There''s a problem.'); end % Hessian approximation (H) to use: fprintf('g>0: '); H = W(1,1)*(mx_1^2+vx_1) + 1 + g if simplify(H-H0)~=0, disp('There''s a problem.'); end fprintf('g<0: '); H = W(1,1)*(mx_1^2+vx_1) + 1 Ha = simplify(subs(H0,mx_2,solve(G0==0,mx_2)),1000); % Check the workings if simplify(H+g-H0)~=0, disp('There''s a problem.'); end * * nf - Vector of dimensions (n_x, n_y, n_z, P). * mf - E[f], dimensions nf. * vf - Var[f], dimensions nf. * gmm - Gaussian mixture model data structure. * nm - Dimensions of log tissue priors (4 elements). * skip - Sampling density for log tissue priors (in x, y and z). * lkp - Lookup table relating Gaussians to tissue classes. * lp - Log tissue priors * g1 - Output gradients (n_x, n_y, n_z). * g2 - Output Hessian (n_x, n_y, n_z). * */ static double INUgrads(mwSize nf[], float mf[], float vf[], unsigned char label[], mwSize K, GMMtype gmm[], double lnP[], mwSize nm[], mwSize skip[], mwSize lkp[], float lp[], mwSize index[], float g1[], float g2[]) { mwSize P, Nf, Nm, i0,i1,i2, n0,n1,n2; double ll=0.0, mx[MaxChan], vx[MaxChan], p[128]; P = nf[3]; Nf = nf[0]*nf[1]*nf[2]; Nm = nm[0]*nm[1]*nm[2]; n2 = nm[2]/skip[2]; if (n2>nf[2]) n2 = nf[2]; n1 = nm[1]/skip[1]; if (n1>nf[1]) n1 = nf[1]; n0 = nm[0]/skip[0]; if (n0>nf[0]) n0 = nf[0]; if (P>=MaxChan || K>=128) return NAN; for(i2=0; i2=MaxChan || K>=128) return NAN; if ((gmm = sub_gmm(P, K, mu, b, W, nu, gam))==NULL) return NAN; index = (mwSize *)calloc((size_t)1<<(size_t)P, sizeof(mwSize)); if (index == NULL) { (void)free((void *)gmm); return NAN; } make_index(P, ic, index); ll = INUgrads(nf, mf, vf, label, K, gmm, lnP, nm, skip, lkp, lp, index, g1, g2); (void)free((void *)gmm); (void)free((void *)index); return ll; } /* Fill in missing data, replacing NaNs with expectations * * nf - Vector of dimensions (n_x, n_y, n_z, P). * mf - E[f], dimensions nf. * vf - Var[f], dimensions nf. * gmm - Gaussian mixture model data structure. * nm - Dimensions of log tissue priors (4 elements). * skip - Sampling density for log tissue priors (in x, y and z). * lkp - Lookup table relating Gaussians to tissue classes. * lp - Log tissue priors * mx1 - Output data (n_x, n_y, n_z, P). */ static int fill_missing(mwSize nf[], float mf[], float vf[], unsigned char label[], mwSize K, GMMtype gmm[], MissInfType missinf[],double lnP[], mwSize nm[], mwSize skip[], mwSize lkp[], float lp[], float mx1[]) { mwSize P, Nf, Nm, i0,i1,i2, n0,n1,n2; double mx[MaxChan], vx[MaxChan], p[128]; P = nf[3]; Nf = nf[0]*nf[1]*nf[2]; Nm = nm[0]*nm[1]*nm[2]; n2 = nm[2]/skip[2]; if (n2>nf[2]) n2 = nf[2]; n1 = nm[1]/skip[1]; if (n1>nf[1]) n1 = nf[1]; n0 = nm[0]/skip[0]; if (n0>nf[0]) n0 = nf[0]; if (P>=MaxChan || K>=128) return -1; for(i2=0; i2=MaxChan || K>=128) return -1; if ((gmm = sub_gmm(P, K, mu, b, W, nu, gam))==NULL) return -1; if ((missinf = prepare_missinf(P, K, W, mu))==NULL) { (void)free((void *)gmm); return -1; } sts = fill_missing(nf, mf, vf, label, K, gmm, missinf, lnP, nm, skip, lkp, lp, mx1); (void)free((void *)gmm); (void)free((void *)missinf); return sts; }