Home > . > prt_nested_cv.m

prt_nested_cv

PURPOSE ^

Function to perform the nested CV

SYNOPSIS ^

function [out] = prt_nested_cv(PRT, in)

DESCRIPTION ^

 Function to perform the nested CV

 Inputs:
 -------
   in.nc:          number of classes
   in.ID:          ID matrix
   in.mid:         model id
   in.CV:          cross-validation matrix
   in.Phi_all:     Kernel

 Outputs:
 --------
   out.opt_param:  optimal hyper-parameter choosen using the stats from
                   the inner CVs
   out.vary_param: stats values associated with all the hyper-parameters
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [out] = prt_nested_cv(PRT, in)
0002 % Function to perform the nested CV
0003 %
0004 % Inputs:
0005 % -------
0006 %   in.nc:          number of classes
0007 %   in.ID:          ID matrix
0008 %   in.mid:         model id
0009 %   in.CV:          cross-validation matrix
0010 %   in.Phi_all:     Kernel
0011 %
0012 % Outputs:
0013 % --------
0014 %   out.opt_param:  optimal hyper-parameter choosen using the stats from
0015 %                   the inner CVs
0016 %   out.vary_param: stats values associated with all the hyper-parameters
0017 %__________________________________________________________________________
0018 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0019 
0020 % Written by J.M. Monteiro
0021 % $Id$
0022 
0023 
0024 % Set flag
0025 use_nested_cv = PRT.model(in.mid).input.use_nested_cv;
0026 if use_nested_cv == false
0027     error('prt_nested_cv function called with use_nested_cv = false');
0028 end
0029 
0030 train_entries = find(in.CV == 1);
0031 
0032 % Change fdata
0033 in.ID      = in.ID(train_entries, :);
0034 in.t       = in.t(train_entries);
0035 in.fs      = PRT.fs;
0036 if isfield(PRT.model(in.mid).input,'cv_type_nested')
0037     in.cv.type = PRT.model(in.mid).input.cv_type_nested;
0038     in.cv.k = PRT.model(in.mid).input.cv_k_nested;
0039 else
0040     in.cv.type = PRT.model(in.mid).input.cv_type;
0041     in.cv.k = PRT.model(in.mid).input.cv_k;
0042 end
0043 
0044 for i=1:length(in.Phi_all)
0045     in.Phi_all{i} = in.Phi_all{i}(train_entries, train_entries);
0046 end
0047 
0048 % Set range of the hyper parameters
0049 switch PRT.model(in.mid).input.machine.function
0050     case {'prt_machine_svm_bin','prt_machine_sMKL_cla','prt_machine_krr', 'prt_machine_sMKL_reg'}
0051         if ~isempty(PRT.model(in.mid).input.nested_param)
0052             par = PRT.model(in.mid).input.nested_param;
0053         else
0054             d1 = -2 : 3;
0055             par = 10 .^(d1);
0056             beep
0057             warning('No parameter range specified for optimization, using 10^-2 to 10^3')
0058         end
0059     case 'prt_machine_ENMKL'
0060         if ~isempty(PRT.model(in.mid).input.nested_param)
0061             % Get parameter ranges from PRT
0062             c = PRT.model(in.mid).input.nested_param{1};
0063             mu = PRT.model(in.mid).input.nested_param{2};
0064             % Convert them to a matrix with all the combinations
0065             [c_mesh,mu_mesh] = meshgrid(c, mu);
0066             par = [c_mesh(:), mu_mesh(:)]';
0067         else
0068             d1 = -2 : 3;
0069             c = 10 .^(d1);
0070             mu = 0:0.1:1;
0071             [c_mesh,mu_mesh] = meshgrid(c, mu);
0072             par = [c_mesh(:), mu_mesh(:)]';
0073             beep
0074             warning('No parameter range specified for C and mu, using 10^-2 to 10^3 and 0 to 1')
0075         end
0076         
0077     otherwise
0078         error('Machine not currently supported for nested CV');
0079         
0080 end
0081 
0082 out.param = par;
0083 stats_vec = zeros(1, size(par, 2));
0084 
0085 % generate new CV matrix
0086 in.CV = prt_compute_cv_mat(PRT, in, in.mid, use_nested_cv);
0087 
0088 % compute model performance based on hyper-parameter range
0089 for i = 1:size(par, 2)
0090     
0091     switch PRT.model(in.mid).input.machine.function
0092         case {'prt_machine_svm_bin','prt_machine_sMKL_cla'}
0093             PRT.model(in.mid).input.machine.args = par(i);
0094             m.type = 'classifier';
0095             
0096         case {'prt_machine_krr', 'prt_machine_sMKL_reg'}
0097             PRT.model(in.mid).input.machine.args = par(i);
0098             m.type = 'regression';
0099             
0100         case 'prt_machine_ENMKL'
0101             PRT.model(in.mid).input.machine.args = par(:,i)';
0102             m.type = 'classifier';
0103             
0104         otherwise
0105             error('Machine not currently supported for nested CV');
0106     end
0107     
0108     % compute the model for each fold of the inner CV
0109     for f = 1:size(in.CV, 2)
0110         
0111         fold.ID      = in.ID;
0112         fold.CV      = in.CV(:,f);
0113         fold.Phi_all = in.Phi_all;
0114         fold.t       = in.t;
0115         fold.mid     = in.mid;
0116         
0117         [model, targets] = prt_cv_fold(PRT,fold);
0118         
0119         %for classification check that for each fold, the test targets have been trained
0120         if strcmpi(PRT.model(in.mid).input.type,'classification')
0121             if ~all(ismember(unique(targets.test),unique(targets.train)))
0122                 beep
0123                 disp('At least one class is in the test set but not in the training set')
0124                 disp('Abandoning modelling, please correct class selection/cross-validation')
0125                 return
0126             end
0127         end
0128         
0129         % Compute stats
0130         stats = prt_stats(model, targets.test, in.nc);
0131         f_stats(f).targets     = targets.test;
0132         f_stats(f).predictions = model.predictions(:);
0133         f_stats(f).stats       = stats;
0134         
0135         
0136     end
0137     
0138     % Model level statistics (across folds)
0139     ttt           = vertcat(f_stats(:).targets);
0140     m.predictions = vertcat(f_stats(:).predictions);
0141     stats         = prt_stats(m, ttt(:), in.nc);
0142     
0143     
0144     switch PRT.model(in.mid).input.type
0145         case 'classification'
0146             stats_vec(i) = stats.b_acc;
0147         case 'regression'
0148             stats_vec(i) = stats.mse;
0149         otherwise
0150             error('Type of model not recognised');
0151     end
0152     
0153     
0154 end
0155 
0156 % For now, only parameter optimisation. Add flag for feature selection
0157 % Get optimal parameter
0158 if strcmp(PRT.model(in.mid).input.machine.function, 'prt_machine_ENMKL')
0159     
0160     % Reshape the stats vector into a matrix
0161     stats_mat = reshape(stats_vec, length(unique(par(2,:))), length(unique(par(1,:))))';
0162     
0163     % Find max
0164     opt_stats_ind = get_opt_stats_ind(stats_mat, 2, true);
0165     c_max = c(opt_stats_ind(1));
0166     mu_max = mu(opt_stats_ind(2));
0167     
0168     out.opt_param = [c_max, mu_max];
0169     out.vary_param = stats_mat;
0170     
0171     
0172 else
0173     
0174     switch PRT.model(in.mid).input.type
0175         case 'classification'
0176             opt_stats_ind = get_opt_stats_ind(stats_vec, 1, true);
0177         case 'regression'
0178             opt_stats_ind = get_opt_stats_ind(stats_vec, 1, false);
0179         otherwise
0180             error('Type of model not recognised');
0181     end
0182     
0183     par_opt = par(opt_stats_ind);
0184     
0185     out.opt_param = par_opt;
0186     out.vary_param = stats_vec;
0187     
0188 end
0189 
0190 end
0191 
0192 
0193 
0194 % -------------------------------------------------------------------------
0195 % Private functions
0196 % -------------------------------------------------------------------------
0197 function opt_stats_ind = get_opt_stats_ind(stats, n_par, classification)
0198 
0199 switch n_par
0200     
0201     case 1
0202         if classification
0203             opt_stats = max(stats);
0204         else
0205             opt_stats = min(stats);
0206         end
0207         
0208         ind = find(stats == opt_stats);
0209         opt_stats_ind = round(median(ind));
0210         
0211     case 2
0212         if classification
0213             opt_stats = max(max(stats));
0214         else
0215             opt_stats = min(min(stats));
0216         end
0217         
0218         [ind_c, ind_mu] = find(stats==opt_stats);
0219         
0220         opt_stats_ind(1) = round(median(ind_c));
0221         opt_stats_ind(2) = round(median(ind_mu));
0222         
0223     otherwise
0224         error('The number of parameters to optimise must be <=2')
0225 end
0226 
0227 
0228 end

Generated on Tue 10-Feb-2015 18:16:33 by m2html © 2005