Home > . > prt_permutation.m

prt_permutation

PURPOSE ^

Function to compute permutation test

SYNOPSIS ^

function [] = prt_permutation(PRT, n_perm, modelid, path)

DESCRIPTION ^

 Function to compute permutation test

 Inputs:
 -------
 PRT: PRT structured including model
 n_permu: number of permutations
 modelid: model ID

 Outputs:
 --------

 for classification
 permutation.c_acc: Permuted accuracy per class
 permutation.b_acc: Permuted balanced accuracy
 permutation.pvalue_b_acc: p-value for c_acc
 permutation.pvalue_c_acc: p-value for b_acc

 for regression
 permutation.corr: Permuted correlation
 permutation.mse: Permuted mean square error
 permutation.corr: p-value for corr
 permutation.mse: p-value for mse
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [] = prt_permutation(PRT, n_perm, modelid, path)
0002 % Function to compute permutation test
0003 %
0004 % Inputs:
0005 % -------
0006 % PRT: PRT structured including model
0007 % n_permu: number of permutations
0008 % modelid: model ID
0009 %
0010 % Outputs:
0011 % --------
0012 %
0013 % for classification
0014 % permutation.c_acc: Permuted accuracy per class
0015 % permutation.b_acc: Permuted balanced accuracy
0016 % permutation.pvalue_b_acc: p-value for c_acc
0017 % permutation.pvalue_c_acc: p-value for b_acc
0018 %
0019 % for regression
0020 % permutation.corr: Permuted correlation
0021 % permutation.mse: Permuted mean square error
0022 % permutation.corr: p-value for corr
0023 % permutation.mse: p-value for mse
0024 %__________________________________________________________________________
0025 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0026 
0027 % Written by J. Mourao-Miranda
0028 % $Id: prt_permutation.m 518 2012-04-17 10:01:12Z cphillip $
0029 
0030 prt_dir = [path];
0031 
0032 % % prt_dir = char(regexprep(in.fname,'PRT.mat', ''));
0033 
0034 if ~isfield(PRT,'model')
0035     beep
0036     disp('No model found in this PRT.mat');
0037     return
0038 else
0039     if ~isfield(PRT.model,'output')
0040         beep
0041         disp('No model output found in this PRT.mat')
0042         return
0043         
0044     end
0045     
0046     % configure some variables
0047     CV       = PRT.model(modelid).input.cv_mat;     % CV matrix
0048     n_folds  = size(CV,2);                      % number of CV folds
0049     n_Phi    = length(PRT.model(modelid).input.fs); % number of data matrices
0050     samp_idx = PRT.model(modelid).input.samp_idx;   % which samples are in the model
0051     
0052     
0053     % targets
0054     t = PRT.model(modelid).input.targets;
0055     
0056     % load data files and configure ID matrix
0057     Phi_all = cell(1,n_Phi);
0058     for i = 1:length(PRT.model(modelid).input.fs)
0059         if i == 1
0060             ID = PRT.fs(i).id_mat(PRT.model(modelid).input.samp_idx,:);
0061         end
0062         fid=find(strcmp({PRT.fs(:).fs_name},PRT.model(modelid).input.fs(i).fs_name));
0063         
0064         if PRT.model(modelid).input.use_kernel
0065             load(fullfile(prt_dir, PRT.fs(i).k_file));
0066             Phi_all{i} = Phi(samp_idx,samp_idx);
0067         else
0068             error('training with features not implemented yet');
0069             % this should be improved (e.g. need to load feat_idx)
0070             vname = whos('-file', [prt_dir,PRT.fs(fid).fs_file]);
0071             eval(['Phi_all{',num2str(i),'}=',vname,'(samp_idx,:);']);
0072         end
0073         
0074     end
0075     
0076     
0077     % Find chunks in the data (e.g. temporal correlated samples)
0078     % -------------------------------------------------------------------------
0079     
0080     ids = PRT.fs(fid).id_mat(PRT.model(modelid).input.samp_idx,:);
0081     i=1;
0082     samp_g=unique(ids(:,1));%number of groups
0083     for gid = 1: length(samp_g)
0084         
0085         samp_s=unique(ids(ids(:,1)==samp_g(gid),2)); %number of subjects for specific group
0086         
0087         for sid = 1: length(samp_s)
0088             
0089             samp_m=unique(ids(ids(:,1)==samp_g(gid) & ids(:,2)==samp_s(sid),3)); %number of modality for specific group & subject
0090             
0091             for mid = 1:length(samp_m)
0092                 
0093                 samp_c=unique(ids(ids(:,1)==samp_g(gid) & ids(:,2)==samp_s(sid) & ids(:,3)==samp_m(mid),4)); %number of conditions for specific group & subject & modality
0094                 
0095                 for cid = 1:length(samp_c)
0096                     
0097                     samp_b=unique(ids(ids(:,1)==samp_g(gid) & ids(:,2)==samp_s(sid) & ids(:,3)==samp_m(mid) & ids(:,4)==samp_c(cid),5));  %number of blocks for specific group & subject & modality & conditions
0098                     
0099                     for bid = 1:length(samp_b)
0100                         
0101                         rg = find((ids(:,1) == samp_g(gid)) & ...
0102                             (ids(:,2) == samp_s(sid)) & ...
0103                             (ids(:,3) == samp_m(mid)) & ...
0104                             (ids(:,4) == samp_c(cid)) & ...
0105                             (ids(:,5) == samp_b(bid)));
0106                         
0107                         chunks{i} =rg;
0108                         
0109                         i=i+1;
0110                     end
0111                 end
0112             end
0113         end
0114     end
0115     
0116     
0117     % Initialize counts
0118     % -------------------------------------------------------------------------
0119     switch PRT.model(modelid).output.fold(1).type
0120         case 'classifier'
0121             n_class = length(PRT.model(modelid).output.fold(1).stats.c_acc);
0122             total_greater_c_acc = zeros(n_class,1);
0123             total_greater_b_acc = 0;
0124             
0125         case 'regression'
0126             total_greater_corr = 0;
0127             total_greater_mse = 0;
0128     end
0129     
0130     % Run model with permuted labels
0131     % -------------------------------------------------------------------------
0132     
0133     for p=1:n_perm
0134         
0135         % permute labels
0136         chunkperm=randperm(length(chunks));
0137         for i=1:length(chunks)
0138             t(chunks{i},1)= unique(PRT.model(modelid).input.targets(chunks{chunkperm(i)}));
0139         end
0140         
0141         %t=PRT.model(modelid).input.targets(randperm(length(PRT.model(modelid).input.targets))); %this should take into account the correlation structure
0142         
0143         
0144         for f = 1:n_folds
0145             % configure training and test indices
0146             tr_idx = CV(:,f) == 1;
0147             te_idx = CV(:,f) == 2;
0148             
0149             [Phi_tr, Phi_te, Phi_tt] = ...
0150                 split_data(Phi_all, tr_idx, te_idx, PRT.model(modelid).input.use_kernel);
0151             
0152             %Centre kernel
0153             %[Phi_tr, Phi_te, Phi_tt] = prt_centre_kernel(Phi_tr, Phi_te, Phi_tt);
0154             
0155             % Assemble data structure to supply to machine
0156             cvdata.train      = Phi_tr;
0157             cvdata.test       = Phi_te;
0158             if PRT.model(modelid).input.use_kernel
0159                 cvdata.testcov    = Phi_tt;
0160             end
0161             cvdata.tr_targets = t(tr_idx,:);
0162             cvdata.te_targets = t(te_idx,:);
0163             cvdata.tr_id      = ID(tr_idx,:);
0164             cvdata.te_id      = ID(te_idx,:);
0165             cvdata.use_kernel = PRT.model(modelid).input.use_kernel;
0166             cvdata.pred_type  = PRT.model(modelid).input.type;
0167             % additional parameters (e.g. for MCKR)
0168             cvdata.tr_param  = prt_cv_opt_param(PRT, ID(tr_idx,:), CV(tr_idx,f));
0169             cvdata.te_param  = prt_cv_opt_param(PRT, ID(te_idx,:), CV(te_idx,f));
0170             
0171             % Apply any operations specified
0172             ops = PRT.model(modelid).input.operations(PRT.model(modelid).input.operations ~=0 );
0173             for o = 1:length(ops)
0174                 cvdata = prt_apply_operation(PRT, cvdata, ops(o));
0175             end
0176             
0177             % train the prediction model
0178             try 
0179                 temp_model = prt_machine(cvdata, PRT.model(modelid).input.machine);
0180                 model.output.fold(f).predictions = temp_model.predictions;
0181             catch err
0182                 warning('prt_permutation:modelDidNotReturn',...
0183                         'Prediction method did not return [%s]',err.message);
0184                 temp_model.predictions = zeros(size(cvdata.te_targets));
0185             end
0186             model.output.fold(f).targets = cvdata.te_targets;
0187             
0188             
0189             
0190         end
0191         
0192         %Model level statistics
0193         t=cat(1,model.output.fold(:).targets); % account for unequal fold sizes
0194         m.type=PRT.model(modelid).output.fold(1).type;
0195         m.predictions=cat(1,model.output.fold(:).predictions);
0196         m.predictions=m.predictions(:); % make extra sure (this can't really happen)
0197         t=t(:);
0198         perm_stats=prt_stats(m,t,m.predictions(:));
0199         %perm_stats=prt_stats(m,t,'model');
0200         %permutation.perm_stats(p)=stats;
0201         %end
0202         
0203         switch PRT.model(modelid).output.fold(1).type
0204             
0205             case 'classifier'
0206                 
0207                 permutation.b_acc(p)=perm_stats.b_acc;
0208                 n_class = length(PRT.model(modelid).output.fold(1).stats.c_acc);
0209                 
0210                 if (perm_stats.b_acc > PRT.model(modelid).output.stats.b_acc)
0211                     total_greater_b_acc=total_greater_b_acc+1;
0212                 end
0213                 
0214                 for c=1:n_class
0215                     permutation.c_acc(c,p)=perm_stats.c_acc(c);
0216                     if (perm_stats.c_acc(c) > PRT.model(modelid).output.stats.c_acc(c))
0217                         total_greater_c_acc(c)=total_greater_c_acc(c)+1;
0218                     end
0219                 end
0220                 
0221             case 'regression'
0222                 permutation.corr(p)=perm_stats.corr;
0223                 if (perm_stats.corr > PRT.model(modelid).output.stats.corr)
0224                     total_greater_corr=total_greater_corr+1;
0225                 end
0226                 permutation.mse(p)=perm_stats.mse;
0227                 if (perm_stats.mse < PRT.model(modelid).output.stats.mse)
0228                     total_greater_mse=total_greater_mse+1;
0229                 end
0230                 
0231                 
0232         end
0233         
0234         
0235         
0236         
0237         
0238     end
0239     
0240     
0241     switch PRT.model(modelid).output.fold(1).type
0242         case 'classifier'
0243             
0244             pval_b_acc = total_greater_b_acc / n_perm;
0245             if pval_b_acc == 0
0246                 pval_b_acc = 1./n_perm;
0247             end
0248             
0249             pval_c_acc=zeros(n_class,1);
0250             for c=1:n_class
0251                 pval_c_acc(c) = total_greater_c_acc(c) / n_perm;
0252                 if pval_c_acc(c) == 0
0253                     pval_c_acc(c) = 1./n_perm;
0254                 end
0255             end
0256             
0257             permutation.pvalue_b_acc = pval_b_acc;
0258             permutation.pvalue_c_acc = pval_c_acc;
0259             
0260         case 'regression'
0261             
0262             pval_corr = total_greater_corr / n_perm;
0263             if pval_corr == 0
0264                 pval_corr = 1./n_perm;
0265             end
0266             
0267             pval_mse = total_greater_mse / n_perm;
0268             if pval_mse == 0
0269                 pval_mse = 1./n_perm;
0270             end
0271             
0272             permutation.pval_corr = pval_corr;
0273             permutation.pval_mse = pval_mse;
0274             
0275             
0276     end
0277     
0278     
0279     
0280     %update PRT
0281     PRT.model(modelid).output.stats.permutation = permutation;
0282     
0283     % Save PRT containing machine output
0284     % -------------------------------------------------------------------------
0285     outfile = fullfile(path,'PRT.mat');
0286     disp('Updating PRT.mat.......>>')
0287     if spm_matlab_version_chk('7') >= 0
0288         save(outfile,'-V6','PRT');
0289     else
0290         save(outfile,'PRT');
0291     end
0292     disp('Permutation test done.')
0293 end
0294 
0295 end
0296 % -------------------------------------------------------------------------
0297 % Private functions
0298 % -------------------------------------------------------------------------
0299 
0300 function [Phi_tr Phi_te Phi_tt] = split_data(Phi_all, tr_idx, te_idx, usebf)
0301 % function to split the data matrix into training and test
0302 
0303 n_mat = length(Phi_all);
0304 
0305 % training
0306 Phi_tr = cell(1,n_mat);
0307 for i = 1:n_mat;
0308     if usebf
0309         cols_tr = tr_idx;
0310     else
0311         cols_tr = size(Phi_all{i},2);
0312     end
0313     
0314     Phi_tr{i} = Phi_all{i}(tr_idx,cols_tr);
0315 end
0316 
0317 % test
0318 Phi_te  = cell(1,n_mat);
0319 Phi_tt = cell(1,n_mat);
0320 if usebf
0321     cols_tr = tr_idx;
0322     cols_te = te_idx;
0323 else
0324     cols_tr = size(Phi_all{i},2);
0325     %cols_te = size(Phi_all{i},2);
0326 end
0327 
0328 for i = 1:length(Phi_all)
0329     Phi_te{i} = Phi_all{i}(te_idx, cols_tr);
0330     if usebf
0331         Phi_tt{i} = Phi_all{i}(te_idx, cols_te);
0332     else
0333         Phi_tt{i} = [];
0334     end
0335 end
0336 end

Generated on Sun 20-May-2012 13:24:48 by m2html © 2005