0001 function prt_plot_nested_cv(PRT, model, fold, axes_handle)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025 logscale = 0;
0026 switch PRT.model(model).input.machine.function
0027 case {'prt_machine_svm_bin','prt_machine_sMKL_cla'}
0028 x_label = 'C';
0029 y_label = 'Balanced Accuracy (%)';
0030
0031
0032 if ~exist('axes_handle', 'var')
0033 figure;
0034 axes_handle = axes('XMinorTick','on');
0035 logscale = 1;
0036 else
0037
0038 cla(axes_handle, 'reset');
0039 set(axes_handle,'XMinorTick','on');
0040 logscale = 1;
0041 end
0042 box(axes_handle,'on');
0043 hold(axes_handle,'all');
0044
0045 case 'prt_machine_sMKL_reg'
0046 x_label = 'Args';
0047 y_label = 'MSE';
0048
0049
0050 if ~exist('axes_handle', 'var')
0051 figure;
0052 axes_handle = axes('XMinorTick','on');
0053 logscale = 1;
0054 else
0055
0056 cla(axes_handle, 'reset');
0057 set(axes_handle, 'XMinorTick','on');
0058 logscale = 1;
0059 end
0060 box(axes_handle,'on');
0061 hold(axes_handle,'all');
0062
0063
0064 case 'prt_machine_krr'
0065 x_label = 'Args';
0066 y_label = 'MSE';
0067
0068
0069 if ~exist('axes_handle', 'var')
0070 figure;
0071 axes_handle = axes;
0072 logscale = 1;
0073 else
0074
0075 cla(axes_handle, 'reset');
0076 logscale = 1;
0077 end
0078
0079
0080
0081 case 'prt_machine_ENMKL'
0082 x_label = 'mu';
0083 y_label = 'C';
0084 z_label = 'Balanced Accuracy (%)';
0085
0086
0087 if ~exist('axes_handle', 'var')
0088 figure;
0089 axes_handle = axes;
0090 else
0091
0092 cla(axes_handle, 'reset');
0093 set(axes_handle, 'XScale','linear', 'XMinorTick','on', 'YMinorTick','on');
0094 logscale = 1;
0095 end
0096
0097 otherwise
0098 error('Machine not currently supported for nested CV');
0099 end
0100
0101
0102 cla(axes_handle)
0103 rotate3d off
0104 set(axes_handle,'Color',[1,1,1])
0105 pos=get(axes_handle,'Position');
0106 set(axes_handle,'Position',[pos(1) pos(2) 0.9*pos(3) pos(4)])
0107
0108
0109
0110
0111 if strcmp(PRT.model(model).input.machine.function, 'prt_machine_ENMKL')
0112
0113 if fold == 1
0114
0115 nfold = length(PRT.model(model).output.fold);
0116
0117
0118 c = unique(PRT.model(model).output.fold(fold).param_effect.param(1,:));
0119 mu = unique(PRT.model(model).output.fold(fold).param_effect.param(2,:));
0120
0121 for i = 1:nfold
0122 f(:,:,i) = PRT.model(model).output.fold(i).param_effect.vary_param;
0123 end
0124
0125 f_mean = mean(f, 3);
0126
0127
0128 f_mean = 100.*f_mean;
0129
0130
0131
0132
0133
0134 axes_handle = image(f_mean, 'CDataMapping', 'scaled', 'XData', mu, 'YData', log10(c));
0135
0136 axes_color = colorbar;
0137 title('Mean')
0138
0139
0140
0141
0142
0143
0144
0145 xlabel(x_label,'FontWeight','bold');
0146 ylabel(y_label,'FontWeight','bold');
0147 ylabel(axes_color, z_label,'FontWeight','bold');
0148
0149
0150
0151
0152
0153
0154
0155
0156
0157
0158
0159
0160
0161
0162
0163
0164
0165
0166
0167
0168
0169
0170
0171
0172
0173
0174
0175
0176
0177
0178
0179
0180
0181
0182
0183
0184
0185
0186
0187
0188
0189
0190
0191
0192
0193
0194
0195
0196
0197 else
0198
0199
0200 c = unique(PRT.model(model).output.fold(fold-1).param_effect.param(1,:));
0201 mu = unique(PRT.model(model).output.fold(fold-1).param_effect.param(2,:));
0202
0203 f = PRT.model(model).output.fold(fold-1).param_effect.vary_param;
0204
0205 f = 100.*f;
0206
0207
0208 axes_handle = image(f, 'CDataMapping', 'scaled', 'XData', mu, 'YData', log10(c));
0209 axes_color = colorbar;
0210
0211
0212 xlabel(x_label,'FontWeight','bold');
0213 ylabel(y_label,'FontWeight','bold');
0214 ylabel(axes_color, z_label,'FontWeight','bold');
0215
0216
0217 end
0218
0219
0220 else
0221
0222
0223 if fold == 1
0224
0225 nfold = length(PRT.model(model).output.fold);
0226
0227
0228 x = PRT.model(model).output.fold(fold).param_effect.param;
0229 f = zeros(nfold, length(x));
0230
0231
0232 for i = 1:nfold
0233 f(i,:) = PRT.model(model).output.fold(i).param_effect.vary_param;
0234
0235 x_opt(i) = PRT.model(model).output.fold(i).param_effect.opt_param;
0236 end
0237
0238 if strcmp(PRT.model(model).input.type, 'classification')
0239 f = 100.*f;
0240 end
0241 f_mean = mean(f);
0242 f_std = std(f);
0243
0244
0245 x_opt = hist(x_opt, x)./size(f,1);
0246
0247
0248
0249 markersize = 10;
0250 switch PRT.model(model).input.type
0251 case 'classification'
0252 f_min = 0;
0253 f_max = 108;
0254 case 'regression'
0255 f_min = min(f(:));
0256 f_max = max(f(:));
0257 otherwise
0258 error('Type of model not recognised');
0259 end
0260
0261
0262
0263 if logscale
0264 x = log10(x);
0265 end
0266
0267 hold on
0268 [hax,hbar,hline] = plotyy(x,x_opt*100,x,mean(f),'bar','plot');
0269 errorbar(axes_handle, x, f_mean, f_std, '.k', 'linewidth', 2);
0270 set(hbar,'BarWidth',0.5,'FaceColor',[0.5 0.5 0.5])
0271 set(hline,'Color','k','Linewidth',1)
0272 set(hax(2),'YColor',[0.2,0.2,0.2])
0273 for i = 1:length(x_opt)
0274 R = x_opt(i);
0275 B = 1-R;
0276 plot(x(i), f_mean(i), 'o', 'markersize', 4, ...
0277 'linewidth', 0.01,'MarkerFaceColor', [R 0 B]);
0278 end
0279 hold off
0280
0281
0282
0283 ylabel(hax(1), y_label,'FontWeight','bold');
0284 ylabel(hax(2),'Frequency of selection (%)','FontWeight','bold');
0285 if logscale
0286 xlabel(axes_handle, [x_label, ' (log 10)'],'FontWeight','bold');
0287 else
0288 xlabel(axes_handle, x_label,'FontWeight','bold');
0289 end
0290 axis(hax(1), [min(x)-0.2*abs(min(x)) max(x)+0.2*abs(max(x)) f_min f_max]);
0291 axis(hax(2), [min(x)-0.2*abs(min(x)) max(x)+0.2*abs(max(x)) f_min f_max]);
0292 set(hax(2),'XTickLabel',{})
0293
0294 else
0295
0296
0297 x = PRT.model(model).output.fold(fold-1).param_effect.param;
0298 f = PRT.model(model).output.fold(fold-1).param_effect.vary_param;
0299 if strcmp(PRT.model(model).input.type, 'classification')
0300 f = 100.*f;
0301 end
0302
0303
0304 switch PRT.model(model).input.type
0305 case 'classification'
0306 x_opt = find(f==max(f));
0307 case 'regression'
0308 x_opt = find(f==min(f));
0309 otherwise
0310 error('Type of model not recognised');
0311 end
0312
0313
0314 markersize = 10;
0315 switch PRT.model(model).input.type
0316 case 'classification'
0317 f_min = 0;
0318 f_max = 108;
0319 case 'regression'
0320 f_min = min(f(:));
0321 f_max = max(f(:));
0322 otherwise
0323 error('Type of model not recognised');
0324 end
0325
0326 if logscale
0327 x = log10(x);
0328 end
0329
0330
0331 hold on
0332 plot(axes_handle, x, f, '-xk', 'markersize', markersize, 'linewidth', 1);
0333
0334 opt_handle = plot(axes_handle, x(x_opt), f(x_opt), 'xr', 'markersize', markersize, 'linewidth', 3);
0335 hold off
0336
0337
0338 if logscale
0339 x_label = [x_label,' (log 10)'];
0340 end
0341 xlabel(axes_handle, x_label,'FontWeight','bold');
0342 ylabel(axes_handle, y_label,'FontWeight','bold');
0343 legend(opt_handle, 'Optimal value(s)');
0344 axis(axes_handle, [min(x) max(x) f_min f_max]);
0345
0346 end
0347
0348 end
0349
0350 end