0001 function [varargout] = prt_rvr(varargin)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064 if isnumeric(varargin{1}),
0065 [varargout{1:nargout}]=regression0(varargin{:});
0066 elseif iscell(varargin{1}),
0067 [varargout{1:nargout}]=regression1(varargin{:});
0068 else
0069 error('Incorrect usage');
0070 end;
0071 return;
0072
0073
0074
0075 function [w,alpha,beta,ll]=regression0(Phi,t)
0076 [N,M] = size(Phi);
0077 if N==M,
0078 Phi = [Phi ones(N,1)];
0079 elseif M~=N+1,
0080 error('Phi must be N x (N+1)');
0081 end;
0082 scale = sqrt(sum(sum(Phi(1:N,1:N).^2))/N^2);
0083 scale = [ones(N,1)*scale ; 1];
0084 Phi = Phi/spdiags(scale,0,numel(scale),numel(scale));
0085 alpha = ones(size(Phi,2),1)/N;
0086
0087 beta = 1e6;
0088 [w,alpha,beta,ll] = rvr1a(Phi,t,alpha,beta);
0089 alpha = [alpha(1)*ones(N,1) ; alpha(2)];
0090 [w,alpha,beta,ll] = rvr2a(Phi,t,alpha,beta);
0091 w = w./scale;
0092 alpha = alpha.*scale.^2;
0093 return;
0094
0095
0096
0097 function [w,alpha,beta,ll] = rvr1a(Phi,t,alpha,beta)
0098
0099
0100 [N,M] = size(Phi);
0101 ll = Inf;
0102 PP = Phi'*Phi;
0103 Pt = Phi'*t;
0104 for subit=1:10,
0105 alpha_old = alpha;
0106 beta_old = beta;
0107
0108
0109 S = inv(PP*beta + spdiags([ones(N,1)*alpha(1) ; alpha(2)],0,N+1,N+1));
0110 w = S*(Pt*beta);
0111
0112
0113
0114 tmp = t-Phi*w;
0115 ll = ...
0116 -0.5*log(alpha(1))*N-0.5*log(alpha(2))-0.5*N*log(beta)-0.5*logdet(S)...
0117 +0.5*tmp'*tmp*beta + 0.5*sum(w.^2.*[repmat(alpha(1),N,1) ; alpha(2)])...
0118 +0.5*(M-N)*log(2*pi);
0119
0120
0121
0122
0123 ds = diag(S);
0124 dfa1 = sum(ds(1:N))*alpha(1);
0125 dfa2 = sum(ds(N+1))*alpha(2);
0126 alpha(1) = max(N-dfa1,eps)/(sum(w(1:N).^2) +eps);
0127 alpha(2) = max(1-dfa2,eps)/(sum(w(N+1).^2) +eps);
0128 beta = max(dfa1+dfa2-1,eps)/(sum((Phi*w-t).^2)+eps);
0129
0130
0131 if max(max(abs(log((alpha+eps)./(alpha_old+eps)))),log(beta/beta_old)) < 1e-9,
0132 break;
0133 end;
0134 end;
0135
0136 return;
0137
0138
0139
0140 function [w,alpha,beta,ll]=rvr2a(Phi,t,alpha,beta)
0141
0142 [N,M] = size(Phi);
0143 nz = true(M,1);
0144
0145 PP = Phi'*Phi;
0146 Pt = Phi'*t;
0147
0148 for subit=1:200,
0149 th = min(alpha)*1e9;
0150 nz = alpha<th;
0151 alpha(~nz) = th*1e9;
0152 alpha_old = alpha;
0153 beta_old = beta;
0154
0155
0156 S = inv(PP(nz,nz)*beta + diag(alpha(nz)));
0157 w = S*Pt(nz)*beta;
0158
0159
0160
0161 tmp = t-Phi(:,nz)*w;
0162 ll = ...
0163 -0.5*sum(log(alpha(nz)+1e-32))-0.5*N*log(beta+1e-32)-0.5*logdet(S)...
0164 +0.5*tmp'*tmp*beta + 0.5*sum(w.^2.*alpha(nz))...
0165 +0.5*(sum(nz)-N)*log(2*pi);
0166
0167
0168
0169
0170 gam = 1 - alpha(nz).*diag(S);
0171 alpha(nz) = max(gam,eps)./(w.^2+1e-32);
0172 beta = max(N-sum(gam),eps)./(sum((Phi(:,nz)*w-t).^2)+1e-32);
0173
0174
0175 if max(max(abs(log((alpha(nz)+eps)./(alpha_old(nz)+eps)))),log(beta/beta_old)) < 1e-6*N,
0176 break;
0177 end;
0178 end;
0179 w(nz) = w;
0180 w(~nz) = 0;
0181 w = w(:);
0182
0183
0184
0185
0186 function [w,alpha,beta,nu,ll]=regression1(K,t,opt)
0187
0188 if nargin<3, opt = 'Linear'; end;
0189 switch opt,
0190 case {'Linear','linear','lin'},
0191 dkrn_f = @make_dphi;
0192 krn_f = @make_phi;
0193 case {'Gaussian RBF','nonlinear','nonlin'},
0194 dkrn_f = @make_dphi_rbf;
0195 krn_f = @make_phi_rbf;
0196 otherwise
0197 error('Unknown option');
0198 end;
0199 [N,M] = size(K{1});
0200 nu = ones(numel(K),1);
0201 rescal = ones(numel(K),1);
0202 for i=1:numel(K),
0203 if strcmpi(opt,'Gaussian RBF') || strcmpi(opt,'nonlinear') || strcmpi(opt,'nonlin'),
0204 d = 0.5*diag(K{i});
0205 K{i} = repmat(d,[1 size(K{i},1)]) + repmat(d',[size(K{i},1),1]) - K{i};
0206 K{i} = max(K{i},0);
0207 K{i} = -K{i};
0208 nu(i) = 1/sqrt(sum(K{i}(:).^2)/(size(K{i},1).^2-size(K{i},1)));
0209 else
0210 rescal(i) = sqrt(size(K{i},1)/sum(K{i}(:).^2));
0211 K{i} = K{i}.*rescal(i);
0212 end;
0213 end;
0214
0215 alpha = [1 1]';
0216
0217 beta = 1e6;
0218 [w,alpha,beta,nu,ll]=rvr1(K,t,alpha,beta,nu,krn_f,dkrn_f);
0219 alpha = [alpha(1)*ones(N,1) ; alpha(2)];
0220 [w,alpha,beta,nu,ll]=rvr2(K,t,alpha,beta,nu,krn_f,dkrn_f);
0221 nu = nu.*rescal;
0222 return;
0223
0224
0225
0226 function [w,alpha,beta,nu,ll] = rvr1(K,t,alpha,beta,nu,krn_f,dkrn_f)
0227
0228 spm_chi2_plot('Init','ML-II (non-sparse)','-Log-likelihood','Iteration');
0229 [N,M] = size(K{1});
0230 ll = Inf;
0231 for iter=1:50,
0232 Phi = feval(krn_f,nu,K);
0233 for subit=1:1,
0234 alpha_old = alpha;
0235 beta_old = beta;
0236
0237
0238 S = inv(Phi'*Phi*beta + spdiags([ones(N,1)*alpha(1) ; alpha(2)],0,N+1,N+1));
0239 w = S*(Phi'*t*beta);
0240
0241
0242 ds = diag(S);
0243 dfa1 = sum(ds(1:N))*alpha(1);
0244 dfa2 = sum(ds(N+1))*alpha(2);
0245 alpha(1) = max(N-dfa1,eps)/(sum(w(1:N).^2) +eps);
0246 alpha(2) = max(1-dfa2,eps)/(sum(w(N+1).^2) +eps);
0247 beta = max(dfa1+dfa2-1,eps)/(sum((Phi*w-t).^2)+eps);
0248
0249
0250 if max(max(abs(log((alpha+eps)./(alpha_old+eps)))),log(beta/beta_old)) < 1e-9,
0251 break;
0252 end;
0253 end;
0254
0255
0256
0257 oll = ll;
0258 al1 = [ones(N,1)*alpha(1) ; alpha(2)];
0259 [nu,ll] = re_estimate_nu(K,t,nu,al1,beta,krn_f,dkrn_f);
0260
0261
0262
0263
0264
0265 spm_chi2_plot('Set',ll);
0266 if abs(oll-ll) < 1e-6*N, break; end;
0267 end;
0268 spm_chi2_plot('Clear');
0269 return;
0270
0271
0272
0273 function [w,alpha,beta,nu,ll]=rvr2(K,t,alpha,beta,nu,krn_f,dkrn_f)
0274 spm_chi2_plot('Init','ML-II (sparse)','-Log-likelihood','Iteration');
0275 [N,M] = size(K{1});
0276 w = zeros(N+1,1);
0277 ll = Inf;
0278 for iter=1:100,
0279 for subits=1:1,
0280
0281
0282
0283 th = min(alpha)*1e9;
0284 nz = alpha<th;
0285 alpha(~nz) = th*1e9;
0286
0287 alpha_old = alpha;
0288 beta_old = beta;
0289 Phi = feval(krn_f,nu,K,nz);
0290
0291
0292 S = inv(beta*Phi'*Phi + diag(alpha(nz)));
0293 w(nz) = S*Phi'*t*beta;
0294 w(~nz) = 0;
0295
0296
0297
0298
0299 gam = 1 - alpha(nz).*diag(S);
0300 alpha(nz) = max(gam,eps)./(w(nz).^2+1e-32);
0301 beta = max(N-sum(gam),eps)./(sum((Phi*w(nz)-t).^2)+1e-32);
0302
0303
0304 if max(max(abs(log((alpha+eps)./(alpha_old+eps)))),log(beta/beta_old)) < 1e-6,
0305 break;
0306 end;
0307 end;
0308
0309 oll = ll;
0310 [nu,ll] = re_estimate_nu(K,t,nu,alpha,beta,krn_f,dkrn_f,nz);
0311
0312
0313
0314
0315
0316 spm_chi2_plot('Set',ll);
0317
0318
0319 if abs(oll-ll) < 1e-9*N,
0320 break;
0321 end;
0322 end;
0323 spm_chi2_plot('Clear');
0324 return;
0325
0326
0327
0328 function Phi = make_phi(nu,K,nz)
0329
0330
0331 if nargin>2 && ~isempty(nz),
0332 nz1 = nz(1:size(K{1},1));
0333 nz2 = nz(size(K{1},1)+1);
0334 Phi = K{1}(:,nz1)*nu(1);
0335 for i=2:numel(K),
0336 Phi=Phi+K{i}(:,nz1)*nu(i);
0337 end;
0338 Phi = [Phi ones(size(Phi,1),sum(nz2))];
0339 else
0340 Phi = K{1}*nu(1);
0341 for i=2:numel(K),
0342 Phi=Phi+K{i}*nu(i);
0343 end;
0344 Phi = [Phi ones(size(Phi,1),1)];
0345 end;
0346 return;
0347
0348
0349
0350 function [dPhi,d2Phi] = make_dphi(nu,K,nz)
0351
0352
0353 dPhi = cell(size(K));
0354 d2Phi = cell(numel(K));
0355 if nargin>2 && ~isempty(nz),
0356 nz1 = nz(1:size(K{1},1));
0357 nz2 = nz(size(K{1},1)+1);
0358 for i=1:numel(K),
0359 dPhi{i} = [K{i}(:,nz1),zeros(size(K{i},1),sum(nz2))];
0360 dPhi{i} = dPhi{i}*nu(i);
0361 end;
0362 else
0363 for i=1:numel(K),
0364 dPhi{i} = [K{i},zeros(size(K{i},1),1)];
0365 dPhi{i} = dPhi{i}*nu(i);
0366 end;
0367 end;
0368
0369 z = zeros(size(dPhi{1}));
0370 for i=1:numel(K),
0371 d2Phi{i,i} = nu(i)*dPhi{i};
0372 for j=(i+1):numel(K),
0373 d2Phi{i,j} = z;
0374 d2Phi{j,i} = z;
0375 end;
0376 dPhi{i} = nu(i)*dPhi{i};
0377 end;
0378 return;
0379
0380
0381
0382 function Phi = make_phi_rbf(nu,K,nz)
0383
0384 if nargin>2 && ~isempty(nz),
0385 nz1 = nz(1:size(K{1},1));
0386 nz2 = nz(size(K{1},1)+1);
0387 Phi = K{1}(:,nz1)*nu(1);
0388 for i=2:numel(K),
0389 Phi=Phi+K{i}(:,nz1)*nu(i);
0390 end;
0391 Phi = [exp(Phi) ones(size(Phi,1),sum(nz2))];
0392 else
0393 Phi = K{1}*nu(1);
0394 for i=2:numel(K),
0395 Phi=Phi+K{i}*nu(i);
0396 end;
0397 Phi = [exp(Phi) ones(size(Phi,1),1)];
0398 end;
0399 return;
0400
0401
0402
0403 function [dPhi,d2Phi] = make_dphi_rbf(nu,K,nz)
0404
0405
0406 Phi = make_phi_rbf(nu,K,nz);
0407 dPhi = cell(size(K));
0408 d2Phi = cell(numel(K));
0409 if nargin>2 && ~isempty(nz),
0410 nz1 = nz(1:size(K{1},1));
0411 nz2 = nz(size(K{1},1)+1);
0412 for i=1:numel(K),
0413 dPhi{i} = [K{i}(:,nz1),zeros(size(K{i},1),sum(nz2))];
0414 end;
0415 else
0416 for i=1:numel(K),
0417 dPhi{i} = [K{i},zeros(size(K{i},1),1)];
0418 end;
0419 end;
0420
0421 for i=1:numel(K),
0422 d2Phi{i,i} = nu(i)*dPhi{i}.*Phi.*(1+nu(i)*dPhi{i});
0423 for j=(i+1):numel(K),
0424 d2Phi{i,j} = (nu(i)*nu(j))*dPhi{i}.*dPhi{j}.*Phi;
0425 d2Phi{j,i} = d2Phi{i,j};
0426 end;
0427 dPhi{i} = nu(i)*dPhi{i}.*Phi;
0428 end;
0429 return;
0430
0431
0432
0433 function [nu,ll] = re_estimate_nu(K,t,nu,alpha,beta,krn_f,dkrn_f,nz)
0434
0435
0436 if nargin<8, nz = true(size(K{1},2)+1,1); end;
0437
0438 ll = Inf;
0439 lam = 1e-6;
0440 Phi = feval(krn_f,nu,K,nz);
0441 S = inv(Phi'*Phi*beta+diag(alpha(nz)));
0442 w = beta*S*Phi'*t;
0443 N = size(Phi,1);
0444 ll = ...
0445 -0.5*sum(log(alpha(nz)))-0.5*N*log(beta)-0.5*logdet(S)...
0446 +0.5*(t-Phi*w)'*(t-Phi*w)*beta + 0.5*sum(w.^2.*alpha(nz))...
0447 +0.5*(sum(nz)-N)*log(2*pi);
0448
0449 for subit=1:30,
0450
0451
0452
0453
0454 g = zeros(numel(K),1);
0455 H = zeros(numel(K));
0456 [dPhi,d2Phi] = feval(dkrn_f,nu,K,nz);
0457 for i=1:numel(K),
0458 tmp1 = Phi'*dPhi{i};
0459 tmp1 = tmp1+tmp1';
0460 g(i) = 0.5*beta*(sum(sum(S.*tmp1)) + w'*tmp1*w - 2*t'*dPhi{i}*w);
0461 for j=i:numel(K),
0462 tmp = dPhi{j}'*dPhi{i} + Phi'*d2Phi{i,j};
0463 tmp = tmp+tmp';
0464 tmp2 = Phi'*dPhi{j};
0465 tmp2 = tmp2+tmp2';
0466 H(i,j) = sum(sum(S.*(tmp - tmp1*S*tmp2))) + w'*tmp*w - 2*w'*d2Phi{i,j}'*t;
0467 H(i,j) = 0.5*beta*H(i,j);
0468 H(j,i) = H(i,j);
0469 end;
0470 end;
0471
0472 oll = ll;
0473 onu = nu;
0474
0475
0476
0477
0478 lam = max(lam,-real(min(eig(H)))*1.5);
0479
0480 for subsubit=1:30,
0481 drawnow;
0482
0483
0484 warning off
0485 nu = exp(log(nu) - (H+lam*speye(size(H)))\g);
0486 warning on
0487
0488
0489 nu = max(max(nu,1e-12),max(nu)*1e-9);
0490 nu = min(min(nu,1e12) ,min(nu)*1e9);
0491
0492
0493 Phi1 = feval(krn_f,nu,K,nz);
0494 warning off
0495 S1 = inv(Phi1'*Phi1*beta+diag(alpha(nz)));
0496 warning on
0497 w1 = beta*S1*Phi1'*t;
0498 ll = ...
0499 -0.5*sum(log(alpha(nz)+eps))-0.5*N*log(beta+eps)-0.5*logdet(S1)...
0500 +0.5*(t-Phi1*w1)'*(t-Phi1*w1)*beta + 0.5*sum(w1.^2.*alpha(nz))...
0501 +0.5*(sum(nz)-N)*log(2*pi);
0502
0503
0504 if abs(ll-oll)<1e-9,
0505 nu = onu;
0506 ll = oll;
0507 break;
0508 end;
0509
0510 if ll>oll,
0511 lam = lam*10;
0512 nu = onu;
0513 ll = oll;
0514 else
0515 lam = lam/10;
0516 lam = max(lam,1e-12);
0517
0518
0519
0520 Phi = Phi1;
0521 S = S1;
0522 w = w1;
0523 break;
0524 end;
0525 end;
0526 if abs(ll-oll)<1e-9*N, break; end;
0527 end;
0528
0529
0530
0531 function [ld,C] = logdet(A)
0532 A = (A+A')/2;
0533 C = chol(A);
0534 d = max(diag(C),eps);
0535 ld = sum(2*log(d));
0536