Home > . > prt_plot_nested_cv.m

prt_plot_nested_cv

PURPOSE ^

FORMAT prt_plot_nested_cv(PRT, model, fold, axes_handle)

SYNOPSIS ^

function prt_plot_nested_cv(PRT, model, fold, axes_handle)

DESCRIPTION ^

 FORMAT prt_plot_nested_cv(PRT, model, fold, axes_handle)

 Plots the results of the nested cv that appear on prt_ui_results.


 Inputs:
       PRT             - data/design/model structure (it needs to contain
                         at least one estimated model).
       model           - the number of the model that will be ploted
       fold            - the number of the fold
       axes_handle     - (Optional) axes where the plot will be displayed

 Output:
       None
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function prt_plot_nested_cv(PRT, model, fold, axes_handle)
0002 % FORMAT prt_plot_nested_cv(PRT, model, fold, axes_handle)
0003 %
0004 % Plots the results of the nested cv that appear on prt_ui_results.
0005 %
0006 %
0007 % Inputs:
0008 %       PRT             - data/design/model structure (it needs to contain
0009 %                         at least one estimated model).
0010 %       model           - the number of the model that will be ploted
0011 %       fold            - the number of the fold
0012 %       axes_handle     - (Optional) axes where the plot will be displayed
0013 %
0014 % Output:
0015 %       None
0016 %__________________________________________________________________________
0017 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0018 
0019 % Written by J. Matos Monteiro
0020 % $Id$
0021 
0022 
0023 
0024 % Check machine and set the labels an axes
0025 logscale = 0;
0026 switch PRT.model(model).input.machine.function
0027     case {'prt_machine_svm_bin','prt_machine_sMKL_cla'}
0028         x_label = 'C';
0029         y_label = 'Balanced Accuracy (%)';
0030         
0031         %If no axes_handle is given, create a new window
0032         if ~exist('axes_handle', 'var')
0033             figure;
0034             axes_handle = axes('XMinorTick','on');
0035             logscale = 1;
0036         else
0037             % Clear EVERYTHING in the UI before defining the axes
0038             cla(axes_handle, 'reset');
0039             set(axes_handle,'XMinorTick','on');
0040             logscale = 1;
0041         end
0042         box(axes_handle,'on');
0043         hold(axes_handle,'all');
0044         
0045     case 'prt_machine_sMKL_reg'
0046         x_label = 'Args';
0047         y_label = 'MSE';
0048         
0049         %If no axes_handle is given, create a new window
0050         if ~exist('axes_handle', 'var')
0051             figure;
0052             axes_handle = axes('XMinorTick','on');
0053             logscale = 1;
0054         else
0055             % Clear EVERYTHING in the UI before defining the axes
0056             cla(axes_handle, 'reset');
0057             set(axes_handle, 'XMinorTick','on');
0058             logscale = 1;
0059         end
0060         box(axes_handle,'on');
0061         hold(axes_handle,'all');
0062         
0063         
0064     case 'prt_machine_krr'
0065         x_label = 'Args';
0066         y_label = 'MSE';
0067         
0068         %If no axes_handle is given, create a new window
0069         if ~exist('axes_handle', 'var')
0070             figure;
0071             axes_handle = axes;
0072             logscale = 1;
0073         else
0074             % Clear EVERYTHING in the UI before defining the axes
0075             cla(axes_handle, 'reset');
0076             logscale = 1;
0077         end
0078         
0079         
0080         
0081     case 'prt_machine_ENMKL'
0082         x_label = 'mu';
0083         y_label = 'C';
0084         z_label = 'Balanced Accuracy (%)';
0085         
0086         % If no axes_handle is given, create a new window
0087         if ~exist('axes_handle', 'var')
0088             figure;
0089             axes_handle = axes;
0090         else
0091             % Clear EVERYTHING in the UI before defining the axes
0092             cla(axes_handle, 'reset');
0093             set(axes_handle, 'XScale','linear', 'XMinorTick','on', 'YMinorTick','on');
0094             logscale = 1;
0095         end
0096         
0097     otherwise
0098         error('Machine not currently supported for nested CV');
0099 end
0100 
0101 
0102 cla(axes_handle)
0103 rotate3d off
0104 set(axes_handle,'Color',[1,1,1])
0105 pos=get(axes_handle,'Position');
0106 set(axes_handle,'Position',[pos(1) pos(2) 0.9*pos(3) pos(4)])
0107 
0108 
0109 
0110 % Check if it's a 2 parameter optimisation problem
0111 if strcmp(PRT.model(model).input.machine.function, 'prt_machine_ENMKL')
0112     
0113     if fold == 1
0114         
0115         nfold = length(PRT.model(model).output.fold);
0116         
0117         % Get all function values
0118         c = unique(PRT.model(model).output.fold(fold).param_effect.param(1,:));
0119         mu = unique(PRT.model(model).output.fold(fold).param_effect.param(2,:));
0120         
0121         for i = 1:nfold
0122             f(:,:,i) = PRT.model(model).output.fold(i).param_effect.vary_param;
0123         end
0124         
0125         f_mean = mean(f, 3);
0126         %         f_std = std(f, 0, 3);
0127         
0128         f_mean = 100.*f_mean;
0129         
0130         % Plot points
0131         
0132         %         subplot(2,1,1);
0133         % TODO: Put Logscale on the Y axis
0134         axes_handle = image(f_mean, 'CDataMapping', 'scaled', 'XData', mu, 'YData', log10(c));
0135         % set(axes_handle,'Yscale','log','Ydir','normal');
0136         axes_color = colorbar;
0137         title('Mean')
0138         %         subplot(2,1,2);
0139         %         axes_handle = image(f_std, 'CDataMapping', 'scaled', 'XData', [min(mu), max(mu)], 'YData', [min(c) max(c)]);
0140         %         title('Standard Deviation')
0141         %         colorbar;
0142         
0143         
0144         % Properties
0145         xlabel(x_label,'FontWeight','bold');
0146         ylabel(y_label,'FontWeight','bold');
0147         ylabel(axes_color, z_label,'FontWeight','bold');
0148         
0149         
0150         
0151         
0152         % TODO: Try to do it this way instead
0153         % Include the str information: http://code.izzid.com/2007/08/19/How-to-make-a-3D-plot-with-errorbars-in-matlab.html
0154         %==================================================================
0155         %         % TODO: Delete these variables
0156         %         d_mean = f_mean;
0157         %         d_std = f_std;
0158         %
0159         %         % convert matrices to vectors
0160         %         f_mean = reshape(f_mean', 1, size(f_mean,1)*size(f_mean,2));
0161         %         f_std = reshape(f_std', 1, size(f_std,1)*size(f_std,2));
0162         %
0163         %         % make mu and x vectors of the same size as f
0164         %         l_mu = length(mu);
0165         %         l_c = length(c);
0166         %         mu = repmat(mu, 1, l_c);
0167         %         c = repmat(c, l_mu, 1);
0168         %         c = reshape(c, 1, length(f_mean));
0169         %
0170         %         rotate3d on
0171         %         hold off
0172         %
0173         %         axes_handle = plot3(mu, c, f_mean, '.k', 'MarkerSize', 25);
0174         %         set(axes_handle, 'YScale','log','YMinorTick','on');
0175         %         %         axes_handle = axes('YScale','log','YMinorTick','on');
0176         %
0177         %         hold on
0178         %         % Draw errorbar for each point
0179         %         for i = length(f_mean)
0180         %             c_error = [c(i); c(i)];
0181         %             mu_error = [mu(i); mu(i)];
0182         %
0183         %             f_mean_min = f_mean(i) + f_std(i);
0184         %             f_mean_max = f_mean(i) - f_std(i);
0185         %             f_mean_error = [f_mean_min; f_mean_max];
0186         %
0187         %             % draw vertical error bar
0188         %             axes_handle = plot3(mu_error, c_error, f_mean_error, '-k','LineWidth', 2);
0189         %
0190         %         end
0191         %
0192         %         %         TODO: Finish this!
0193         %
0194         
0195         %==================================================================
0196         
0197     else
0198         
0199         % Get function values
0200         c = unique(PRT.model(model).output.fold(fold-1).param_effect.param(1,:));
0201         mu = unique(PRT.model(model).output.fold(fold-1).param_effect.param(2,:));
0202         
0203         f = PRT.model(model).output.fold(fold-1).param_effect.vary_param;
0204         
0205         f = 100.*f;
0206         
0207         % Plot points
0208         axes_handle = image(f, 'CDataMapping', 'scaled', 'XData', mu, 'YData', log10(c));
0209         axes_color = colorbar;
0210         
0211         % Properties
0212         xlabel(x_label,'FontWeight','bold');
0213         ylabel(y_label,'FontWeight','bold');
0214         ylabel(axes_color, z_label,'FontWeight','bold');
0215         
0216         
0217     end
0218     
0219     
0220 else % It's a 1 parameter optimisation problem
0221     
0222     
0223     if fold == 1
0224         
0225         nfold = length(PRT.model(model).output.fold);
0226         
0227         % Get function values
0228         x = PRT.model(model).output.fold(fold).param_effect.param;
0229         f = zeros(nfold, length(x));
0230         
0231         % Get mean f values
0232         for i = 1:nfold
0233             f(i,:) = PRT.model(model).output.fold(i).param_effect.vary_param;
0234             % Get the chosen optimal values
0235             x_opt(i) = PRT.model(model).output.fold(i).param_effect.opt_param;
0236         end
0237         
0238         if strcmp(PRT.model(model).input.type, 'classification')
0239             f = 100.*f; % Convert to percentage
0240         end
0241         f_mean = mean(f);
0242         f_std = std(f);
0243         
0244         % get frequencies of optimal values
0245         x_opt = hist(x_opt, x)./size(f,1);
0246         
0247         
0248         % general properties of the plots
0249         markersize = 10;
0250         switch PRT.model(model).input.type
0251             case 'classification'
0252                 f_min = 0;
0253                 f_max = 108;
0254             case 'regression'
0255                 f_min = min(f(:));
0256                 f_max = max(f(:));
0257             otherwise
0258                 error('Type of model not recognised');
0259         end
0260         
0261         
0262         % Plot
0263         if logscale
0264             x = log10(x);
0265         end
0266       
0267         hold on
0268         [hax,hbar,hline] = plotyy(x,x_opt*100,x,mean(f),'bar','plot');
0269         errorbar(axes_handle, x, f_mean, f_std, '.k', 'linewidth', 2);
0270         set(hbar,'BarWidth',0.5,'FaceColor',[0.5 0.5 0.5])
0271         set(hline,'Color','k','Linewidth',1)
0272         set(hax(2),'YColor',[0.2,0.2,0.2])
0273         for i = 1:length(x_opt)
0274             R = x_opt(i);
0275             B = 1-R;
0276             plot(x(i), f_mean(i), 'o', 'markersize', 4, ...
0277                 'linewidth', 0.01,'MarkerFaceColor', [R 0 B]);
0278         end
0279         hold off
0280         
0281         % Properties
0282        
0283         ylabel(hax(1), y_label,'FontWeight','bold');
0284         ylabel(hax(2),'Frequency of selection (%)','FontWeight','bold');
0285         if logscale 
0286              xlabel(axes_handle, [x_label, ' (log 10)'],'FontWeight','bold');
0287         else
0288              xlabel(axes_handle, x_label,'FontWeight','bold');
0289         end
0290         axis(hax(1), [min(x)-0.2*abs(min(x)) max(x)+0.2*abs(max(x)) f_min f_max]);
0291         axis(hax(2), [min(x)-0.2*abs(min(x)) max(x)+0.2*abs(max(x)) f_min f_max]);
0292         set(hax(2),'XTickLabel',{})
0293         
0294     else
0295         
0296         % Get all function values
0297         x = PRT.model(model).output.fold(fold-1).param_effect.param;
0298         f = PRT.model(model).output.fold(fold-1).param_effect.vary_param;
0299         if strcmp(PRT.model(model).input.type, 'classification')
0300             f = 100.*f; % Convert to percentage
0301         end
0302         
0303         % Get optimal function values
0304         switch PRT.model(model).input.type
0305             case 'classification'
0306                 x_opt = find(f==max(f));
0307             case 'regression'
0308                 x_opt = find(f==min(f));
0309             otherwise
0310                 error('Type of model not recognised');
0311         end
0312         
0313         % general properties of the plots
0314         markersize = 10;
0315         switch PRT.model(model).input.type
0316             case 'classification'
0317                 f_min = 0;
0318                 f_max = 108;
0319             case 'regression'
0320                 f_min = min(f(:));
0321                 f_max = max(f(:));
0322             otherwise
0323                 error('Type of model not recognised');
0324         end
0325         
0326         if logscale
0327             x = log10(x);
0328         end
0329         
0330         % Plot all points
0331         hold on
0332         plot(axes_handle, x, f, '-xk', 'markersize', markersize, 'linewidth', 1);
0333         % Plot the optimal on top of the original
0334         opt_handle = plot(axes_handle, x(x_opt), f(x_opt), 'xr', 'markersize', markersize, 'linewidth', 3);
0335         hold off
0336         
0337         % Properties
0338         if logscale
0339             x_label = [x_label,' (log 10)'];
0340         end
0341         xlabel(axes_handle, x_label,'FontWeight','bold');
0342         ylabel(axes_handle, y_label,'FontWeight','bold');
0343         legend(opt_handle, 'Optimal value(s)');
0344         axis(axes_handle, [min(x) max(x) f_min f_max]);
0345         
0346     end
0347     
0348 end
0349 
0350 end

Generated on Tue 10-Feb-2015 18:16:33 by m2html © 2005