Home > . > prt_cv_fold.m

prt_cv_fold

PURPOSE ^

Function to run a single cross-validation fold

SYNOPSIS ^

function [model, targets] = prt_cv_fold(PRT, in)

DESCRIPTION ^

 Function to run a single cross-validation fold 

 Inputs:
 -------
 PRT:           data structure
 in.mid:        index to the model we are working on
 in.ID:         ID matrix
 in.CV:         Cross-validation matrix (current fold only)
 in.Phi_all:    Cell array of data matri(ces) (training and test)
 in.t           prediction targets

 Outputs:
 --------
 model:         the model returned by the machine
 targets.train: training targets
 targets.test:  test targets

 Notes: 
 ------
 The training and test targets output byt this function are not
 necessarily equivalent to the targets that are supplied to the function.
 e.g. some data operations can modify the number of samples (e.g. sample
 averaging). In such cases size(targets.train) ~= size(in.t)

__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [model, targets] = prt_cv_fold(PRT, in)
0002 % Function to run a single cross-validation fold
0003 %
0004 % Inputs:
0005 % -------
0006 % PRT:           data structure
0007 % in.mid:        index to the model we are working on
0008 % in.ID:         ID matrix
0009 % in.CV:         Cross-validation matrix (current fold only)
0010 % in.Phi_all:    Cell array of data matri(ces) (training and test)
0011 % in.t           prediction targets
0012 %
0013 % Outputs:
0014 % --------
0015 % model:         the model returned by the machine
0016 % targets.train: training targets
0017 % targets.test:  test targets
0018 %
0019 % Notes:
0020 % ------
0021 % The training and test targets output byt this function are not
0022 % necessarily equivalent to the targets that are supplied to the function.
0023 % e.g. some data operations can modify the number of samples (e.g. sample
0024 % averaging). In such cases size(targets.train) ~= size(in.t)
0025 %
0026 %__________________________________________________________________________
0027 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0028 
0029 % Written by A Marquand
0030 % $Id$
0031 
0032 tr_idx = in.CV == 1;
0033 te_idx = in.CV == 2;
0034 
0035 [Phi_tr, Phi_te, Phi_tt] = ...
0036     split_data(in.Phi_all, tr_idx, te_idx, PRT.model(in.mid).input.use_kernel);
0037 
0038 % Assemble data structure to supply to machine
0039 cvdata.train      = Phi_tr;
0040 cvdata.test       = Phi_te;
0041 if PRT.model(in.mid).input.use_kernel
0042     cvdata.testcov    = Phi_tt;
0043 end
0044 
0045 % configure basic CV parameters
0046 cvdata.tr_targets = in.t(tr_idx,:);
0047 cvdata.te_targets = in.t(te_idx,:);
0048 cvdata.tr_id      = in.ID(tr_idx,:);
0049 cvdata.te_id      = in.ID(te_idx,:);
0050 cvdata.use_kernel = PRT.model(in.mid).input.use_kernel;
0051 cvdata.pred_type  = PRT.model(in.mid).input.type;
0052 
0053 % configure additional CV parameters (e.g. needed to compute a GLM)
0054 cvdata.tr_param = prt_cv_opt_param(PRT, in.ID(tr_idx,:), in.mid);
0055 cvdata.te_param = prt_cv_opt_param(PRT, in.ID(te_idx,:), in.mid);
0056 
0057 % Apply any operations specified
0058 ops = PRT.model(in.mid).input.operations(PRT.model(in.mid).input.operations ~=0 );
0059 for o = 1:length(ops)
0060     if any(ismember(ops,5))
0061         cvdata.tr_cov = in.cov(tr_idx,:);
0062         cvdata.te_cov = in.cov(te_idx,:);
0063     end
0064     cvdata = prt_apply_operation(PRT, cvdata, ops(o));
0065 end
0066 
0067 % train the prediction model
0068 try
0069     model = prt_machine(cvdata, PRT.model(in.mid).input.machine);
0070 catch err
0071     warning('prt_cv_fold:modelDidNotReturn',...
0072         'Prediction method did not return [%s]',err.message);
0073     model.predictions = zeros(size(cvdata.te_targets));
0074 end
0075 
0076 % check that it produced a predictions field
0077 if ~any(strcmpi(fieldnames(model),'predictions'))
0078     error(['prt_cv_model:machineDoesNotGivePredictions',...
0079         'Machine did not produce a predictions field']);
0080 end
0081 
0082 % does the model alter the target vector (e.g. change its dimension) ?
0083 if isfield(model,'te_targets')
0084     targets.test = model.te_targets(:);
0085 else
0086     targets.test = cvdata.te_targets(:);
0087 end
0088 if isfield(model,'tr_targets')
0089     targets.train = model.tr_targets(:);
0090 else
0091     targets.train= cvdata.tr_targets(:);
0092 end
0093 
0094 end
0095 
0096 % -------------------------------------------------------------------------
0097 % Private functions
0098 % -------------------------------------------------------------------------
0099         
0100 function [Phi_tr Phi_te Phi_tt] = split_data(Phi_all, tr_idx, te_idx, usebf)
0101 % function to split the data matrix into training and test
0102 
0103 n_mat = length(Phi_all);
0104 
0105 % training
0106 Phi_tr = cell(1,n_mat);
0107 for i = 1:n_mat;
0108     if usebf
0109         cols_tr = tr_idx;
0110     else
0111         cols_tr = size(Phi_all{i},2);
0112     end
0113     
0114     Phi_tr{i} = Phi_all{i}(tr_idx,cols_tr);
0115 
0116 end
0117 
0118 % test
0119 Phi_te  = cell(1,n_mat);
0120 Phi_tt = cell(1,n_mat);
0121 if usebf
0122     cols_tr = tr_idx;
0123     cols_te = te_idx;
0124 else
0125     cols_tr = size(Phi_all{i},2);
0126     %cols_te = size(Phi_all{i},2);
0127 end
0128 
0129 for i = 1:length(Phi_all)
0130     Phi_te{i} = Phi_all{i}(te_idx, cols_tr);
0131     if usebf
0132         Phi_tt{i} = Phi_all{i}(te_idx, cols_te);
0133     else
0134         Phi_tt{i} = [];
0135     end
0136 end
0137 end

Generated on Tue 10-Feb-2015 18:16:33 by m2html © 2005