0001 function [] = prt_permutation(PRT, n_perm, modelid, path)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030 prt_dir = [path];
0031
0032
0033
0034 if ~isfield(PRT,'model')
0035 beep
0036 disp('No model found in this PRT.mat');
0037 return
0038 else
0039 if ~isfield(PRT.model,'output')
0040 beep
0041 disp('No model output found in this PRT.mat')
0042 return
0043
0044 end
0045
0046
0047 CV = PRT.model(modelid).input.cv_mat;
0048 n_folds = size(CV,2);
0049 n_Phi = length(PRT.model(modelid).input.fs);
0050 samp_idx = PRT.model(modelid).input.samp_idx;
0051
0052
0053
0054 t = PRT.model(modelid).input.targets;
0055
0056
0057 Phi_all = cell(1,n_Phi);
0058 for i = 1:length(PRT.model(modelid).input.fs)
0059 if i == 1
0060 ID = PRT.fs(i).id_mat(PRT.model(modelid).input.samp_idx,:);
0061 end
0062 fid=find(strcmp({PRT.fs(:).fs_name},PRT.model(modelid).input.fs(i).fs_name));
0063
0064 if PRT.model(modelid).input.use_kernel
0065 load(fullfile(prt_dir, PRT.fs(i).k_file));
0066 Phi_all{i} = Phi(samp_idx,samp_idx);
0067 else
0068 error('training with features not implemented yet');
0069
0070 vname = whos('-file', [prt_dir,PRT.fs(fid).fs_file]);
0071 eval(['Phi_all{',num2str(i),'}=',vname,'(samp_idx,:);']);
0072 end
0073
0074 end
0075
0076
0077
0078
0079
0080 ids = PRT.fs(fid).id_mat(PRT.model(modelid).input.samp_idx,:);
0081 i=1;
0082 samp_g=unique(ids(:,1));
0083 for gid = 1: length(samp_g)
0084
0085 samp_s=unique(ids(ids(:,1)==samp_g(gid),2));
0086
0087 for sid = 1: length(samp_s)
0088
0089 samp_m=unique(ids(ids(:,1)==samp_g(gid) & ids(:,2)==samp_s(sid),3));
0090
0091 for mid = 1:length(samp_m)
0092
0093 samp_c=unique(ids(ids(:,1)==samp_g(gid) & ids(:,2)==samp_s(sid) & ids(:,3)==samp_m(mid),4));
0094
0095 for cid = 1:length(samp_c)
0096
0097 samp_b=unique(ids(ids(:,1)==samp_g(gid) & ids(:,2)==samp_s(sid) & ids(:,3)==samp_m(mid) & ids(:,4)==samp_c(cid),5));
0098
0099 for bid = 1:length(samp_b)
0100
0101 rg = find((ids(:,1) == samp_g(gid)) & ...
0102 (ids(:,2) == samp_s(sid)) & ...
0103 (ids(:,3) == samp_m(mid)) & ...
0104 (ids(:,4) == samp_c(cid)) & ...
0105 (ids(:,5) == samp_b(bid)));
0106
0107 chunks{i} =rg;
0108
0109 i=i+1;
0110 end
0111 end
0112 end
0113 end
0114 end
0115
0116
0117
0118
0119 switch PRT.model(modelid).output.fold(1).type
0120 case 'classifier'
0121 n_class = length(PRT.model(modelid).output.fold(1).stats.c_acc);
0122 total_greater_c_acc = zeros(n_class,1);
0123 total_greater_b_acc = 0;
0124
0125 case 'regression'
0126 total_greater_corr = 0;
0127 total_greater_mse = 0;
0128 end
0129
0130
0131
0132
0133 for p=1:n_perm
0134
0135
0136 chunkperm=randperm(length(chunks));
0137 for i=1:length(chunks)
0138 t(chunks{i},1)= unique(PRT.model(modelid).input.targets(chunks{chunkperm(i)}));
0139 end
0140
0141
0142
0143
0144 for f = 1:n_folds
0145
0146 tr_idx = CV(:,f) == 1;
0147 te_idx = CV(:,f) == 2;
0148
0149 [Phi_tr, Phi_te, Phi_tt] = ...
0150 split_data(Phi_all, tr_idx, te_idx, PRT.model(modelid).input.use_kernel);
0151
0152
0153
0154
0155
0156 cvdata.train = Phi_tr;
0157 cvdata.test = Phi_te;
0158 if PRT.model(modelid).input.use_kernel
0159 cvdata.testcov = Phi_tt;
0160 end
0161 cvdata.tr_targets = t(tr_idx,:);
0162 cvdata.te_targets = t(te_idx,:);
0163 cvdata.tr_id = ID(tr_idx,:);
0164 cvdata.te_id = ID(te_idx,:);
0165 cvdata.use_kernel = PRT.model(modelid).input.use_kernel;
0166 cvdata.pred_type = PRT.model(modelid).input.type;
0167
0168 cvdata.tr_param = prt_cv_opt_param(PRT, ID(tr_idx,:), CV(tr_idx,f));
0169 cvdata.te_param = prt_cv_opt_param(PRT, ID(te_idx,:), CV(te_idx,f));
0170
0171
0172 ops = PRT.model(modelid).input.operations(PRT.model(modelid).input.operations ~=0 );
0173 for o = 1:length(ops)
0174 cvdata = prt_apply_operation(PRT, cvdata, ops(o));
0175 end
0176
0177
0178 try
0179 temp_model = prt_machine(cvdata, PRT.model(modelid).input.machine);
0180 model.output.fold(f).predictions = temp_model.predictions;
0181 catch err
0182 warning('prt_permutation:modelDidNotReturn',...
0183 'Prediction method did not return [%s]',err.message);
0184 temp_model.predictions = zeros(size(cvdata.te_targets));
0185 end
0186 model.output.fold(f).targets = cvdata.te_targets;
0187
0188
0189
0190 end
0191
0192
0193 t=cat(1,model.output.fold(:).targets);
0194 m.type=PRT.model(modelid).output.fold(1).type;
0195 m.predictions=cat(1,model.output.fold(:).predictions);
0196 m.predictions=m.predictions(:);
0197 t=t(:);
0198 perm_stats=prt_stats(m,t,m.predictions(:));
0199
0200
0201
0202
0203 switch PRT.model(modelid).output.fold(1).type
0204
0205 case 'classifier'
0206
0207 permutation.b_acc(p)=perm_stats.b_acc;
0208 n_class = length(PRT.model(modelid).output.fold(1).stats.c_acc);
0209
0210 if (perm_stats.b_acc > PRT.model(modelid).output.stats.b_acc)
0211 total_greater_b_acc=total_greater_b_acc+1;
0212 end
0213
0214 for c=1:n_class
0215 permutation.c_acc(c,p)=perm_stats.c_acc(c);
0216 if (perm_stats.c_acc(c) > PRT.model(modelid).output.stats.c_acc(c))
0217 total_greater_c_acc(c)=total_greater_c_acc(c)+1;
0218 end
0219 end
0220
0221 case 'regression'
0222 permutation.corr(p)=perm_stats.corr;
0223 if (perm_stats.corr > PRT.model(modelid).output.stats.corr)
0224 total_greater_corr=total_greater_corr+1;
0225 end
0226 permutation.mse(p)=perm_stats.mse;
0227 if (perm_stats.mse < PRT.model(modelid).output.stats.mse)
0228 total_greater_mse=total_greater_mse+1;
0229 end
0230
0231
0232 end
0233
0234
0235
0236
0237
0238 end
0239
0240
0241 switch PRT.model(modelid).output.fold(1).type
0242 case 'classifier'
0243
0244 pval_b_acc = total_greater_b_acc / n_perm;
0245 if pval_b_acc == 0
0246 pval_b_acc = 1./n_perm;
0247 end
0248
0249 pval_c_acc=zeros(n_class,1);
0250 for c=1:n_class
0251 pval_c_acc(c) = total_greater_c_acc(c) / n_perm;
0252 if pval_c_acc(c) == 0
0253 pval_c_acc(c) = 1./n_perm;
0254 end
0255 end
0256
0257 permutation.pvalue_b_acc = pval_b_acc;
0258 permutation.pvalue_c_acc = pval_c_acc;
0259
0260 case 'regression'
0261
0262 pval_corr = total_greater_corr / n_perm;
0263 if pval_corr == 0
0264 pval_corr = 1./n_perm;
0265 end
0266
0267 pval_mse = total_greater_mse / n_perm;
0268 if pval_mse == 0
0269 pval_mse = 1./n_perm;
0270 end
0271
0272 permutation.pval_corr = pval_corr;
0273 permutation.pval_mse = pval_mse;
0274
0275
0276 end
0277
0278
0279
0280
0281 PRT.model(modelid).output.stats.permutation = permutation;
0282
0283
0284
0285 outfile = fullfile(path,'PRT.mat');
0286 disp('Updating PRT.mat.......>>')
0287 if spm_matlab_version_chk('7') >= 0
0288 save(outfile,'-V6','PRT');
0289 else
0290 save(outfile,'PRT');
0291 end
0292 disp('Permutation test done.')
0293 end
0294
0295 end
0296
0297
0298
0299
0300 function [Phi_tr Phi_te Phi_tt] = split_data(Phi_all, tr_idx, te_idx, usebf)
0301
0302
0303 n_mat = length(Phi_all);
0304
0305
0306 Phi_tr = cell(1,n_mat);
0307 for i = 1:n_mat;
0308 if usebf
0309 cols_tr = tr_idx;
0310 else
0311 cols_tr = size(Phi_all{i},2);
0312 end
0313
0314 Phi_tr{i} = Phi_all{i}(tr_idx,cols_tr);
0315 end
0316
0317
0318 Phi_te = cell(1,n_mat);
0319 Phi_tt = cell(1,n_mat);
0320 if usebf
0321 cols_tr = tr_idx;
0322 cols_te = te_idx;
0323 else
0324 cols_tr = size(Phi_all{i},2);
0325
0326 end
0327
0328 for i = 1:length(Phi_all)
0329 Phi_te{i} = Phi_all{i}(te_idx, cols_tr);
0330 if usebf
0331 Phi_tt{i} = Phi_all{i}(te_idx, cols_te);
0332 else
0333 Phi_tt{i} = [];
0334 end
0335 end
0336 end