0001 function [] = prt_permutation(PRT, n_perm, modelid, path, flag)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035 prt_dir = path;
0036 def_par = prt_get_defaults('paral');
0037 if nargin<5
0038 flag=0;
0039 end
0040
0041
0042
0043 if ~isfield(PRT,'model')
0044 beep
0045 disp('No model found in this PRT.mat');
0046 return
0047 else
0048 if ~isfield(PRT.model,'output')
0049 beep
0050 disp('No model output found in this PRT.mat')
0051 return
0052
0053 end
0054
0055
0056 CV = PRT.model(modelid).input.cv_mat;
0057 n_folds = size(CV,2);
0058
0059
0060 if def_par.allow
0061 try
0062 matlabpool(def_par.ncore)
0063 catch
0064 warning('Could not use pool of Matlab processes!')
0065 end
0066 end
0067
0068
0069 t = PRT.model(modelid).input.targets;
0070
0071
0072 [Phi_all,ID,fid] = prt_getKernelModel(PRT,prt_dir,modelid);
0073
0074
0075 if strcmpi(PRT.model(modelid).input.type,'classification')
0076 nc=max(unique(t));
0077 else
0078 nc=[];
0079 end
0080 fdata.nc = nc;
0081
0082
0083
0084
0085 ids = PRT.fs(fid).id_mat(PRT.model(modelid).input.samp_idx,:);
0086 i=1;
0087 samp_g=unique(ids(:,1));
0088 for gid = 1: length(samp_g)
0089
0090 samp_s=unique(ids(ids(:,1)==samp_g(gid),2));
0091
0092 for sid = 1: length(samp_s)
0093
0094 samp_m=unique(ids(ids(:,1)==samp_g(gid) & ids(:,2)==samp_s(sid),3));
0095
0096 for mid = 1:length(samp_m)
0097
0098 samp_c=unique(ids(ids(:,1)==samp_g(gid) & ids(:,2)==samp_s(sid) & ids(:,3)==samp_m(mid),4));
0099
0100 for cid = 1:length(samp_c)
0101
0102 samp_b=unique(ids(ids(:,1)==samp_g(gid) & ids(:,2)==samp_s(sid) & ids(:,3)==samp_m(mid) & ids(:,4)==samp_c(cid),5));
0103
0104 for bid = 1:length(samp_b)
0105
0106 rg = find((ids(:,1) == samp_g(gid)) & ...
0107 (ids(:,2) == samp_s(sid)) & ...
0108 (ids(:,3) == samp_m(mid)) & ...
0109 (ids(:,4) == samp_c(cid)) & ...
0110 (ids(:,5) == samp_b(bid)));
0111
0112 chunks{i} = rg;
0113
0114 i=i+1;
0115 end
0116 end
0117 end
0118 end
0119 end
0120
0121
0122
0123
0124
0125 switch PRT.model(modelid).output.fold(1).type
0126 case 'classifier'
0127 n_class = length(PRT.model(modelid).output.fold(1).stats.c_acc);
0128 total_greater_c_acc = zeros(n_class,1);
0129 total_greater_b_acc = 0;
0130
0131 case 'regression'
0132 total_greater_corr = 0;
0133 total_greater_mse = 0;
0134 total_greater_nmse = 0;
0135 total_greater_r2 = 0;
0136 end
0137
0138
0139
0140 if ~isfield(PRT.model(modelid).output,'permutation') || ...
0141 (isfield(PRT.model(modelid).output,'permutation') && flag)
0142 PRT.model(modelid).output.permutation=struct('fold',[]);
0143 end
0144 for p=1:n_perm
0145
0146 disp(sprintf('Permutation %d out of %d >>>>>>',p,n_perm));
0147
0148
0149 chunkperm=randperm(length(chunks));
0150 CVperm = zeros(size(CV));
0151 t_perm = zeros(length(t),1);
0152 for i=1:length(chunks)
0153 t_perm(chunks{i},1)= unique(PRT.model(modelid).input.targets(chunks{chunkperm(i)}));
0154 CVperm(chunks{i},:) = CV(chunks{chunkperm(i)},:);
0155 end
0156
0157 for f = 1:n_folds
0158
0159 fdata.ID = ID;
0160 fdata.mid = modelid;
0161 fdata.CV = CVperm(:,f);
0162 fdata.Phi_all = Phi_all;
0163 fdata.t = t_perm;
0164
0165
0166 if isfield(PRT.model(modelid).input,'use_nested_cv')
0167 if PRT.model(modelid).input.use_nested_cv
0168 [out] = prt_nested_cv(PRT, fdata);
0169 PRT.model(modelid).output.fold(f).param_effect = out;
0170 PRT.model(modelid).input.machine.args = out.opt_param;
0171 end
0172 end
0173
0174 [temp_model, targets] = prt_cv_fold(PRT,fdata);
0175
0176
0177 if flag
0178 PRT.model(modelid).output.permutation(p).fold(f).alpha=temp_model.alpha;
0179 PRT.model(modelid).output.permutation(p).fold(f).pred=temp_model.predictions;
0180 end
0181
0182 model.output.fold(f).predictions = temp_model.predictions;
0183 model.output.fold(f).targets = targets.test;
0184
0185 end
0186
0187
0188 t = vertcat(model.output.fold(:).targets);
0189 m.type = PRT.model(modelid).output.fold(1).type;
0190 m.predictions = vertcat(model.output.fold(:).predictions);
0191 perm_stats = prt_stats(m,t,t);
0192
0193
0194 switch PRT.model(modelid).output.fold(1).type
0195
0196 case 'classifier'
0197
0198 permutation.b_acc(p)=perm_stats.b_acc;
0199 n_class = length(PRT.model(modelid).output.fold(1).stats.c_acc);
0200
0201 if (perm_stats.b_acc >= PRT.model(modelid).output.stats.b_acc)
0202 total_greater_b_acc=total_greater_b_acc+1;
0203 end
0204
0205 for c=1:n_class
0206 permutation.c_acc(c,p)=perm_stats.c_acc(c);
0207 if (perm_stats.c_acc(c) >= PRT.model(modelid).output.stats.c_acc(c))
0208 total_greater_c_acc(c)=total_greater_c_acc(c)+1;
0209 end
0210 end
0211
0212 case 'regression'
0213 permutation.corr(p)=perm_stats.corr;
0214 if (perm_stats.corr >= PRT.model(modelid).output.stats.corr)
0215 total_greater_corr=total_greater_corr+1;
0216 end
0217 permutation.mse(p)=perm_stats.mse;
0218 if (perm_stats.mse <= PRT.model(modelid).output.stats.mse)
0219 total_greater_mse=total_greater_mse+1;
0220 end
0221 permutation.nmse(p)=perm_stats.nmse;
0222 if (perm_stats.nmse <= PRT.model(modelid).output.stats.nmse)
0223 total_greater_nmse=total_greater_nmse+1;
0224 end
0225 permutation.r2(p)=perm_stats.r2;
0226 if (perm_stats.r2 >= PRT.model(modelid).output.stats.r2)
0227 total_greater_r2=total_greater_r2+1;
0228 end
0229
0230
0231 end
0232 end
0233
0234 switch PRT.model(modelid).output.fold(1).type
0235 case 'classifier'
0236
0237 pval_b_acc = total_greater_b_acc / n_perm;
0238 if pval_b_acc == 0
0239 pval_b_acc = 1./n_perm;
0240 end
0241
0242 pval_c_acc=zeros(n_class,1);
0243 for c=1:n_class
0244 pval_c_acc(c) = total_greater_c_acc(c) / n_perm;
0245 if pval_c_acc(c) == 0
0246 pval_c_acc(c) = 1./n_perm;
0247 end
0248 end
0249
0250 permutation.pvalue_b_acc = pval_b_acc;
0251 permutation.pvalue_c_acc = pval_c_acc;
0252
0253 case 'regression'
0254
0255 pval_corr = total_greater_corr / n_perm;
0256 if pval_corr == 0
0257 pval_corr = 1./n_perm;
0258 end
0259
0260 pval_mse = total_greater_mse / n_perm;
0261 if pval_mse == 0
0262 pval_mse = 1./n_perm;
0263 end
0264
0265 pval_nmse = total_greater_nmse / n_perm;
0266 if pval_nmse == 0
0267 pval_nmse = 1./n_perm;
0268 end
0269
0270 pval_r2 = total_greater_r2 / n_perm;
0271 if pval_r2 == 0
0272 pval_r2 = 1./n_perm;
0273 end
0274
0275 permutation.pval_corr = pval_corr;
0276 permutation.pval_mse = pval_mse;
0277 permutation.pval_nmse = pval_nmse;
0278 permutation.pval_r2 = pval_r2;
0279 end
0280
0281
0282
0283
0284 PRT.model(modelid).output.stats.permutation = permutation;
0285
0286
0287
0288 outfile = fullfile(path,'PRT.mat');
0289 disp('Updating PRT.mat.......>>')
0290 if spm_check_version('MATLAB','7') < 0
0291 save(outfile,'-V6','PRT');
0292 else
0293 save(outfile,'PRT');
0294 end
0295 disp('Permutation test done.')
0296 end
0297
0298 end
0299