0001 function [out] = prt_nested_cv(PRT, in)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025 use_nested_cv = PRT.model(in.mid).input.use_nested_cv;
0026 if use_nested_cv == false
0027 error('prt_nested_cv function called with use_nested_cv = false');
0028 end
0029
0030 train_entries = find(in.CV == 1);
0031
0032
0033 in.ID = in.ID(train_entries, :);
0034 in.t = in.t(train_entries);
0035 in.fs = PRT.fs;
0036 if isfield(PRT.model(in.mid).input,'cv_type_nested')
0037 in.cv.type = PRT.model(in.mid).input.cv_type_nested;
0038 in.cv.k = PRT.model(in.mid).input.cv_k_nested;
0039 else
0040 in.cv.type = PRT.model(in.mid).input.cv_type;
0041 in.cv.k = PRT.model(in.mid).input.cv_k;
0042 end
0043
0044 for i=1:length(in.Phi_all)
0045 in.Phi_all{i} = in.Phi_all{i}(train_entries, train_entries);
0046 end
0047
0048
0049 switch PRT.model(in.mid).input.machine.function
0050 case {'prt_machine_svm_bin','prt_machine_sMKL_cla','prt_machine_krr', 'prt_machine_sMKL_reg'}
0051 if ~isempty(PRT.model(in.mid).input.nested_param)
0052 par = PRT.model(in.mid).input.nested_param;
0053 else
0054 d1 = -2 : 3;
0055 par = 10 .^(d1);
0056 beep
0057 warning('No parameter range specified for optimization, using 10^-2 to 10^3')
0058 end
0059 case 'prt_machine_ENMKL'
0060 if ~isempty(PRT.model(in.mid).input.nested_param)
0061
0062 c = PRT.model(in.mid).input.nested_param{1};
0063 mu = PRT.model(in.mid).input.nested_param{2};
0064
0065 [c_mesh,mu_mesh] = meshgrid(c, mu);
0066 par = [c_mesh(:), mu_mesh(:)]';
0067 else
0068 d1 = -2 : 3;
0069 c = 10 .^(d1);
0070 mu = 0:0.1:1;
0071 [c_mesh,mu_mesh] = meshgrid(c, mu);
0072 par = [c_mesh(:), mu_mesh(:)]';
0073 beep
0074 warning('No parameter range specified for C and mu, using 10^-2 to 10^3 and 0 to 1')
0075 end
0076
0077 otherwise
0078 error('Machine not currently supported for nested CV');
0079
0080 end
0081
0082 out.param = par;
0083 stats_vec = zeros(1, size(par, 2));
0084
0085
0086 in.CV = prt_compute_cv_mat(PRT, in, in.mid, use_nested_cv);
0087
0088
0089 for i = 1:size(par, 2)
0090
0091 switch PRT.model(in.mid).input.machine.function
0092 case {'prt_machine_svm_bin','prt_machine_sMKL_cla'}
0093 PRT.model(in.mid).input.machine.args = par(i);
0094 m.type = 'classifier';
0095
0096 case {'prt_machine_krr', 'prt_machine_sMKL_reg'}
0097 PRT.model(in.mid).input.machine.args = par(i);
0098 m.type = 'regression';
0099
0100 case 'prt_machine_ENMKL'
0101 PRT.model(in.mid).input.machine.args = par(:,i)';
0102 m.type = 'classifier';
0103
0104 otherwise
0105 error('Machine not currently supported for nested CV');
0106 end
0107
0108
0109 for f = 1:size(in.CV, 2)
0110
0111 fold.ID = in.ID;
0112 fold.CV = in.CV(:,f);
0113 fold.Phi_all = in.Phi_all;
0114 fold.t = in.t;
0115 fold.mid = in.mid;
0116
0117 [model, targets] = prt_cv_fold(PRT,fold);
0118
0119
0120 if strcmpi(PRT.model(in.mid).input.type,'classification')
0121 if ~all(ismember(unique(targets.test),unique(targets.train)))
0122 beep
0123 disp('At least one class is in the test set but not in the training set')
0124 disp('Abandoning modelling, please correct class selection/cross-validation')
0125 return
0126 end
0127 end
0128
0129
0130 stats = prt_stats(model, targets.test, in.nc);
0131 f_stats(f).targets = targets.test;
0132 f_stats(f).predictions = model.predictions(:);
0133 f_stats(f).stats = stats;
0134
0135
0136 end
0137
0138
0139 ttt = vertcat(f_stats(:).targets);
0140 m.predictions = vertcat(f_stats(:).predictions);
0141 stats = prt_stats(m, ttt(:), in.nc);
0142
0143
0144 switch PRT.model(in.mid).input.type
0145 case 'classification'
0146 stats_vec(i) = stats.b_acc;
0147 case 'regression'
0148 stats_vec(i) = stats.mse;
0149 otherwise
0150 error('Type of model not recognised');
0151 end
0152
0153
0154 end
0155
0156
0157
0158 if strcmp(PRT.model(in.mid).input.machine.function, 'prt_machine_ENMKL')
0159
0160
0161 stats_mat = reshape(stats_vec, length(unique(par(2,:))), length(unique(par(1,:))))';
0162
0163
0164 opt_stats_ind = get_opt_stats_ind(stats_mat, 2, true);
0165 c_max = c(opt_stats_ind(1));
0166 mu_max = mu(opt_stats_ind(2));
0167
0168 out.opt_param = [c_max, mu_max];
0169 out.vary_param = stats_mat;
0170
0171
0172 else
0173
0174 switch PRT.model(in.mid).input.type
0175 case 'classification'
0176 opt_stats_ind = get_opt_stats_ind(stats_vec, 1, true);
0177 case 'regression'
0178 opt_stats_ind = get_opt_stats_ind(stats_vec, 1, false);
0179 otherwise
0180 error('Type of model not recognised');
0181 end
0182
0183 par_opt = par(opt_stats_ind);
0184
0185 out.opt_param = par_opt;
0186 out.vary_param = stats_vec;
0187
0188 end
0189
0190 end
0191
0192
0193
0194
0195
0196
0197 function opt_stats_ind = get_opt_stats_ind(stats, n_par, classification)
0198
0199 switch n_par
0200
0201 case 1
0202 if classification
0203 opt_stats = max(stats);
0204 else
0205 opt_stats = min(stats);
0206 end
0207
0208 ind = find(stats == opt_stats);
0209 opt_stats_ind = round(median(ind));
0210
0211 case 2
0212 if classification
0213 opt_stats = max(max(stats));
0214 else
0215 opt_stats = min(min(stats));
0216 end
0217
0218 [ind_c, ind_mu] = find(stats==opt_stats);
0219
0220 opt_stats_ind(1) = round(median(ind_c));
0221 opt_stats_ind(2) = round(median(ind_mu));
0222
0223 otherwise
0224 error('The number of parameters to optimise must be <=2')
0225 end
0226
0227
0228 end