Home > . > prt_permutation.m

prt_permutation

PURPOSE ^

Function to compute permutation test

SYNOPSIS ^

function [] = prt_permutation(PRT, n_perm, modelid, path)

DESCRIPTION ^

 Function to compute permutation test

 Inputs:
 -------
 PRT: PRT structured including model
 n_permu: number of permutations
 modelid: model ID

 Outputs:
 --------

 for classification
 permutation.c_acc: Permuted accuracy per class
 permutation.b_acc: Permuted balanced accuracy
 permutation.pvalue_b_acc: p-value for c_acc
 permutation.pvalue_c_acc: p-value for b_acc

 for regression
 permutation.corr: Permuted correlation
 permutation.mse: Permuted mean square error
 permutation.corr: p-value for corr
 permutation.mse: p-value for mse
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [] = prt_permutation(PRT, n_perm, modelid, path)
0002 % Function to compute permutation test
0003 %
0004 % Inputs:
0005 % -------
0006 % PRT: PRT structured including model
0007 % n_permu: number of permutations
0008 % modelid: model ID
0009 %
0010 % Outputs:
0011 % --------
0012 %
0013 % for classification
0014 % permutation.c_acc: Permuted accuracy per class
0015 % permutation.b_acc: Permuted balanced accuracy
0016 % permutation.pvalue_b_acc: p-value for c_acc
0017 % permutation.pvalue_c_acc: p-value for b_acc
0018 %
0019 % for regression
0020 % permutation.corr: Permuted correlation
0021 % permutation.mse: Permuted mean square error
0022 % permutation.corr: p-value for corr
0023 % permutation.mse: p-value for mse
0024 %__________________________________________________________________________
0025 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0026 
0027 % Written by J. Mourao-Miranda
0028 % $Id: prt_permutation.m 577 2012-09-01 23:47:36Z mjrosa $
0029 
0030 prt_dir = [path];
0031 
0032 % % prt_dir = char(regexprep(in.fname,'PRT.mat', ''));
0033 
0034 if ~isfield(PRT,'model')
0035     beep
0036     disp('No model found in this PRT.mat');
0037     return
0038 else
0039     if ~isfield(PRT.model,'output')
0040         beep
0041         disp('No model output found in this PRT.mat')
0042         return
0043         
0044     end
0045     
0046     % configure some variables
0047     CV       = PRT.model(modelid).input.cv_mat;     % CV matrix
0048     n_folds  = size(CV,2);                      % number of CV folds
0049     n_Phi    = length(PRT.model(modelid).input.fs); % number of data matrices
0050     samp_idx = PRT.model(modelid).input.samp_idx;   % which samples are in the model
0051     
0052     
0053     % targets
0054     t = PRT.model(modelid).input.targets;
0055     
0056     % load data files and configure ID matrix
0057     Phi_all = cell(1,n_Phi);
0058     for i = 1:length(PRT.model(modelid).input.fs)
0059         if i == 1
0060             ID = PRT.fs(i).id_mat(PRT.model(modelid).input.samp_idx,:);
0061         end
0062         fid=find(strcmp({PRT.fs(:).fs_name},PRT.model(modelid).input.fs(i).fs_name));
0063         
0064         if PRT.model(modelid).input.use_kernel
0065             load(fullfile(prt_dir, PRT.fs(i).k_file));
0066             Phi_all{i} = Phi(samp_idx,samp_idx);
0067         else
0068             error('training with features not implemented yet');
0069             % this should be improved (e.g. need to load feat_idx)
0070             vname = whos('-file', [prt_dir,PRT.fs(fid).fs_file]);
0071             eval(['Phi_all{',num2str(i),'}=',vname,'(samp_idx,:);']);
0072         end
0073         
0074     end
0075     
0076     
0077     % Find chunks in the data (e.g. temporal correlated samples)
0078     % -------------------------------------------------------------------------
0079     
0080     ids = PRT.fs(fid).id_mat(PRT.model(modelid).input.samp_idx,:);
0081     i=1;
0082     samp_g=unique(ids(:,1));%number of groups
0083     for gid = 1: length(samp_g)
0084         
0085         samp_s=unique(ids(ids(:,1)==samp_g(gid),2)); %number of subjects for specific group
0086         
0087         for sid = 1: length(samp_s)
0088             
0089             samp_m=unique(ids(ids(:,1)==samp_g(gid) & ids(:,2)==samp_s(sid),3)); %number of modality for specific group & subject
0090             
0091             for mid = 1:length(samp_m)
0092                 
0093                 samp_c=unique(ids(ids(:,1)==samp_g(gid) & ids(:,2)==samp_s(sid) & ids(:,3)==samp_m(mid),4)); %number of conditions for specific group & subject & modality
0094                 
0095                 for cid = 1:length(samp_c)
0096                     
0097                     samp_b=unique(ids(ids(:,1)==samp_g(gid) & ids(:,2)==samp_s(sid) & ids(:,3)==samp_m(mid) & ids(:,4)==samp_c(cid),5));  %number of blocks for specific group & subject & modality & conditions
0098                     
0099                     for bid = 1:length(samp_b)
0100                         
0101                         rg = find((ids(:,1) == samp_g(gid)) & ...
0102                             (ids(:,2) == samp_s(sid)) & ...
0103                             (ids(:,3) == samp_m(mid)) & ...
0104                             (ids(:,4) == samp_c(cid)) & ...
0105                             (ids(:,5) == samp_b(bid)));
0106                         
0107                         chunks{i} =rg;
0108                         
0109                         i=i+1;
0110                     end
0111                 end
0112             end
0113         end
0114     end
0115     
0116     
0117     % Initialize counts
0118     % -------------------------------------------------------------------------
0119     switch PRT.model(modelid).output.fold(1).type
0120         case 'classifier'
0121             n_class = length(PRT.model(modelid).output.fold(1).stats.c_acc);
0122             total_greater_c_acc = zeros(n_class,1);
0123             total_greater_b_acc = 0;
0124             
0125         case 'regression'
0126             total_greater_corr = 0;
0127             total_greater_mse = 0;
0128     end
0129     
0130     % Run model with permuted labels
0131     % -------------------------------------------------------------------------
0132     
0133     for p=1:n_perm
0134         
0135         disp(sprintf('Permutation %d out of %d >>>>>>',p,n_perm));
0136         
0137         % permute labels
0138         chunkperm=randperm(length(chunks));
0139         for i=1:length(chunks)
0140             t(chunks{i},1)= unique(PRT.model(modelid).input.targets(chunks{chunkperm(i)}));
0141         end
0142         
0143         for f = 1:n_folds
0144             % configure data structure for prt_cv_fold
0145             fdata.ID      = ID;
0146             fdata.mid     = modelid;
0147             fdata.CV      = CV(:,f);
0148             fdata.Phi_all = Phi_all;
0149             fdata.t       = t;
0150             
0151             [temp_model, targets] = prt_cv_fold(PRT,fdata);
0152             
0153             model.output.fold(f).predictions = temp_model.predictions;
0154             model.output.fold(f).targets     = targets.test;
0155             
0156         end
0157         
0158         % Model level statistics (across folds)
0159         t             = vertcat(model.output.fold(:).targets);
0160         m.type        = PRT.model(modelid).output.fold(1).type;
0161         m.predictions = vertcat(model.output.fold(:).predictions);
0162         perm_stats         = prt_stats(m,t,t);
0163         
0164         
0165         switch PRT.model(modelid).output.fold(1).type
0166             
0167             case 'classifier'
0168                 
0169                 permutation.b_acc(p)=perm_stats.b_acc;
0170                 n_class = length(PRT.model(modelid).output.fold(1).stats.c_acc);
0171                 
0172                 if (perm_stats.b_acc > PRT.model(modelid).output.stats.b_acc)
0173                     total_greater_b_acc=total_greater_b_acc+1;
0174                 end
0175                 
0176                 for c=1:n_class
0177                     permutation.c_acc(c,p)=perm_stats.c_acc(c);
0178                     if (perm_stats.c_acc(c) > PRT.model(modelid).output.stats.c_acc(c))
0179                         total_greater_c_acc(c)=total_greater_c_acc(c)+1;
0180                     end
0181                 end
0182                 
0183             case 'regression'
0184                 permutation.corr(p)=perm_stats.corr;
0185                 if (abs(perm_stats.corr) > abs(PRT.model(modelid).output.stats.corr))
0186                     total_greater_corr=total_greater_corr+1;
0187                 end
0188                 permutation.mse(p)=perm_stats.mse;
0189                 if (perm_stats.mse < PRT.model(modelid).output.stats.mse)
0190                     total_greater_mse=total_greater_mse+1;
0191                 end
0192                 
0193                 
0194         end
0195         
0196         
0197         
0198         
0199         
0200     end
0201     
0202     
0203     switch PRT.model(modelid).output.fold(1).type
0204         case 'classifier'
0205             
0206             pval_b_acc = total_greater_b_acc / n_perm;
0207             if pval_b_acc == 0
0208                 pval_b_acc = 1./n_perm;
0209             end
0210             
0211             pval_c_acc=zeros(n_class,1);
0212             for c=1:n_class
0213                 pval_c_acc(c) = total_greater_c_acc(c) / n_perm;
0214                 if pval_c_acc(c) == 0
0215                     pval_c_acc(c) = 1./n_perm;
0216                 end
0217             end
0218             
0219             permutation.pvalue_b_acc = pval_b_acc;
0220             permutation.pvalue_c_acc = pval_c_acc;
0221             
0222         case 'regression'
0223             
0224             pval_corr = total_greater_corr / n_perm;
0225             if pval_corr == 0
0226                 pval_corr = 1./n_perm;
0227             end
0228             
0229             pval_mse = total_greater_mse / n_perm;
0230             if pval_mse == 0
0231                 pval_mse = 1./n_perm;
0232             end
0233             
0234             permutation.pval_corr = pval_corr;
0235             permutation.pval_mse = pval_mse;
0236             
0237             
0238     end
0239     
0240     
0241     
0242     %update PRT
0243     PRT.model(modelid).output.stats.permutation = permutation;
0244     
0245     % Save PRT containing machine output
0246     % -------------------------------------------------------------------------
0247     outfile = fullfile(path,'PRT.mat');
0248     disp('Updating PRT.mat.......>>')
0249     if spm_matlab_version_chk('7') >= 0
0250         save(outfile,'-V6','PRT');
0251     else
0252         save(outfile,'PRT');
0253     end
0254     disp('Permutation test done.')
0255 end
0256 
0257 end

Generated on Mon 03-Sep-2012 18:07:18 by m2html © 2005