Home > . > prt_cv_model.m

prt_cv_model

PURPOSE ^

Function to run a cross-validation structure on a given model

SYNOPSIS ^

function [outfile]=prt_cv_model(PRT,in)

DESCRIPTION ^

 Function to run a cross-validation structure on a given model 

 Inputs:
 -------
 PRT containing the specified model plus the following arguments:
 in.fname:      filename for PRT.mat (string)
 in.model_name: name for this model (string)

 Outputs:
 --------
 Writes the following fields in the PRT data structure:
 
 PRT.model(m).output.fold(i).targets:     targets for fold(i)
 PRT.model(m).output.fold(i).predictions: predictions for fold(i)
 PRT.model(m).output.fold(i).stats:       statistics for fold(i)
 PRT.model(m).output.fold(i).{custom}:    optional fields

 Notes: - The PRT.model(m).input fields are set by prt_init_model, not by
          this function
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [outfile]=prt_cv_model(PRT,in)
0002 % Function to run a cross-validation structure on a given model
0003 %
0004 % Inputs:
0005 % -------
0006 % PRT containing the specified model plus the following arguments:
0007 % in.fname:      filename for PRT.mat (string)
0008 % in.model_name: name for this model (string)
0009 %
0010 % Outputs:
0011 % --------
0012 % Writes the following fields in the PRT data structure:
0013 %
0014 % PRT.model(m).output.fold(i).targets:     targets for fold(i)
0015 % PRT.model(m).output.fold(i).predictions: predictions for fold(i)
0016 % PRT.model(m).output.fold(i).stats:       statistics for fold(i)
0017 % PRT.model(m).output.fold(i).{custom}:    optional fields
0018 %
0019 % Notes: - The PRT.model(m).input fields are set by prt_init_model, not by
0020 %          this function
0021 %__________________________________________________________________________
0022 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0023 
0024 % Written by A Marquand
0025 % $Id: prt_cv_model.m 522 2012-05-08 22:06:15Z schrouff $
0026 
0027 prt_dir = char(regexprep(in.fname,'PRT.mat', ''));
0028 
0029 % Get index of specified model
0030 mid = prt_init_model(PRT, in);
0031 
0032 % configure some variables
0033 CV       = PRT.model(mid).input.cv_mat;     % CV matrix
0034 n_folds  = size(CV,2);                      % number of CV folds
0035 n_Phi    = length(PRT.model(mid).input.fs); % number of data matrices
0036 samp_idx = PRT.model(mid).input.samp_idx;   % which samples are in the model
0037 
0038 % targets
0039 if isfield(PRT.model(mid).input,'include_allscans') && ...
0040    PRT.model(mid).input.include_allscans
0041     t = PRT.model(mid).input.targ_allscans;
0042 else
0043     t = PRT.model(mid).input.targets;
0044 end
0045 
0046 % load data files and configure ID matrix
0047 disp('Loading data files.....>>');
0048 Phi_all = cell(1,n_Phi);
0049 for i = 1:length(PRT.model(mid).input.fs)
0050     fid = prt_init_fs(PRT, PRT.model(mid).input.fs(i));
0051     
0052     if i == 1
0053         ID = PRT.fs(fid).id_mat(PRT.model(mid).input.samp_idx,:);
0054     end
0055         
0056     if PRT.model(mid).input.use_kernel
0057         load(fullfile(prt_dir, PRT.fs(fid).k_file));
0058         Phi_all{i} = Phi(samp_idx,samp_idx);
0059     else
0060         error('training with features not implemented yet');
0061         % this should be improved (e.g. need to load feat_idx)
0062         vname = whos('-file', [prt_dir,PRT.fs(fid).fs_file]);
0063         eval(['Phi_all{',num2str(i),'}=',vname,'(samp_idx,:);']);
0064     end
0065 end
0066 
0067 % Begin cross-validation loop
0068 % -------------------------------------------------------------------------
0069 
0070 PRT.model(mid).output.fold = struct();
0071 for f = 1:n_folds
0072     disp ([' > running CV fold: ',num2str(f),' of ',num2str(n_folds),' ...'])
0073     % configure training and test indices (validation is done later)
0074     tr_idx = CV(:,f) == 1;
0075     te_idx = CV(:,f) == 2;   
0076     
0077     [Phi_tr, Phi_te, Phi_tt] = ...
0078         split_data(Phi_all, tr_idx, te_idx, PRT.model(mid).input.use_kernel);
0079  
0080     % Assemble data structure to supply to machine
0081     cvdata.train      = Phi_tr;
0082     cvdata.test       = Phi_te;
0083     if PRT.model(mid).input.use_kernel
0084         cvdata.testcov    = Phi_tt;
0085     end
0086 
0087     % configure basic CV parameters
0088     cvdata.tr_targets = t(tr_idx,:);
0089     %if KRR, then mean center the targets
0090     if ~isempty(strfind(PRT.model(mid).input.machine.function,'krr'))
0091         mm=mean(t(tr_idx,:));
0092         cvdata.tr_targets = cvdata.tr_targets-mm;
0093     end
0094     cvdata.te_targets = t(te_idx,:);
0095     cvdata.tr_id      = ID(tr_idx,:);
0096     cvdata.te_id      = ID(te_idx,:);
0097     cvdata.use_kernel = PRT.model(mid).input.use_kernel;
0098     cvdata.pred_type  = PRT.model(mid).input.type;
0099     
0100     % configure additional CV parameters (e.g. needed to compute a GLM)
0101     cvdata.tr_param = prt_cv_opt_param(PRT, ID(tr_idx,:), mid);
0102     cvdata.te_param = prt_cv_opt_param(PRT, ID(te_idx,:), mid);
0103 
0104     % Apply any operations specified
0105     ops = PRT.model(mid).input.operations(PRT.model(mid).input.operations ~=0 );
0106     for o = 1:length(ops)
0107         cvdata = prt_apply_operation(PRT, cvdata, ops(o));
0108     end
0109     
0110     % train the prediction model
0111     model = prt_machine(cvdata, PRT.model(mid).input.machine);
0112     
0113     % check that it produced a predictions field
0114     if ~any(strcmpi(fieldnames(model),'predictions'))
0115         error(['prt_cv_model:machineDoesNotGivePredictions',...
0116             'Machine did not produce a predictions field']);
0117     end  
0118     
0119     % does the model alter the target vector (e.g. change its dimension) ?
0120     if isfield(model,'te_targets')
0121         true_te_targets = model.te_targets(:);
0122     else
0123         true_te_targets = cvdata.te_targets(:);
0124     end
0125     if isfield(model,'tr_targets')
0126         tr_targets = model.tr_targets(:);
0127     else
0128         if ~isempty(strfind(PRT.model(mid).input.machine.function,'krr'))
0129             tr_targets = cvdata.tr_targets(:)+mm;
0130         else
0131             tr_targets = cvdata.tr_targets(:);
0132         end       
0133     end
0134     
0135     if ~isempty(strfind(PRT.model(mid).input.machine.function,'krr'))
0136         %add the mean of the training set to the test outputs of KRR
0137         model.predictions=model.predictions+mm;
0138         model.func_val=model.func_val+mm;
0139     end
0140     
0141     % compute stats
0142     stats = prt_stats(model, true_te_targets, tr_targets);
0143     
0144     % update PRT - ensuring column vectors throughout
0145     PRT.model(mid).output.fold(f).targets     = true_te_targets; 
0146     PRT.model(mid).output.fold(f).predictions = model.predictions(:);
0147     PRT.model(mid).output.fold(f).stats       = stats;
0148     % save func_val for later analysis if available
0149     if isfield(model,'func_val')
0150         PRT.model(mid).output.fold(f).func_val    = model.func_val; 
0151     end
0152     
0153     % copy other fields from the model
0154     flds = fieldnames(model);
0155     for fld = 1:length(flds)
0156         fldnm = char(flds(fld));
0157         if ~strcmpi(fldnm,'predictions')
0158             %eval(['PRT.model(mid).output.fold(f).',fldnm,'=model.',fldnm,';']);
0159             PRT.model(mid).output.fold(f).(fldnm)=model.(fldnm);
0160         end
0161     end
0162 end
0163 
0164 
0165 % Model level statistics (across folds)
0166 t             = vertcat(PRT.model(mid).output.fold(:).targets);
0167 m.type        = PRT.model(mid).output.fold(1).type;
0168 m.predictions = vertcat(PRT.model(mid).output.fold(:).predictions);
0169 %m.func_val=[PRT.model(mid).output.fold(:).func_val];
0170 stats         = prt_stats(m,t(:),t(:));
0171 
0172 PRT.model(mid).output.stats=stats;
0173 
0174 % Save PRT containing machine output
0175 % -------------------------------------------------------------------------
0176 outfile = [prt_dir, 'PRT'];
0177 disp('Updating PRT.mat.......>>')
0178 if spm_matlab_version_chk('7') >= 0
0179     save(outfile,'-V6','PRT');
0180 else
0181     save(outfile,'PRT');
0182 end
0183 end
0184 
0185 % -------------------------------------------------------------------------
0186 % Private functions
0187 % -------------------------------------------------------------------------
0188         
0189 function [Phi_tr Phi_te Phi_tt] = split_data(Phi_all, tr_idx, te_idx, usebf)
0190 % function to split the data matrix into training and test
0191 
0192 n_mat = length(Phi_all);
0193 
0194 % training
0195 Phi_tr = cell(1,n_mat);
0196 for i = 1:n_mat;
0197     if usebf
0198         cols_tr = tr_idx;
0199     else
0200         cols_tr = size(Phi_all{i},2);
0201     end
0202     
0203     Phi_tr{i} = Phi_all{i}(tr_idx,cols_tr);
0204 end
0205 
0206 % test
0207 Phi_te  = cell(1,n_mat);
0208 Phi_tt = cell(1,n_mat);
0209 if usebf
0210     cols_tr = tr_idx;
0211     cols_te = te_idx;
0212 else
0213     cols_tr = size(Phi_all{i},2);
0214     %cols_te = size(Phi_all{i},2);
0215 end
0216 
0217 for i = 1:length(Phi_all)
0218     Phi_te{i} = Phi_all{i}(te_idx, cols_tr);
0219     if usebf
0220         Phi_tt{i} = Phi_all{i}(te_idx, cols_te);
0221     else
0222         Phi_tt{i} = [];
0223     end
0224 end
0225 end
0226

Generated on Sun 20-May-2012 13:24:48 by m2html © 2005