Home > . > prt_cv_fold.m

prt_cv_fold

PURPOSE ^

Function to run a single cross-validation fold

SYNOPSIS ^

function [model, targets] = prt_cv_fold(PRT, in)

DESCRIPTION ^

 Function to run a single cross-validation fold 

 Inputs:
 -------
 PRT:           data structure
 in.mid:        index to the model we are working on
 in.ID:         ID matrix
 in.CV:         Cross-validation matrix (current fold only)
 in.Phi_all:    Cell array of data matri(ces) (training and test)
 in.t           prediction targets

 Outputs:
 --------
 model:         the model returned by the machine
 targets.train: training targets
 targets.test:  test targets

 Notes: 
 ------
 The training and test targets output byt this function are not
 necessarily equivalent to the targets that are supplied to the function.
 e.g. some data operations can modify the number of samples (e.g. sample
 averaging). In such cases size(targets.train) ~= size(in.t)

__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [model, targets] = prt_cv_fold(PRT, in)
0002 % Function to run a single cross-validation fold
0003 %
0004 % Inputs:
0005 % -------
0006 % PRT:           data structure
0007 % in.mid:        index to the model we are working on
0008 % in.ID:         ID matrix
0009 % in.CV:         Cross-validation matrix (current fold only)
0010 % in.Phi_all:    Cell array of data matri(ces) (training and test)
0011 % in.t           prediction targets
0012 %
0013 % Outputs:
0014 % --------
0015 % model:         the model returned by the machine
0016 % targets.train: training targets
0017 % targets.test:  test targets
0018 %
0019 % Notes:
0020 % ------
0021 % The training and test targets output byt this function are not
0022 % necessarily equivalent to the targets that are supplied to the function.
0023 % e.g. some data operations can modify the number of samples (e.g. sample
0024 % averaging). In such cases size(targets.train) ~= size(in.t)
0025 %
0026 %__________________________________________________________________________
0027 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0028 
0029 % Written by A Marquand
0030 % $Id: prt_cv_fold.m 522 2012-05-08 22:06:15Z amarquan $
0031 
0032 tr_idx = in.CV == 1;
0033 te_idx = in.CV == 2;
0034 
0035 [Phi_tr, Phi_te, Phi_tt] = ...
0036     split_data(in.Phi_all, tr_idx, te_idx, PRT.model(in.mid).input.use_kernel);
0037 
0038 % Assemble data structure to supply to machine
0039 cvdata.train      = Phi_tr;
0040 cvdata.test       = Phi_te;
0041 if PRT.model(in.mid).input.use_kernel
0042     cvdata.testcov    = Phi_tt;
0043 end
0044 
0045 % configure basic CV parameters
0046 cvdata.tr_targets = in.t(tr_idx,:);
0047 cvdata.te_targets = in.t(te_idx,:);
0048 cvdata.tr_id      = in.ID(tr_idx,:);
0049 cvdata.te_id      = in.ID(te_idx,:);
0050 cvdata.use_kernel = PRT.model(in.mid).input.use_kernel;
0051 cvdata.pred_type  = PRT.model(in.mid).input.type;
0052 
0053 % configure additional CV parameters (e.g. needed to compute a GLM)
0054 cvdata.tr_param = prt_cv_opt_param(PRT, in.ID(tr_idx,:), in.mid);
0055 cvdata.te_param = prt_cv_opt_param(PRT, in.ID(te_idx,:), in.mid);
0056 
0057 % Apply any operations specified
0058 ops = PRT.model(in.mid).input.operations(PRT.model(in.mid).input.operations ~=0 );
0059 for o = 1:length(ops)
0060     cvdata = prt_apply_operation(PRT, cvdata, ops(o));
0061 end
0062 
0063 % train the prediction model
0064 try
0065     model = prt_machine(cvdata, PRT.model(in.mid).input.machine);
0066 catch err
0067     warning('prt_cv_fold:modelDidNotReturn',...
0068         'Prediction method did not return [%s]',err.message);
0069     model.predictions = zeros(size(cvdata.te_targets));
0070 end
0071 
0072 % check that it produced a predictions field
0073 if ~any(strcmpi(fieldnames(model),'predictions'))
0074     error(['prt_cv_model:machineDoesNotGivePredictions',...
0075         'Machine did not produce a predictions field']);
0076 end
0077 
0078 % does the model alter the target vector (e.g. change its dimension) ?
0079 if isfield(model,'te_targets')
0080     targets.test = model.te_targets(:);
0081 else
0082     targets.test = cvdata.te_targets(:);
0083 end
0084 if isfield(model,'tr_targets')
0085     targets.train = model.tr_targets(:);
0086 else
0087     targets.train= cvdata.tr_targets(:);
0088 end
0089 
0090 end
0091 
0092 % -------------------------------------------------------------------------
0093 % Private functions
0094 % -------------------------------------------------------------------------
0095         
0096 function [Phi_tr Phi_te Phi_tt] = split_data(Phi_all, tr_idx, te_idx, usebf)
0097 % function to split the data matrix into training and test
0098 
0099 n_mat = length(Phi_all);
0100 
0101 % training
0102 Phi_tr = cell(1,n_mat);
0103 for i = 1:n_mat;
0104     if usebf
0105         cols_tr = tr_idx;
0106     else
0107         cols_tr = size(Phi_all{i},2);
0108     end
0109     
0110     Phi_tr{i} = Phi_all{i}(tr_idx,cols_tr);
0111 end
0112 
0113 % test
0114 Phi_te  = cell(1,n_mat);
0115 Phi_tt = cell(1,n_mat);
0116 if usebf
0117     cols_tr = tr_idx;
0118     cols_te = te_idx;
0119 else
0120     cols_tr = size(Phi_all{i},2);
0121     %cols_te = size(Phi_all{i},2);
0122 end
0123 
0124 for i = 1:length(Phi_all)
0125     Phi_te{i} = Phi_all{i}(te_idx, cols_tr);
0126     if usebf
0127         Phi_tt{i} = Phi_all{i}(te_idx, cols_te);
0128     else
0129         Phi_tt{i} = [];
0130     end
0131 end
0132 end

Generated on Mon 03-Sep-2012 18:07:18 by m2html © 2005