Home > machines > prt_machine_gpclap.m

prt_machine_gpclap

PURPOSE ^

Run multiclass Gaussian process classification (Laplace approximation)

SYNOPSIS ^

function output = prt_machine_gpclap(d,args)

DESCRIPTION ^

 Run multiclass Gaussian process classification (Laplace approximation)
 FORMAT output = prt_machine_gpclap(d,args)
 Inputs:
   d         - structure with data information, with mandatory fields:
     .train      - training data (cell array of matrices of row vectors,
                   each [Ntr x D]). each matrix contains one representation
                   of the data. This is useful for approaches such as
                   multiple kernel learning.
     .test       - testing data  (cell array of matrices row vectors, each
                   [Nte x D])
     .testcov    - testing covariance (cell array of matrices row vectors,
                   each [Nte x Nte])
     .tr_targets - training labels (for classification) or values (for
                   regression) (column vector, [Ntr x 1])
     .use_kernel - flag, is data in form of kernel matrices (true) or in 
                   form of features (false)
    args     - argument string, where
       -h         - optimise hyperparameters (otherwise don't)
       -c covfun  - covariance function:
                       'covLINkcell' - simple dot product
                       'covLINglm'   - construct a GLM
    experimental args (use at your own risk):
       -p         - use priors for the hyperparameters. If specified, this
                    indicates that a maximum a posteriori (MAP) approach
                    will be used to set covariance function
                    hyperparameters. The priors are obtained 
                    by calling prt_gp_priors('covFuncName')

       N.B.: for the arguments specifying functions, pass in a string, not
       a function handle. This script will generate a function handle
 
 Output:
    output  - output of machine (struct).
     * Mandatory fields:
      .predictions - predictions of classification or regression [Nte x D]
     * Optional fields:
      .type     - which type of machine this is (here, 'classifier')
      .func_val - predictive probabilties
      .loghyper - log hyperparameters
      .nlml     - negative log marginal likelihood
      .mu       - test latent means
      .s2       - test latent variances
      .alpha    - GP weighting coefficients
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function output = prt_machine_gpclap(d,args)
0002 % Run multiclass Gaussian process classification (Laplace approximation)
0003 % FORMAT output = prt_machine_gpclap(d,args)
0004 % Inputs:
0005 %   d         - structure with data information, with mandatory fields:
0006 %     .train      - training data (cell array of matrices of row vectors,
0007 %                   each [Ntr x D]). each matrix contains one representation
0008 %                   of the data. This is useful for approaches such as
0009 %                   multiple kernel learning.
0010 %     .test       - testing data  (cell array of matrices row vectors, each
0011 %                   [Nte x D])
0012 %     .testcov    - testing covariance (cell array of matrices row vectors,
0013 %                   each [Nte x Nte])
0014 %     .tr_targets - training labels (for classification) or values (for
0015 %                   regression) (column vector, [Ntr x 1])
0016 %     .use_kernel - flag, is data in form of kernel matrices (true) or in
0017 %                   form of features (false)
0018 %    args     - argument string, where
0019 %       -h         - optimise hyperparameters (otherwise don't)
0020 %       -c covfun  - covariance function:
0021 %                       'covLINkcell' - simple dot product
0022 %                       'covLINglm'   - construct a GLM
0023 %    experimental args (use at your own risk):
0024 %       -p         - use priors for the hyperparameters. If specified, this
0025 %                    indicates that a maximum a posteriori (MAP) approach
0026 %                    will be used to set covariance function
0027 %                    hyperparameters. The priors are obtained
0028 %                    by calling prt_gp_priors('covFuncName')
0029 %
0030 %       N.B.: for the arguments specifying functions, pass in a string, not
0031 %       a function handle. This script will generate a function handle
0032 %
0033 % Output:
0034 %    output  - output of machine (struct).
0035 %     * Mandatory fields:
0036 %      .predictions - predictions of classification or regression [Nte x D]
0037 %     * Optional fields:
0038 %      .type     - which type of machine this is (here, 'classifier')
0039 %      .func_val - predictive probabilties
0040 %      .loghyper - log hyperparameters
0041 %      .nlml     - negative log marginal likelihood
0042 %      .mu       - test latent means
0043 %      .s2       - test latent variances
0044 %      .alpha    - GP weighting coefficients
0045 %__________________________________________________________________________
0046 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0047 
0048 % Written by J Ashburner and A Marquand
0049 % $Id: prt_machine_gpclap.m 498 2012-04-05 13:26:23Z amarquan $
0050 
0051 % Error checks
0052 % -------------------------------------------------------------------------
0053 SANITYCHECK=true; % can turn off for "speed". Expert only.
0054 
0055 if SANITYCHECK==true
0056     % args should be a string (empty or otherwise)
0057     if ~ischar(args)
0058         error('prt_machine_gpclap:libSVMargsNotString',['Error: gpml'...
0059             ' args should be a string. ' ...
0060             ' SOLUTION: Please do XXX']);
0061     end
0062     % are we using a kernel ?
0063     if ~d.use_kernel
0064         error('prt_machine_gpclap:useKernelIsFalse',['Error:'...
0065             ' This machine is currently only implemented for kernel data ' ...
0066             'SOLUTION: Please set use_kernel to true']);
0067     end
0068 end
0069 
0070 % configure default parameters for GP optimisation
0071 covfunc   = @covLINkcell;
0072 mode      = 'classifier';
0073 
0074 % parse input arguments
0075 % -------------------------------------------------------------------------
0076 % hyperparameters
0077 if ~isempty(regexp(args,'-h','once'))
0078     optimise_theta = true;
0079     eargs = regexp(args,'-f\s+[0-9]*','match');
0080     if ~isempty(eargs)
0081         eargs = regexp(cell2mat(eargs),'-f\s+','split');
0082         maxeval  = str2num(['-',cell2mat(eargs(2))]);
0083     end
0084 else
0085     optimise_theta = false;
0086 end
0087 % covariance function
0088 cargs = regexp(args,'-c\s+[a-zA-Z0-9_]*','match');
0089 if ~isempty(cargs)
0090     cargs = regexp(cell2mat(cargs),'-c\s+','split');
0091     covfunc = str2func(cell2mat(cargs(2)));
0092 end
0093 % priors
0094 if ~isempty(regexp(args,'-p','once'))
0095     disp('Empirical priors specified. Using MAP for hyperparameters')
0096     priors = prt_gp_priors(func2str(covfunc));
0097     map = true;
0098 else
0099     map = false;
0100 end
0101 
0102 % Set default hyperparameters and objective function
0103 % -------------------------------------------------------------------------
0104 nhyp = str2num(feval(covfunc));
0105 if nhyp(1) > 0
0106     hyp = zeros(nhyp(1),1);
0107 end
0108 if map
0109     objfunc = @gp_objfun_map;
0110 else
0111     objfunc = @gp_objfun;
0112 end
0113 
0114 % Assemble data matrices
0115 % -------------------------------------------------------------------------
0116 % handle the glm as a special case (for now)
0117 if strcmpi(func2str(covfunc),'covLINglm') || strcmpi(func2str(covfunc),'covLINglm_2class')
0118     % configure covariances
0119     K   = [d.train(:)'   {d.tr_param}];
0120     Ks  = [d.test(:)'    {d.te_param}];
0121     Kss = [d.testcov(:)' {d.te_param}];
0122     
0123     % get default hyperparamter values
0124     hyp = log(prt_glm_design);
0125     
0126     [tmp1 tmp2 tmp3 tr_lbs] = prt_glm_design(hyp, d.tr_param);
0127     [tmp1 tmp2 tmp3 te_lbs] = prt_glm_design(hyp, d.te_param);    
0128 else
0129     % configure covariances
0130     K   = d.train;
0131     Ks  = d.test;
0132     Kss = d.testcov;
0133     
0134     tr_lbs = d.tr_targets;
0135     te_lbs = d.te_targets;
0136 end 
0137 
0138 % create one-of-k labels
0139 k = max(unique(tr_lbs));
0140 n = length(tr_lbs);
0141 Y = zeros(n,k);
0142 for j = 1:n 
0143     Y(j,tr_lbs(j)) = 1;
0144 end
0145 
0146 % Train and test GP model
0147 % -------------------------------------------------------------------------
0148 % train
0149 if optimise_theta
0150     nh = numel(hyp);
0151     if map
0152         objfunc = @gp_objfun_map;
0153     end
0154     hyp = spm_powell(hyp,eye(nh),ones(nh,1)*0.05,objfunc,Y,K,covfunc); 
0155 end
0156 % compute marginal likelihood and posterior parameters
0157 [f lml alpha] = gp_lap_multiclass(hyp,covfunc,K,Y);
0158 
0159 % make predictions
0160 [p mu sigma]  = gp_pred_lap_multiclass(hyp,K,Y,covfunc,Ks,Kss);
0161 [maxp pred]   = max(p,[],2);
0162 
0163 % Outputs
0164 % -------------------------------------------------------------------------
0165 output.predictions = pred;
0166 output.type        = mode;
0167 output.func_val    = p;
0168 output.tr_targets  = tr_lbs;
0169 output.te_targets  = te_lbs; 
0170 output.mu          = mu;
0171 output.sigma       = sigma;
0172 output.loghyper    = hyp;
0173 output.nlml        = -lml;
0174 output.alpha       = alpha;
0175 end
0176 
0177 % -------------------------------------------------------------------------
0178 % Private functions
0179 % -------------------------------------------------------------------------
0180 
0181 function E = gp_objfun(logtheta,t,X,covfunc)
0182 % Objective function to minimise
0183 
0184 [f,F]   = gp_lap_multiclass(logtheta,covfunc,X,t);
0185 E = -F; %+ 1e-6*(logtheta'*logtheta);
0186 end
0187 
0188 % -------------------------------------------------------------------------
0189 function E = gp_objfun_map(logtheta,t,X,covfunc)
0190 % Objective function to minimise in a MAP setting
0191 
0192 E = gp_objfun(logtheta,t,X,covfunc);
0193 
0194 % priors
0195 priors = prt_gp_priors(func2str(covfunc));
0196 
0197 if iscell(covfunc)
0198     d = str2double(feval(covfunc{:}));
0199 else
0200     d = str2double(feval(covfunc));
0201 end
0202 % compute priors
0203 theta = exp(logtheta);
0204 lP  = zeros(d,1);
0205 %dlP = zeros(d,1);
0206 for i = 1:d
0207     switch priors(i).type
0208         case 'gauss'
0209             mu = priors(i).param(1);
0210             s2 = priors(i).param(2);
0211             
0212             lP(i)  = ( -0.5*log(2*pi) - 0.5*log(s2) - 0.5*(theta(i)-mu)^2/s2);
0213             %dlP(i) = (-(theta(i)-mu) / s2);
0214             
0215         case 'gamma'
0216             a = priors(i).param(1)*priors(i).param(2) + 1;
0217             b = priors(i).param(2);
0218             
0219             lP(i) = (a*log(b) - gammaln(a) + (a - 1)*log(theta(i)) - b*theta(i));
0220             %%lP(i)  = log(gampdf(theta(i), a, 1/b));
0221             %dlP(i) =  ((a - 1) / theta(i) - b);
0222             
0223         otherwise
0224             error(['Unknown prior type: ', priors(i).type]);
0225     end
0226 end
0227 
0228 % outputs
0229 nlP = -sum(lP);
0230 %pnlZ  = nlZ + nlP;
0231 
0232 E = E + nlP;
0233 end
0234 
0235 % -------------------------------------------------------------------------
0236 function [f,F,a] = gp_lap_multiclass(logtheta,covfunc,X,t,f)
0237 % Find mode for Laplace approximation for multi-class classification.
0238 % Derived mostly from Rasmussen & Williams
0239 % Algorithm 3.3 (page 50).
0240 [N,C] = size(t);
0241 if nargin<5, f = zeros(N,C); end;
0242 %if norm(K)>1e8, F=-1e10; return; end
0243 
0244 K = covfunc(logtheta,X);
0245 
0246 for i=1:32,
0247     f   = f - repmat(max(f,[],2),1,size(f,2));
0248     sig = exp(f)+eps;
0249     sig = sig./repmat(sum(sig,2),1,C);
0250     E   = zeros(N,N,C);
0251     for c1=1:C
0252         D         = sig(:,c1);
0253         sD        = sqrt(D);
0254         L         = chol(eye(N) + K.*(sD*sD'));
0255         E(:,:,c1) = diag(sD)*(L\(L'\diag(sD)));
0256        %z(c1)     = sum(log(diag(L)));
0257     end
0258     M = chol(sum(E,3));
0259 
0260     b = t-sig+sig.*f;
0261     for c1=1:C,
0262         for c2=1:C,
0263             b(:,c1) = b(:,c1) - sig(:,c1).*sig(:,c2).*f(:,c2);
0264         end
0265     end
0266 
0267     c   = zeros(size(t));
0268     for c1=1:C,
0269         c(:,c1) = E(:,:,c1)*K*b(:,c1);
0270     end
0271     tmp = M\(M'\sum(c,2));
0272     a   = b-c;
0273     for c1=1:C,
0274         a(:,c1) = a(:,c1) + E(:,:,c1)*tmp;
0275     end
0276     of = f;
0277     f  = K*a;
0278    
0279     %fprintf('%d -> %g %g %g\n', i,-0.5*a(:)'*f(:), t(:)'*f(:), -sum(log(sum(exp(f),2)),1));
0280     if sum((f(:)-of(:)).^2)<(20*eps)^2*numel(f), break; end
0281 end
0282 if nargout>1
0283     % Really not sure about sum(z) as being the determinant.
0284     % hlogdet = sum(z);
0285 
0286     R  = null(ones(1,C));
0287     sW = sparse([],[],[],N*(C-1),N*(C-1));
0288     for i=1:N,
0289         ind         = (0:(C-2))*N+i;
0290         P           = sig(i,:)';
0291         D           = diag(P);
0292         sW(ind,ind) = sqrtm(R'*(D-P*P')*R);
0293     end
0294     hlogdet = sum(log(diag(chol(speye(N*(C-1))+sW*kron(eye(C-1),K)*sW))));
0295     F       = -0.5*a(:)'*f(:) + t(:)'*f(:) - sum(log(sum(exp(f),2)),1) - hlogdet;
0296     %fprintf('%g %g %g\n', -0.5*a(:)'*f(:) + t(:)'*f(:) - sum(log(sum(exp(f),2)),1), -hlogdet, F);
0297 end
0298 end
0299 
0300 % -------------------------------------------------------------------------
0301 function [p Mu SS] = gp_pred_lap_multiclass(logtheta,X,t,covfunc,Xs,Xss,f)
0302 % Predictions for Laplace approximation to multi-class classification.
0303 % Derived mostly from Rasmussen & Williams
0304 % Algorithm 3.4 (page 51).
0305 
0306 [N,C] = size(t);
0307 
0308 K   = covfunc(logtheta,X);
0309 Ks  = covfunc(logtheta,X,Xs);
0310 kss = covfunc(logtheta,Xss,'diag');
0311 
0312 if nargin<7,
0313     f = gp_lap_multiclass(logtheta,covfunc,X,t);
0314 end
0315 
0316 sig = exp(f);
0317 sig = sig./repmat(sum(sig,2)+eps,1,C);
0318 E   = zeros(N,N,C);
0319 for c1=1:C   
0320     D         = sig(:,c1);
0321     sD        = sqrt(D);
0322     L         = chol(eye(N) + K.*(sD*sD'));
0323     E(:,:,c1) = diag(sD)*(L\(L'\diag(sD) ));
0324 end 
0325 M   = chol(sum(E,3));
0326 os  = RandStream.getDefaultStream;
0327 p   = zeros(size(Ks,2),C);
0328 j   = 0;
0329 Mu = zeros(size(Ks,2),C); SS = zeros(C,C,size(Ks,2));
0330 for i=1:size(Ks,2),
0331     j = j + 1;
0332 
0333     mu = zeros(C,1);
0334     S  = zeros(C,C);
0335     for c1=1:C,
0336         mu(c1) = (t(:,c1)-sig(:,c1))'*Ks(:,i);
0337         b      = E(:,:,c1)*Ks(:,i);
0338         c      = (M\(M'\b));
0339         for c2=1:C,
0340             S(c1,c2) = Ks(:,i)'*E(:,:,c2)*c;
0341         end
0342         S(c1,c1) = S(c1,c1) - b'*Ks(:,i) + kss(i);
0343     end
0344     
0345     % collect latent means and variances
0346     Mu(i,:)   = mu';
0347     SS(:,:,i) = S;
0348     
0349     s = RandStream.create('mt19937ar','seed',0);
0350     RandStream.setDefaultStream(s);
0351     nsamp  = 10000;
0352     r      = sqrtm(S)*randn(C,nsamp) + repmat(mu,1,nsamp);
0353     %r      = chol(S)'*randn(C,nsamp) + repmat(mu,1,nsamp);
0354     % subtract a constant to avoid numerical overflow
0355     r      = bsxfun(@minus, r, max(r, [], 1));
0356     r      = exp(r);
0357     p(j,:) = mean(r./repmat(sum(r,1),C,1),2)';
0358 end
0359 RandStream.setDefaultStream(os);
0360 end
0361

Generated on Sun 20-May-2012 13:24:48 by m2html © 2005