Home > . > prt_cv_model.m

prt_cv_model

PURPOSE ^

Function to run a cross-validation structure on a given model

SYNOPSIS ^

function [outfile]=prt_cv_model(PRT,in)

DESCRIPTION ^

 Function to run a cross-validation structure on a given model

 Inputs:
 -------
 PRT:             data structure
 in.fname:        filename for PRT.mat (string)
 in.model_name:   name for this model (string)

 Outputs:
 --------
 Writes the following fields in the PRT data structure:

 PRT.model(m).output.fold(i).targets:     targets for fold(i)
 PRT.model(m).output.fold(i).predictions: predictions for fold(i)
 PRT.model(m).output.fold(i).stats:       statistics for fold(i)
 PRT.model(m).output.fold(i).{custom}:    optional fields

 Notes:
 ------
 The PRT.model(m).input fields are set by prt_init_model, not by
 this function

__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [outfile]=prt_cv_model(PRT,in)
0002 % Function to run a cross-validation structure on a given model
0003 %
0004 % Inputs:
0005 % -------
0006 % PRT:             data structure
0007 % in.fname:        filename for PRT.mat (string)
0008 % in.model_name:   name for this model (string)
0009 %
0010 % Outputs:
0011 % --------
0012 % Writes the following fields in the PRT data structure:
0013 %
0014 % PRT.model(m).output.fold(i).targets:     targets for fold(i)
0015 % PRT.model(m).output.fold(i).predictions: predictions for fold(i)
0016 % PRT.model(m).output.fold(i).stats:       statistics for fold(i)
0017 % PRT.model(m).output.fold(i).{custom}:    optional fields
0018 %
0019 % Notes:
0020 % ------
0021 % The PRT.model(m).input fields are set by prt_init_model, not by
0022 % this function
0023 %
0024 %__________________________________________________________________________
0025 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0026 
0027 % Written by A Marquand
0028 % $Id$
0029 
0030 prt_dir = char(regexprep(in.fname,'PRT.mat', ''));
0031 
0032 % Get index of specified model
0033 mid = prt_init_model(PRT, in);
0034 
0035 % configure some variables
0036 CV       = PRT.model(mid).input.cv_mat;     % CV matrix
0037 n_folds  = size(CV,2);                      % number of CV folds
0038 
0039 % targets
0040 if isfield(PRT.model(mid).input,'include_allscans') && ...
0041         PRT.model(mid).input.include_allscans
0042     t = PRT.model(mid).input.targ_allscans;
0043     % Get covariates if GLM required
0044     if any(ismember(PRT.model(mid).input.operations,5))
0045         cov = PRT.model(mid).input.cov_allscans;
0046     else
0047         cov=[];
0048     end
0049 else
0050     t = PRT.model(mid).input.targets;
0051     % Get covariates if GLM required
0052     if any(ismember(PRT.model(mid).input.operations,5))
0053         cov = PRT.model(mid).input.covar;
0054     else
0055         cov=[];
0056     end
0057 end
0058 
0059 %get number of classes
0060 if strcmpi(PRT.model(mid).input.type,'classification')
0061     nc=max(unique(t));
0062 else
0063     nc=[];
0064 end
0065 fdata.nc = nc;
0066 
0067 %load kernels and get the used sample in this model
0068 [Phi_all,ID] = prt_getKernelModel(PRT,prt_dir,mid);
0069 
0070 
0071 % Begin cross-validation loop
0072 % -------------------------------------------------------------------------
0073 PRT.model(mid).output=struct();
0074 PRT.model(mid).output.fold = struct();
0075 for f = 1:n_folds
0076     disp ([' > running CV fold: ',num2str(f),' of ',num2str(n_folds),' ...'])
0077     % configure data structure for prt_cv_fold
0078     fdata.ID      = ID;
0079     fdata.mid     = mid; %index of model
0080     fdata.CV      = CV(:,f);
0081     fdata.Phi_all = Phi_all; %kernel(s)
0082     fdata.t       = t; %targets
0083     if ~isempty(cov)
0084         fdata.cov = cov;
0085     end
0086     
0087     % Nested CV for hyper-parameter optimisation or feature selection
0088     if isfield(PRT.model(mid).input,'use_nested_cv')
0089         if PRT.model(mid).input.use_nested_cv
0090             [out] = prt_nested_cv(PRT, fdata);
0091             PRT.model(mid).output.fold(f).param_effect = out;
0092             PRT.model(mid).input.machine.args = out.opt_param;
0093         end
0094     end
0095     
0096     % compute the model for this CV fold
0097     [model, targets] = prt_cv_fold(PRT,fdata);
0098     
0099     %for classification check that for each fold, the test targets have been trained
0100     if strcmpi(PRT.model(mid).input.type,'classification')
0101         if ~all(ismember(unique(targets.test),unique(targets.train)))
0102             beep
0103             disp('At least one class is in the test set but not in the training set')
0104             disp('Abandoning modelling, please correct class selection/cross-validation')
0105             return
0106         end
0107     end
0108     
0109     % compute stats
0110     stats = prt_stats(model, targets.test, nc); %targets.train
0111     
0112     % update PRT
0113     PRT.model(mid).output.fold(f).targets     = targets.test;
0114     PRT.model(mid).output.fold(f).predictions = model.predictions(:);
0115     PRT.model(mid).output.fold(f).stats       = stats;
0116     % copy other fields from the model
0117     flds = fieldnames(model);
0118     for fld = 1:length(flds)
0119         fldnm = char(flds(fld));
0120         if ~strcmpi(fldnm,'predictions')
0121             PRT.model(mid).output.fold(f).(fldnm)=model.(fldnm);
0122         end
0123     end
0124 end
0125 
0126 
0127 % Model level statistics (across folds)
0128 ttt             = vertcat(PRT.model(mid).output.fold(:).targets);
0129 m.type        = PRT.model(mid).output.fold(1).type;
0130 m.predictions = vertcat(PRT.model(mid).output.fold(:).predictions);
0131 %m.func_val    = [PRT.model(mid).output.fold(:).func_val];
0132 stats         = prt_stats(m,ttt(:),nc);
0133 
0134 PRT.model(mid).output.stats=stats;
0135 
0136 
0137 % Save PRT containing machine output
0138 % -------------------------------------------------------------------------
0139 outfile = [prt_dir, filesep,'PRT.mat'];
0140 disp('Updating PRT.mat.......>>')
0141 if spm_check_version('MATLAB','7') < 0
0142     save(outfile,'-V6','PRT');
0143 else
0144     save(outfile,'PRT');
0145 end
0146 end
0147 
0148

Generated on Tue 10-Feb-2015 18:16:33 by m2html © 2005