0001 function [varargout] = prt_rvr(varargin)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063 if isnumeric(varargin{1}),
0064 [varargout{1:nargout}]=regression0(varargin{:});
0065 elseif iscell(varargin{1}),
0066 [varargout{1:nargout}]=regression1(varargin{:});
0067 else
0068 error('Incorrect usage');
0069 end;
0070 return;
0071
0072
0073
0074 function [w,alpha,beta,ll]=regression0(Phi,t)
0075 [N,M] = size(Phi);
0076 if N==M,
0077 Phi = [Phi ones(N,1)];
0078 elseif M~=N+1,
0079 error('Phi must be N x (N+1)');
0080 end;
0081 scale = sqrt(sum(sum(Phi(1:N,1:N).^2))/N^2);
0082 scale = [ones(N,1)*scale ; 1];
0083 Phi = Phi/spdiags(scale,0,numel(scale),numel(scale));
0084 alpha = ones(size(Phi,2),1)/N;
0085
0086 beta = 1e6;
0087 [w,alpha,beta,ll] = rvr1a(Phi,t,alpha,beta);
0088 alpha = [alpha(1)*ones(N,1) ; alpha(2)];
0089 [w,alpha,beta,ll] = rvr2a(Phi,t,alpha,beta);
0090 w = w./scale;
0091 alpha = alpha.*scale.^2;
0092 return;
0093
0094
0095
0096 function [w,alpha,beta,ll] = rvr1a(Phi,t,alpha,beta)
0097
0098
0099 [N,M] = size(Phi);
0100 ll = Inf;
0101 PP = Phi'*Phi;
0102 Pt = Phi'*t;
0103 for subit=1:10,
0104 alpha_old = alpha;
0105 beta_old = beta;
0106
0107
0108 S = inv(PP*beta + spdiags([ones(N,1)*alpha(1) ; alpha(2)],0,N+1,N+1));
0109 w = S*(Pt*beta);
0110
0111
0112
0113 tmp = t-Phi*w;
0114 ll = ...
0115 -0.5*log(alpha(1))*N-0.5*log(alpha(2))-0.5*N*log(beta)-0.5*logdet(S)...
0116 +0.5*tmp'*tmp*beta + 0.5*sum(w.^2.*[repmat(alpha(1),N,1) ; alpha(2)])...
0117 +0.5*(M-N)*log(2*pi);
0118
0119
0120
0121
0122 ds = diag(S);
0123 dfa1 = sum(ds(1:N))*alpha(1);
0124 dfa2 = sum(ds(N+1))*alpha(2);
0125 alpha(1) = max(N-dfa1,eps)/(sum(w(1:N).^2) +eps);
0126 alpha(2) = max(1-dfa2,eps)/(sum(w(N+1).^2) +eps);
0127 beta = max(dfa1+dfa2-1,eps)/(sum((Phi*w-t).^2)+eps);
0128
0129
0130 if max(max(abs(log((alpha+eps)./(alpha_old+eps)))),log(beta/beta_old)) < 1e-9,
0131 break;
0132 end;
0133 end;
0134
0135 return;
0136
0137
0138
0139 function [w,alpha,beta,ll]=rvr2a(Phi,t,alpha,beta)
0140
0141 [N,M] = size(Phi);
0142 nz = true(M,1);
0143
0144 PP = Phi'*Phi;
0145 Pt = Phi'*t;
0146
0147 for subit=1:200,
0148 th = min(alpha)*1e9;
0149 nz = alpha<th;
0150 alpha(~nz) = th*1e9;
0151 alpha_old = alpha;
0152 beta_old = beta;
0153
0154
0155 S = inv(PP(nz,nz)*beta + diag(alpha(nz)));
0156 w = S*Pt(nz)*beta;
0157
0158
0159
0160 tmp = t-Phi(:,nz)*w;
0161 ll = ...
0162 -0.5*sum(log(alpha(nz)+1e-32))-0.5*N*log(beta+1e-32)-0.5*logdet(S)...
0163 +0.5*tmp'*tmp*beta + 0.5*sum(w.^2.*alpha(nz))...
0164 +0.5*(sum(nz)-N)*log(2*pi);
0165
0166
0167
0168
0169 gam = 1 - alpha(nz).*diag(S);
0170 alpha(nz) = max(gam,eps)./(w.^2+1e-32);
0171 beta = max(N-sum(gam),eps)./(sum((Phi(:,nz)*w-t).^2)+1e-32);
0172
0173
0174 if max(max(abs(log((alpha(nz)+eps)./(alpha_old(nz)+eps)))),log(beta/beta_old)) < 1e-6*N,
0175 break;
0176 end;
0177 end;
0178 w(nz) = w;
0179 w(~nz) = 0;
0180 w = w(:);
0181
0182
0183
0184
0185 function [w,alpha,beta,nu,ll]=regression1(K,t,opt)
0186
0187 if nargin<3, opt = 'Linear'; end;
0188 switch opt,
0189 case {'Linear','linear','lin'},
0190 dkrn_f = @make_dphi;
0191 krn_f = @make_phi;
0192 case {'Gaussian RBF','nonlinear','nonlin'},
0193 dkrn_f = @make_dphi_rbf;
0194 krn_f = @make_phi_rbf;
0195 otherwise
0196 error('Unknown option');
0197 end;
0198 [N,M] = size(K{1});
0199 nu = ones(numel(K),1);
0200 rescal = ones(numel(K),1);
0201 for i=1:numel(K),
0202 if strcmpi(opt,'Gaussian RBF') || strcmpi(opt,'nonlinear') || strcmpi(opt,'nonlin'),
0203 d = 0.5*diag(K{i});
0204 K{i} = repmat(d,[1 size(K{i},1)]) + repmat(d',[size(K{i},1),1]) - K{i};
0205 K{i} = max(K{i},0);
0206 K{i} = -K{i};
0207 nu(i) = 1/sqrt(sum(K{i}(:).^2)/(size(K{i},1).^2-size(K{i},1)));
0208 else
0209 rescal(i) = sqrt(size(K{i},1)/sum(K{i}(:).^2));
0210 K{i} = K{i}.*rescal(i);
0211 end;
0212 end;
0213
0214 alpha = [1 1]';
0215
0216 beta = 1e6;
0217 [w,alpha,beta,nu,ll]=rvr1(K,t,alpha,beta,nu,krn_f,dkrn_f);
0218 alpha = [alpha(1)*ones(N,1) ; alpha(2)];
0219 [w,alpha,beta,nu,ll]=rvr2(K,t,alpha,beta,nu,krn_f,dkrn_f);
0220 nu = nu.*rescal;
0221 return;
0222
0223
0224
0225 function [w,alpha,beta,nu,ll] = rvr1(K,t,alpha,beta,nu,krn_f,dkrn_f)
0226
0227 spm_chi2_plot('Init','ML-II (non-sparse)','-Log-likelihood','Iteration');
0228 [N,M] = size(K{1});
0229 ll = Inf;
0230 for iter=1:50,
0231 Phi = feval(krn_f,nu,K);
0232 for subit=1:1,
0233 alpha_old = alpha;
0234 beta_old = beta;
0235
0236
0237 S = inv(Phi'*Phi*beta + spdiags([ones(N,1)*alpha(1) ; alpha(2)],0,N+1,N+1));
0238 w = S*(Phi'*t*beta);
0239
0240
0241 ds = diag(S);
0242 dfa1 = sum(ds(1:N))*alpha(1);
0243 dfa2 = sum(ds(N+1))*alpha(2);
0244 alpha(1) = max(N-dfa1,eps)/(sum(w(1:N).^2) +eps);
0245 alpha(2) = max(1-dfa2,eps)/(sum(w(N+1).^2) +eps);
0246 beta = max(dfa1+dfa2-1,eps)/(sum((Phi*w-t).^2)+eps);
0247
0248
0249 if max(max(abs(log((alpha+eps)./(alpha_old+eps)))),log(beta/beta_old)) < 1e-9,
0250 break;
0251 end;
0252 end;
0253
0254
0255
0256 oll = ll;
0257 al1 = [ones(N,1)*alpha(1) ; alpha(2)];
0258 [nu,ll] = re_estimate_nu(K,t,nu,al1,beta,krn_f,dkrn_f);
0259
0260
0261
0262
0263
0264 spm_chi2_plot('Set',ll);
0265 if abs(oll-ll) < 1e-6*N, break; end;
0266 end;
0267 spm_chi2_plot('Clear');
0268 return;
0269
0270
0271
0272 function [w,alpha,beta,nu,ll]=rvr2(K,t,alpha,beta,nu,krn_f,dkrn_f)
0273 spm_chi2_plot('Init','ML-II (sparse)','-Log-likelihood','Iteration');
0274 [N,M] = size(K{1});
0275 w = zeros(N+1,1);
0276 ll = Inf;
0277 for iter=1:100,
0278 for subits=1:1,
0279
0280
0281
0282 th = min(alpha)*1e9;
0283 nz = alpha<th;
0284 alpha(~nz) = th*1e9;
0285
0286 alpha_old = alpha;
0287 beta_old = beta;
0288 Phi = feval(krn_f,nu,K,nz);
0289
0290
0291 S = inv(beta*Phi'*Phi + diag(alpha(nz)));
0292 w(nz) = S*Phi'*t*beta;
0293 w(~nz) = 0;
0294
0295
0296
0297
0298 gam = 1 - alpha(nz).*diag(S);
0299 alpha(nz) = max(gam,eps)./(w(nz).^2+1e-32);
0300 beta = max(N-sum(gam),eps)./(sum((Phi*w(nz)-t).^2)+1e-32);
0301
0302
0303 if max(max(abs(log((alpha+eps)./(alpha_old+eps)))),log(beta/beta_old)) < 1e-6,
0304 break;
0305 end;
0306 end;
0307
0308 oll = ll;
0309 [nu,ll] = re_estimate_nu(K,t,nu,alpha,beta,krn_f,dkrn_f,nz);
0310
0311
0312
0313
0314
0315 spm_chi2_plot('Set',ll);
0316
0317
0318 if abs(oll-ll) < 1e-9*N,
0319 break;
0320 end;
0321 end;
0322 spm_chi2_plot('Clear');
0323 return;
0324
0325
0326
0327 function Phi = make_phi(nu,K,nz)
0328
0329
0330 if nargin>2 && ~isempty(nz),
0331 nz1 = nz(1:size(K{1},1));
0332 nz2 = nz(size(K{1},1)+1);
0333 Phi = K{1}(:,nz1)*nu(1);
0334 for i=2:numel(K),
0335 Phi=Phi+K{i}(:,nz1)*nu(i);
0336 end;
0337 Phi = [Phi ones(size(Phi,1),sum(nz2))];
0338 else
0339 Phi = K{1}*nu(1);
0340 for i=2:numel(K),
0341 Phi=Phi+K{i}*nu(i);
0342 end;
0343 Phi = [Phi ones(size(Phi,1),1)];
0344 end;
0345 return;
0346
0347
0348
0349 function [dPhi,d2Phi] = make_dphi(nu,K,nz)
0350
0351
0352 dPhi = cell(size(K));
0353 d2Phi = cell(numel(K));
0354 if nargin>2 && ~isempty(nz),
0355 nz1 = nz(1:size(K{1},1));
0356 nz2 = nz(size(K{1},1)+1);
0357 for i=1:numel(K),
0358 dPhi{i} = [K{i}(:,nz1),zeros(size(K{i},1),sum(nz2))];
0359 dPhi{i} = dPhi{i}*nu(i);
0360 end;
0361 else
0362 for i=1:numel(K),
0363 dPhi{i} = [K{i},zeros(size(K{i},1),1)];
0364 dPhi{i} = dPhi{i}*nu(i);
0365 end;
0366 end;
0367
0368 z = zeros(size(dPhi{1}));
0369 for i=1:numel(K),
0370 d2Phi{i,i} = nu(i)*dPhi{i};
0371 for j=(i+1):numel(K),
0372 d2Phi{i,j} = z;
0373 d2Phi{j,i} = z;
0374 end;
0375 dPhi{i} = nu(i)*dPhi{i};
0376 end;
0377 return;
0378
0379
0380
0381 function Phi = make_phi_rbf(nu,K,nz)
0382
0383 if nargin>2 && ~isempty(nz),
0384 nz1 = nz(1:size(K{1},1));
0385 nz2 = nz(size(K{1},1)+1);
0386 Phi = K{1}(:,nz1)*nu(1);
0387 for i=2:numel(K),
0388 Phi=Phi+K{i}(:,nz1)*nu(i);
0389 end;
0390 Phi = [exp(Phi) ones(size(Phi,1),sum(nz2))];
0391 else
0392 Phi = K{1}*nu(1);
0393 for i=2:numel(K),
0394 Phi=Phi+K{i}*nu(i);
0395 end;
0396 Phi = [exp(Phi) ones(size(Phi,1),1)];
0397 end;
0398 return;
0399
0400
0401
0402 function [dPhi,d2Phi] = make_dphi_rbf(nu,K,nz)
0403
0404
0405 Phi = make_phi_rbf(nu,K,nz);
0406 dPhi = cell(size(K));
0407 d2Phi = cell(numel(K));
0408 if nargin>2 && ~isempty(nz),
0409 nz1 = nz(1:size(K{1},1));
0410 nz2 = nz(size(K{1},1)+1);
0411 for i=1:numel(K),
0412 dPhi{i} = [K{i}(:,nz1),zeros(size(K{i},1),sum(nz2))];
0413 end;
0414 else
0415 for i=1:numel(K),
0416 dPhi{i} = [K{i},zeros(size(K{i},1),1)];
0417 end;
0418 end;
0419
0420 for i=1:numel(K),
0421 d2Phi{i,i} = nu(i)*dPhi{i}.*Phi.*(1+nu(i)*dPhi{i});
0422 for j=(i+1):numel(K),
0423 d2Phi{i,j} = (nu(i)*nu(j))*dPhi{i}.*dPhi{j}.*Phi;
0424 d2Phi{j,i} = d2Phi{i,j};
0425 end;
0426 dPhi{i} = nu(i)*dPhi{i}.*Phi;
0427 end;
0428 return;
0429
0430
0431
0432 function [nu,ll] = re_estimate_nu(K,t,nu,alpha,beta,krn_f,dkrn_f,nz)
0433
0434
0435 if nargin<8, nz = true(size(K{1},2)+1,1); end;
0436
0437 ll = Inf;
0438 lam = 1e-6;
0439 Phi = feval(krn_f,nu,K,nz);
0440 S = inv(Phi'*Phi*beta+diag(alpha(nz)));
0441 w = beta*S*Phi'*t;
0442 N = size(Phi,1);
0443 ll = ...
0444 -0.5*sum(log(alpha(nz)))-0.5*N*log(beta)-0.5*logdet(S)...
0445 +0.5*(t-Phi*w)'*(t-Phi*w)*beta + 0.5*sum(w.^2.*alpha(nz))...
0446 +0.5*(sum(nz)-N)*log(2*pi);
0447
0448 for subit=1:30,
0449
0450
0451
0452
0453 g = zeros(numel(K),1);
0454 H = zeros(numel(K));
0455 [dPhi,d2Phi] = feval(dkrn_f,nu,K,nz);
0456 for i=1:numel(K),
0457 tmp1 = Phi'*dPhi{i};
0458 tmp1 = tmp1+tmp1';
0459 g(i) = 0.5*beta*(sum(sum(S.*tmp1)) + w'*tmp1*w - 2*t'*dPhi{i}*w);
0460 for j=i:numel(K),
0461 tmp = dPhi{j}'*dPhi{i} + Phi'*d2Phi{i,j};
0462 tmp = tmp+tmp';
0463 tmp2 = Phi'*dPhi{j};
0464 tmp2 = tmp2+tmp2';
0465 H(i,j) = sum(sum(S.*(tmp - tmp1*S*tmp2))) + w'*tmp*w - 2*w'*d2Phi{i,j}'*t;
0466 H(i,j) = 0.5*beta*H(i,j);
0467 H(j,i) = H(i,j);
0468 end;
0469 end;
0470
0471 oll = ll;
0472 onu = nu;
0473
0474
0475
0476
0477 lam = max(lam,-real(min(eig(H)))*1.5);
0478
0479 for subsubit=1:30,
0480 drawnow;
0481
0482
0483 warning off
0484 nu = exp(log(nu) - (H+lam*speye(size(H)))\g);
0485 warning on
0486
0487
0488 nu = max(max(nu,1e-12),max(nu)*1e-9);
0489 nu = min(min(nu,1e12) ,min(nu)*1e9);
0490
0491
0492 Phi1 = feval(krn_f,nu,K,nz);
0493 warning off
0494 S1 = inv(Phi1'*Phi1*beta+diag(alpha(nz)));
0495 warning on
0496 w1 = beta*S1*Phi1'*t;
0497 ll = ...
0498 -0.5*sum(log(alpha(nz)+eps))-0.5*N*log(beta+eps)-0.5*logdet(S1)...
0499 +0.5*(t-Phi1*w1)'*(t-Phi1*w1)*beta + 0.5*sum(w1.^2.*alpha(nz))...
0500 +0.5*(sum(nz)-N)*log(2*pi);
0501
0502
0503 if abs(ll-oll)<1e-9,
0504 nu = onu;
0505 ll = oll;
0506 break;
0507 end;
0508
0509 if ll>oll,
0510 lam = lam*10;
0511 nu = onu;
0512 ll = oll;
0513 else
0514 lam = lam/10;
0515 lam = max(lam,1e-12);
0516
0517
0518
0519 Phi = Phi1;
0520 S = S1;
0521 w = w1;
0522 break;
0523 end;
0524 end;
0525 if abs(ll-oll)<1e-9*N, break; end;
0526 end;
0527
0528
0529
0530 function [ld,C] = logdet(A)
0531 A = (A+A')/2;
0532 C = chol(A);
0533 d = max(diag(C),eps);
0534 ld = sum(2*log(d));
0535