Home > . > prt_permutation.m

prt_permutation

PURPOSE ^

Function to compute permutation test

SYNOPSIS ^

function [] = prt_permutation(PRT, n_perm, modelid, path, flag)

DESCRIPTION ^

 Function to compute permutation test

 Inputs:
 -------
 PRT:     PRT structured including model
 n_perm:  number of permutations
 modelid: model ID
 path:    path
 flag:    boolean variable. set to 1 to save the weights for each
          permutation. default: 0

 Outputs:
 --------

 for classification
 permutation.c_acc:        Permuted accuracy per class
 permutation.b_acc:        Permuted balanced accuracy
 permutation.pvalue_b_acc: p-value for c_acc
 permutation.pvalue_c_acc: p-value for b_acc

 for regression
 permutation.corr: Permuted correlation
 permutation.mse:  Permuted mean square error
 permutation.pval_corr: p-value for corr
 permutation.pval_r2: p-value for r2;
 permutation.pval_mse:  p-value for mse
 permutation.pval_nmse:  p-value for nmse
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [] = prt_permutation(PRT, n_perm, modelid, path, flag)
0002 % Function to compute permutation test
0003 %
0004 % Inputs:
0005 % -------
0006 % PRT:     PRT structured including model
0007 % n_perm:  number of permutations
0008 % modelid: model ID
0009 % path:    path
0010 % flag:    boolean variable. set to 1 to save the weights for each
0011 %          permutation. default: 0
0012 %
0013 % Outputs:
0014 % --------
0015 %
0016 % for classification
0017 % permutation.c_acc:        Permuted accuracy per class
0018 % permutation.b_acc:        Permuted balanced accuracy
0019 % permutation.pvalue_b_acc: p-value for c_acc
0020 % permutation.pvalue_c_acc: p-value for b_acc
0021 %
0022 % for regression
0023 % permutation.corr: Permuted correlation
0024 % permutation.mse:  Permuted mean square error
0025 % permutation.pval_corr: p-value for corr
0026 % permutation.pval_r2: p-value for r2;
0027 % permutation.pval_mse:  p-value for mse
0028 % permutation.pval_nmse:  p-value for nmse
0029 %__________________________________________________________________________
0030 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0031 
0032 % Written by J. Mourao-Miranda
0033 % $Id$
0034 
0035 prt_dir = path;
0036 def_par = prt_get_defaults('paral');
0037 if nargin<5
0038     flag=0;
0039 end
0040 
0041 % % prt_dir = char(regexprep(in.fname,'PRT.mat', ''));
0042 
0043 if ~isfield(PRT,'model')
0044     beep
0045     disp('No model found in this PRT.mat');
0046     return
0047 else
0048     if ~isfield(PRT.model,'output')
0049         beep
0050         disp('No model output found in this PRT.mat')
0051         return
0052         
0053     end
0054     
0055     % configure some variables
0056     CV       = PRT.model(modelid).input.cv_mat;     % CV matrix
0057     n_folds  = size(CV,2);                      % number of CV folds
0058     
0059     % parralel code?
0060     if def_par.allow
0061         try
0062             matlabpool(def_par.ncore)
0063         catch
0064             warning('Could not use pool of Matlab processes!')
0065         end
0066     end
0067     
0068     % targets
0069     t = PRT.model(modelid).input.targets;
0070     
0071     % load data files and configure ID matrix
0072     [Phi_all,ID,fid] = prt_getKernelModel(PRT,prt_dir,modelid);
0073 
0074     %get number of classes
0075     if strcmpi(PRT.model(modelid).input.type,'classification')
0076         nc=max(unique(t));
0077     else
0078         nc=[];
0079     end
0080     fdata.nc = nc;
0081     
0082     % Find chunks in the data (e.g. temporal correlated samples)
0083     % -------------------------------------------------------------------------
0084     
0085     ids = PRT.fs(fid).id_mat(PRT.model(modelid).input.samp_idx,:);
0086     i=1;
0087     samp_g=unique(ids(:,1));%number of groups
0088     for gid = 1: length(samp_g)
0089         
0090         samp_s=unique(ids(ids(:,1)==samp_g(gid),2)); %number of subjects for specific group
0091         
0092         for sid = 1: length(samp_s)
0093             
0094             samp_m=unique(ids(ids(:,1)==samp_g(gid) & ids(:,2)==samp_s(sid),3)); %number of modality for specific group & subject
0095             
0096             for mid = 1:length(samp_m)
0097                 
0098                 samp_c=unique(ids(ids(:,1)==samp_g(gid) & ids(:,2)==samp_s(sid) & ids(:,3)==samp_m(mid),4)); %number of conditions for specific group & subject & modality
0099                 
0100                 for cid = 1:length(samp_c)
0101                     
0102                     samp_b=unique(ids(ids(:,1)==samp_g(gid) & ids(:,2)==samp_s(sid) & ids(:,3)==samp_m(mid) & ids(:,4)==samp_c(cid),5));  %number of blocks for specific group & subject & modality & conditions
0103                     
0104                     for bid = 1:length(samp_b)
0105                         
0106                         rg = find((ids(:,1) == samp_g(gid)) & ...
0107                             (ids(:,2) == samp_s(sid)) & ...
0108                             (ids(:,3) == samp_m(mid)) & ...
0109                             (ids(:,4) == samp_c(cid)) & ...
0110                             (ids(:,5) == samp_b(bid)));
0111                         
0112                         chunks{i} = rg;
0113                         
0114                         i=i+1;
0115                     end
0116                 end
0117             end
0118         end
0119     end
0120     
0121  
0122     
0123     % Initialize counts
0124     % -------------------------------------------------------------------------
0125     switch PRT.model(modelid).output.fold(1).type
0126         case 'classifier'
0127             n_class = length(PRT.model(modelid).output.fold(1).stats.c_acc);
0128             total_greater_c_acc = zeros(n_class,1);
0129             total_greater_b_acc = 0;
0130             
0131         case 'regression'
0132             total_greater_corr = 0;
0133             total_greater_mse = 0;
0134             total_greater_nmse = 0;
0135             total_greater_r2 = 0;
0136     end
0137     
0138     % Run model with permuted labels
0139     % -------------------------------------------------------------------------
0140     if ~isfield(PRT.model(modelid).output,'permutation') || ...
0141         (isfield(PRT.model(modelid).output,'permutation') && flag) %Back to empty to save other perm param
0142         PRT.model(modelid).output.permutation=struct('fold',[]);
0143     end
0144     for p=1:n_perm
0145         
0146         disp(sprintf('Permutation %d out of %d >>>>>>',p,n_perm));
0147         
0148         % permute
0149         chunkperm=randperm(length(chunks));
0150         CVperm = zeros(size(CV));
0151         t_perm = zeros(length(t),1);
0152         for i=1:length(chunks)
0153             t_perm(chunks{i},1)= unique(PRT.model(modelid).input.targets(chunks{chunkperm(i)})); 
0154             CVperm(chunks{i},:) = CV(chunks{chunkperm(i)},:);
0155         end
0156                 
0157         for f = 1:n_folds
0158             % configure data structure for prt_cv_fold
0159             fdata.ID      = ID;
0160             fdata.mid     = modelid;
0161             fdata.CV      = CVperm(:,f);
0162             fdata.Phi_all = Phi_all;
0163             fdata.t       = t_perm;
0164             
0165             % Nested CV for hyper-parameter optimisation or feature selection
0166             if isfield(PRT.model(modelid).input,'use_nested_cv')
0167                 if PRT.model(modelid).input.use_nested_cv
0168                     [out] = prt_nested_cv(PRT, fdata);
0169                     PRT.model(modelid).output.fold(f).param_effect = out;
0170                     PRT.model(modelid).input.machine.args = out.opt_param;
0171                 end
0172             end
0173             
0174             [temp_model, targets] = prt_cv_fold(PRT,fdata);
0175             
0176             % save the weights per fold to further compute ranking distance
0177             if flag
0178                 PRT.model(modelid).output.permutation(p).fold(f).alpha=temp_model.alpha;
0179                 PRT.model(modelid).output.permutation(p).fold(f).pred=temp_model.predictions;
0180             end
0181             
0182             model.output.fold(f).predictions = temp_model.predictions;
0183             model.output.fold(f).targets     = targets.test;
0184             
0185         end
0186         
0187         % Model level statistics (across folds)
0188         t             = vertcat(model.output.fold(:).targets);
0189         m.type        = PRT.model(modelid).output.fold(1).type;
0190         m.predictions = vertcat(model.output.fold(:).predictions);
0191         perm_stats         = prt_stats(m,t,t);
0192         
0193         
0194         switch PRT.model(modelid).output.fold(1).type
0195             
0196             case 'classifier'
0197                 
0198                 permutation.b_acc(p)=perm_stats.b_acc;
0199                 n_class = length(PRT.model(modelid).output.fold(1).stats.c_acc);
0200                 
0201                 if (perm_stats.b_acc >= PRT.model(modelid).output.stats.b_acc)
0202                     total_greater_b_acc=total_greater_b_acc+1;
0203                 end
0204                 
0205                 for c=1:n_class
0206                     permutation.c_acc(c,p)=perm_stats.c_acc(c);
0207                     if (perm_stats.c_acc(c) >= PRT.model(modelid).output.stats.c_acc(c))
0208                         total_greater_c_acc(c)=total_greater_c_acc(c)+1;
0209                     end
0210                 end
0211                 
0212             case 'regression'
0213                 permutation.corr(p)=perm_stats.corr;
0214                 if (perm_stats.corr >= PRT.model(modelid).output.stats.corr)
0215                     total_greater_corr=total_greater_corr+1;
0216                 end
0217                 permutation.mse(p)=perm_stats.mse;
0218                 if (perm_stats.mse <= PRT.model(modelid).output.stats.mse)
0219                     total_greater_mse=total_greater_mse+1;
0220                 end
0221                 permutation.nmse(p)=perm_stats.nmse;
0222                 if (perm_stats.nmse <= PRT.model(modelid).output.stats.nmse)
0223                     total_greater_nmse=total_greater_nmse+1;
0224                 end
0225                 permutation.r2(p)=perm_stats.r2;
0226                 if (perm_stats.r2 >= PRT.model(modelid).output.stats.r2)
0227                     total_greater_r2=total_greater_r2+1;
0228                 end
0229                 
0230                 
0231         end
0232     end
0233     
0234     switch PRT.model(modelid).output.fold(1).type
0235         case 'classifier'
0236             
0237             pval_b_acc = total_greater_b_acc / n_perm;
0238             if pval_b_acc == 0
0239                 pval_b_acc = 1./n_perm;
0240             end
0241             
0242             pval_c_acc=zeros(n_class,1);
0243             for c=1:n_class
0244                 pval_c_acc(c) = total_greater_c_acc(c) / n_perm;
0245                 if pval_c_acc(c) == 0
0246                     pval_c_acc(c) = 1./n_perm;
0247                 end
0248             end
0249             
0250             permutation.pvalue_b_acc = pval_b_acc;
0251             permutation.pvalue_c_acc = pval_c_acc;
0252             
0253         case 'regression'
0254             
0255             pval_corr = total_greater_corr / n_perm;
0256             if pval_corr == 0
0257                 pval_corr = 1./n_perm;
0258             end
0259             
0260             pval_mse = total_greater_mse / n_perm;
0261             if pval_mse == 0
0262                 pval_mse = 1./n_perm;
0263             end
0264             
0265             pval_nmse = total_greater_nmse / n_perm;
0266             if pval_nmse == 0
0267                 pval_nmse = 1./n_perm;
0268             end
0269             
0270             pval_r2 = total_greater_r2 / n_perm;
0271             if pval_r2 == 0
0272                 pval_r2 = 1./n_perm;
0273             end
0274             
0275             permutation.pval_corr = pval_corr;
0276             permutation.pval_mse = pval_mse;
0277             permutation.pval_nmse = pval_nmse;
0278             permutation.pval_r2 = pval_r2;
0279     end
0280     
0281     
0282     
0283     %update PRT
0284     PRT.model(modelid).output.stats.permutation = permutation;
0285     
0286     % Save PRT containing machine output
0287     % -------------------------------------------------------------------------
0288     outfile = fullfile(path,'PRT.mat');
0289     disp('Updating PRT.mat.......>>')
0290     if spm_check_version('MATLAB','7') < 0
0291         save(outfile,'-V6','PRT');
0292     else
0293         save(outfile,'PRT');
0294     end
0295     disp('Permutation test done.')
0296 end
0297 
0298 end
0299

Generated on Tue 10-Feb-2015 18:16:33 by m2html © 2005