0001 function [] = prt_permutation(PRT, n_perm, modelid, path)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030 prt_dir = [path];
0031
0032
0033
0034 if ~isfield(PRT,'model')
0035 beep
0036 disp('No model found in this PRT.mat');
0037 return
0038 else
0039 if ~isfield(PRT.model,'output')
0040 beep
0041 disp('No model output found in this PRT.mat')
0042 return
0043
0044 end
0045
0046
0047 CV = PRT.model(modelid).input.cv_mat;
0048 n_folds = size(CV,2);
0049 n_Phi = length(PRT.model(modelid).input.fs);
0050 samp_idx = PRT.model(modelid).input.samp_idx;
0051
0052
0053
0054 t = PRT.model(modelid).input.targets;
0055
0056
0057 Phi_all = cell(1,n_Phi);
0058 for i = 1:length(PRT.model(modelid).input.fs)
0059 if i == 1
0060 ID = PRT.fs(i).id_mat(PRT.model(modelid).input.samp_idx,:);
0061 end
0062 fid=find(strcmp({PRT.fs(:).fs_name},PRT.model(modelid).input.fs(i).fs_name));
0063
0064 if PRT.model(modelid).input.use_kernel
0065 load(fullfile(prt_dir, PRT.fs(i).k_file));
0066 Phi_all{i} = Phi(samp_idx,samp_idx);
0067 else
0068 error('training with features not implemented yet');
0069
0070 vname = whos('-file', [prt_dir,PRT.fs(fid).fs_file]);
0071 eval(['Phi_all{',num2str(i),'}=',vname,'(samp_idx,:);']);
0072 end
0073
0074 end
0075
0076
0077
0078
0079
0080 ids = PRT.fs(fid).id_mat(PRT.model(modelid).input.samp_idx,:);
0081 i=1;
0082 samp_g=unique(ids(:,1));
0083 for gid = 1: length(samp_g)
0084
0085 samp_s=unique(ids(ids(:,1)==samp_g(gid),2));
0086
0087 for sid = 1: length(samp_s)
0088
0089 samp_m=unique(ids(ids(:,1)==samp_g(gid) & ids(:,2)==samp_s(sid),3));
0090
0091 for mid = 1:length(samp_m)
0092
0093 samp_c=unique(ids(ids(:,1)==samp_g(gid) & ids(:,2)==samp_s(sid) & ids(:,3)==samp_m(mid),4));
0094
0095 for cid = 1:length(samp_c)
0096
0097 samp_b=unique(ids(ids(:,1)==samp_g(gid) & ids(:,2)==samp_s(sid) & ids(:,3)==samp_m(mid) & ids(:,4)==samp_c(cid),5));
0098
0099 for bid = 1:length(samp_b)
0100
0101 rg = find((ids(:,1) == samp_g(gid)) & ...
0102 (ids(:,2) == samp_s(sid)) & ...
0103 (ids(:,3) == samp_m(mid)) & ...
0104 (ids(:,4) == samp_c(cid)) & ...
0105 (ids(:,5) == samp_b(bid)));
0106
0107 chunks{i} =rg;
0108
0109 i=i+1;
0110 end
0111 end
0112 end
0113 end
0114 end
0115
0116
0117
0118
0119 switch PRT.model(modelid).output.fold(1).type
0120 case 'classifier'
0121 n_class = length(PRT.model(modelid).output.fold(1).stats.c_acc);
0122 total_greater_c_acc = zeros(n_class,1);
0123 total_greater_b_acc = 0;
0124
0125 case 'regression'
0126 total_greater_corr = 0;
0127 total_greater_mse = 0;
0128 end
0129
0130
0131
0132
0133 for p=1:n_perm
0134
0135 disp(sprintf('Permutation %d out of %d >>>>>>',p,n_perm));
0136
0137
0138 chunkperm=randperm(length(chunks));
0139 for i=1:length(chunks)
0140 t(chunks{i},1)= unique(PRT.model(modelid).input.targets(chunks{chunkperm(i)}));
0141 end
0142
0143 for f = 1:n_folds
0144
0145 fdata.ID = ID;
0146 fdata.mid = modelid;
0147 fdata.CV = CV(:,f);
0148 fdata.Phi_all = Phi_all;
0149 fdata.t = t;
0150
0151 [temp_model, targets] = prt_cv_fold(PRT,fdata);
0152
0153 model.output.fold(f).predictions = temp_model.predictions;
0154 model.output.fold(f).targets = targets.test;
0155
0156 end
0157
0158
0159 t = vertcat(model.output.fold(:).targets);
0160 m.type = PRT.model(modelid).output.fold(1).type;
0161 m.predictions = vertcat(model.output.fold(:).predictions);
0162 perm_stats = prt_stats(m,t,t);
0163
0164
0165 switch PRT.model(modelid).output.fold(1).type
0166
0167 case 'classifier'
0168
0169 permutation.b_acc(p)=perm_stats.b_acc;
0170 n_class = length(PRT.model(modelid).output.fold(1).stats.c_acc);
0171
0172 if (perm_stats.b_acc > PRT.model(modelid).output.stats.b_acc)
0173 total_greater_b_acc=total_greater_b_acc+1;
0174 end
0175
0176 for c=1:n_class
0177 permutation.c_acc(c,p)=perm_stats.c_acc(c);
0178 if (perm_stats.c_acc(c) > PRT.model(modelid).output.stats.c_acc(c))
0179 total_greater_c_acc(c)=total_greater_c_acc(c)+1;
0180 end
0181 end
0182
0183 case 'regression'
0184 permutation.corr(p)=perm_stats.corr;
0185 if (abs(perm_stats.corr) > abs(PRT.model(modelid).output.stats.corr))
0186 total_greater_corr=total_greater_corr+1;
0187 end
0188 permutation.mse(p)=perm_stats.mse;
0189 if (perm_stats.mse < PRT.model(modelid).output.stats.mse)
0190 total_greater_mse=total_greater_mse+1;
0191 end
0192
0193
0194 end
0195
0196
0197
0198
0199
0200 end
0201
0202
0203 switch PRT.model(modelid).output.fold(1).type
0204 case 'classifier'
0205
0206 pval_b_acc = total_greater_b_acc / n_perm;
0207 if pval_b_acc == 0
0208 pval_b_acc = 1./n_perm;
0209 end
0210
0211 pval_c_acc=zeros(n_class,1);
0212 for c=1:n_class
0213 pval_c_acc(c) = total_greater_c_acc(c) / n_perm;
0214 if pval_c_acc(c) == 0
0215 pval_c_acc(c) = 1./n_perm;
0216 end
0217 end
0218
0219 permutation.pvalue_b_acc = pval_b_acc;
0220 permutation.pvalue_c_acc = pval_c_acc;
0221
0222 case 'regression'
0223
0224 pval_corr = total_greater_corr / n_perm;
0225 if pval_corr == 0
0226 pval_corr = 1./n_perm;
0227 end
0228
0229 pval_mse = total_greater_mse / n_perm;
0230 if pval_mse == 0
0231 pval_mse = 1./n_perm;
0232 end
0233
0234 permutation.pval_corr = pval_corr;
0235 permutation.pval_mse = pval_mse;
0236
0237
0238 end
0239
0240
0241
0242
0243 PRT.model(modelid).output.stats.permutation = permutation;
0244
0245
0246
0247 outfile = fullfile(path,'PRT.mat');
0248 disp('Updating PRT.mat.......>>')
0249 if spm_matlab_version_chk('7') >= 0
0250 save(outfile,'-V6','PRT');
0251 else
0252 save(outfile,'PRT');
0253 end
0254 disp('Permutation test done.')
0255 end
0256
0257 end