0001 function output = prt_machine_gpclap(d,args)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053 SANITYCHECK=true;
0054
0055 if SANITYCHECK==true
0056
0057 if ~ischar(args)
0058 error('prt_machine_gpclap:libSVMargsNotString',['Error: gpml'...
0059 ' args should be a string. ' ...
0060 ' SOLUTION: Please do XXX']);
0061 end
0062
0063 if ~d.use_kernel
0064 error('prt_machine_gpclap:useKernelIsFalse',['Error:'...
0065 ' This machine is currently only implemented for kernel data ' ...
0066 'SOLUTION: Please set use_kernel to true']);
0067 end
0068 end
0069
0070
0071 covfunc = @covLINkcell;
0072 mode = 'classifier';
0073
0074
0075
0076
0077 if ~isempty(regexp(args,'-h','once'))
0078 optimise_theta = true;
0079 eargs = regexp(args,'-f\s+[0-9]*','match');
0080 if ~isempty(eargs)
0081 eargs = regexp(cell2mat(eargs),'-f\s+','split');
0082 maxeval = str2num(['-',cell2mat(eargs(2))]);
0083 end
0084 else
0085 optimise_theta = false;
0086 end
0087
0088 cargs = regexp(args,'-c\s+[a-zA-Z0-9_]*','match');
0089 if ~isempty(cargs)
0090 cargs = regexp(cell2mat(cargs),'-c\s+','split');
0091 covfunc = str2func(cell2mat(cargs(2)));
0092 end
0093
0094 if ~isempty(regexp(args,'-p','once'))
0095 disp('Empirical priors specified. Using MAP for hyperparameters')
0096 priors = prt_gp_priors(func2str(covfunc));
0097 map = true;
0098 else
0099 map = false;
0100 end
0101
0102
0103
0104 nhyp = str2num(feval(covfunc));
0105 if nhyp(1) > 0
0106 hyp = zeros(nhyp(1),1);
0107 end
0108 if map
0109 objfunc = @gp_objfun_map;
0110 else
0111 objfunc = @gp_objfun;
0112 end
0113
0114
0115
0116
0117 if strcmpi(func2str(covfunc),'covLINglm') || strcmpi(func2str(covfunc),'covLINglm_2class')
0118
0119 K = [d.train(:)' {d.tr_param}];
0120 Ks = [d.test(:)' {d.te_param}];
0121 Kss = [d.testcov(:)' {d.te_param}];
0122
0123
0124 hyp = log(prt_glm_design);
0125
0126 [tmp1 tmp2 tmp3 tr_lbs] = prt_glm_design(hyp, d.tr_param);
0127 [tmp1 tmp2 tmp3 te_lbs] = prt_glm_design(hyp, d.te_param);
0128 else
0129
0130 K = d.train;
0131 Ks = d.test;
0132 Kss = d.testcov;
0133
0134 tr_lbs = d.tr_targets;
0135 te_lbs = d.te_targets;
0136 end
0137
0138
0139 k = max(unique(tr_lbs));
0140 n = length(tr_lbs);
0141 Y = zeros(n,k);
0142 for j = 1:n
0143 Y(j,tr_lbs(j)) = 1;
0144 end
0145
0146
0147
0148
0149 if optimise_theta
0150 nh = numel(hyp);
0151 if map
0152 objfunc = @gp_objfun_map;
0153 end
0154 hyp = spm_powell(hyp,eye(nh),ones(nh,1)*0.05,objfunc,Y,K,covfunc);
0155 end
0156
0157 [f lml alpha] = gp_lap_multiclass(hyp,covfunc,K,Y);
0158
0159
0160 [p mu sigma] = gp_pred_lap_multiclass(hyp,K,Y,covfunc,Ks,Kss);
0161 [maxp pred] = max(p,[],2);
0162
0163
0164
0165 output.predictions = pred;
0166 output.type = mode;
0167 output.func_val = p;
0168 output.tr_targets = tr_lbs;
0169 output.te_targets = te_lbs;
0170 output.mu = mu;
0171 output.sigma = sigma;
0172 output.loghyper = hyp;
0173 output.nlml = -lml;
0174 output.alpha = alpha;
0175 end
0176
0177
0178
0179
0180
0181 function E = gp_objfun(logtheta,t,X,covfunc)
0182
0183
0184 [f,F] = gp_lap_multiclass(logtheta,covfunc,X,t);
0185 E = -F;
0186 end
0187
0188
0189 function E = gp_objfun_map(logtheta,t,X,covfunc)
0190
0191
0192 E = gp_objfun(logtheta,t,X,covfunc);
0193
0194
0195 priors = prt_gp_priors(func2str(covfunc));
0196
0197 if iscell(covfunc)
0198 d = str2double(feval(covfunc{:}));
0199 else
0200 d = str2double(feval(covfunc));
0201 end
0202
0203 theta = exp(logtheta);
0204 lP = zeros(d,1);
0205
0206 for i = 1:d
0207 switch priors(i).type
0208 case 'gauss'
0209 mu = priors(i).param(1);
0210 s2 = priors(i).param(2);
0211
0212 lP(i) = ( -0.5*log(2*pi) - 0.5*log(s2) - 0.5*(theta(i)-mu)^2/s2);
0213
0214
0215 case 'gamma'
0216 a = priors(i).param(1)*priors(i).param(2) + 1;
0217 b = priors(i).param(2);
0218
0219 lP(i) = (a*log(b) - gammaln(a) + (a - 1)*log(theta(i)) - b*theta(i));
0220
0221
0222
0223 otherwise
0224 error(['Unknown prior type: ', priors(i).type]);
0225 end
0226 end
0227
0228
0229 nlP = -sum(lP);
0230
0231
0232 E = E + nlP;
0233 end
0234
0235
0236 function [f,F,a] = gp_lap_multiclass(logtheta,covfunc,X,t,f)
0237
0238
0239
0240 [N,C] = size(t);
0241 if nargin<5, f = zeros(N,C); end;
0242
0243
0244 K = covfunc(logtheta,X);
0245
0246 for i=1:32,
0247 f = f - repmat(max(f,[],2),1,size(f,2));
0248 sig = exp(f)+eps;
0249 sig = sig./repmat(sum(sig,2),1,C);
0250 E = zeros(N,N,C);
0251 for c1=1:C
0252 D = sig(:,c1);
0253 sD = sqrt(D);
0254 L = chol(eye(N) + K.*(sD*sD'));
0255 E(:,:,c1) = diag(sD)*(L\(L'\diag(sD)));
0256
0257 end
0258 M = chol(sum(E,3));
0259
0260 b = t-sig+sig.*f;
0261 for c1=1:C,
0262 for c2=1:C,
0263 b(:,c1) = b(:,c1) - sig(:,c1).*sig(:,c2).*f(:,c2);
0264 end
0265 end
0266
0267 c = zeros(size(t));
0268 for c1=1:C,
0269 c(:,c1) = E(:,:,c1)*K*b(:,c1);
0270 end
0271 tmp = M\(M'\sum(c,2));
0272 a = b-c;
0273 for c1=1:C,
0274 a(:,c1) = a(:,c1) + E(:,:,c1)*tmp;
0275 end
0276 of = f;
0277 f = K*a;
0278
0279
0280 if sum((f(:)-of(:)).^2)<(20*eps)^2*numel(f), break; end
0281 end
0282 if nargout>1
0283
0284
0285
0286 R = null(ones(1,C));
0287 sW = sparse([],[],[],N*(C-1),N*(C-1));
0288 for i=1:N,
0289 ind = (0:(C-2))*N+i;
0290 P = sig(i,:)';
0291 D = diag(P);
0292 sW(ind,ind) = sqrtm(R'*(D-P*P')*R);
0293 end
0294 hlogdet = sum(log(diag(chol(speye(N*(C-1))+sW*kron(eye(C-1),K)*sW))));
0295 F = -0.5*a(:)'*f(:) + t(:)'*f(:) - sum(log(sum(exp(f),2)),1) - hlogdet;
0296
0297 end
0298 end
0299
0300
0301 function [p Mu SS] = gp_pred_lap_multiclass(logtheta,X,t,covfunc,Xs,Xss,f)
0302
0303
0304
0305
0306 [N,C] = size(t);
0307
0308 K = covfunc(logtheta,X);
0309 Ks = covfunc(logtheta,X,Xs);
0310 kss = covfunc(logtheta,Xss,'diag');
0311
0312 if nargin<7,
0313 f = gp_lap_multiclass(logtheta,covfunc,X,t);
0314 end
0315
0316 sig = exp(f);
0317 sig = sig./repmat(sum(sig,2)+eps,1,C);
0318 E = zeros(N,N,C);
0319 for c1=1:C
0320 D = sig(:,c1);
0321 sD = sqrt(D);
0322 L = chol(eye(N) + K.*(sD*sD'));
0323 E(:,:,c1) = diag(sD)*(L\(L'\diag(sD) ));
0324 end
0325 M = chol(sum(E,3));
0326 os = RandStream.getDefaultStream;
0327 p = zeros(size(Ks,2),C);
0328 j = 0;
0329 Mu = zeros(size(Ks,2),C); SS = zeros(C,C,size(Ks,2));
0330 for i=1:size(Ks,2),
0331 j = j + 1;
0332
0333 mu = zeros(C,1);
0334 S = zeros(C,C);
0335 for c1=1:C,
0336 mu(c1) = (t(:,c1)-sig(:,c1))'*Ks(:,i);
0337 b = E(:,:,c1)*Ks(:,i);
0338 c = (M\(M'\b));
0339 for c2=1:C,
0340 S(c1,c2) = Ks(:,i)'*E(:,:,c2)*c;
0341 end
0342 S(c1,c1) = S(c1,c1) - b'*Ks(:,i) + kss(i);
0343 end
0344
0345
0346 Mu(i,:) = mu';
0347 SS(:,:,i) = S;
0348
0349 s = RandStream.create('mt19937ar','seed',0);
0350 RandStream.setDefaultStream(s);
0351 nsamp = 10000;
0352 r = sqrtm(S)*randn(C,nsamp) + repmat(mu,1,nsamp);
0353
0354
0355 r = bsxfun(@minus, r, max(r, [], 1));
0356 r = exp(r);
0357 p(j,:) = mean(r./repmat(sum(r,1),C,1),2)';
0358 end
0359 RandStream.setDefaultStream(os);
0360 end
0361