Home > . > prt_cv_model.m

prt_cv_model

PURPOSE ^

Function to run a cross-validation structure on a given model

SYNOPSIS ^

function [outfile]=prt_cv_model(PRT,in)

DESCRIPTION ^

 Function to run a cross-validation structure on a given model 

 Inputs:
 -------
 PRT:           data structure
 in.fname:      filename for PRT.mat (string)
 in.model_name: name for this model (string)

 Outputs:
 --------
 Writes the following fields in the PRT data structure:
 
 PRT.model(m).output.fold(i).targets:     targets for fold(i)
 PRT.model(m).output.fold(i).predictions: predictions for fold(i)
 PRT.model(m).output.fold(i).stats:       statistics for fold(i)
 PRT.model(m).output.fold(i).{custom}:    optional fields

 Notes: 
 ------
 The PRT.model(m).input fields are set by prt_init_model, not by
 this function

__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [outfile]=prt_cv_model(PRT,in)
0002 % Function to run a cross-validation structure on a given model
0003 %
0004 % Inputs:
0005 % -------
0006 % PRT:           data structure
0007 % in.fname:      filename for PRT.mat (string)
0008 % in.model_name: name for this model (string)
0009 %
0010 % Outputs:
0011 % --------
0012 % Writes the following fields in the PRT data structure:
0013 %
0014 % PRT.model(m).output.fold(i).targets:     targets for fold(i)
0015 % PRT.model(m).output.fold(i).predictions: predictions for fold(i)
0016 % PRT.model(m).output.fold(i).stats:       statistics for fold(i)
0017 % PRT.model(m).output.fold(i).{custom}:    optional fields
0018 %
0019 % Notes:
0020 % ------
0021 % The PRT.model(m).input fields are set by prt_init_model, not by
0022 % this function
0023 %
0024 %__________________________________________________________________________
0025 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0026 
0027 % Written by A Marquand
0028 % $Id: prt_cv_model.m 551 2012-05-31 08:26:54Z amarquan $
0029 
0030 prt_dir = char(regexprep(in.fname,'PRT.mat', ''));
0031 
0032 % Get index of specified model
0033 mid = prt_init_model(PRT, in);
0034 
0035 % configure some variables
0036 CV       = PRT.model(mid).input.cv_mat;     % CV matrix
0037 n_folds  = size(CV,2);                      % number of CV folds
0038 n_Phi    = length(PRT.model(mid).input.fs); % number of data matrices
0039 samp_idx = PRT.model(mid).input.samp_idx;   % which samples are in the model
0040 
0041 % targets
0042 if isfield(PRT.model(mid).input,'include_allscans') && ...
0043    PRT.model(mid).input.include_allscans
0044     t = PRT.model(mid).input.targ_allscans;
0045 else
0046     t = PRT.model(mid).input.targets;
0047 end
0048 
0049 % load data files and configure ID matrix
0050 disp('Loading data files.....>>');
0051 Phi_all = cell(1,n_Phi);
0052 for i = 1:length(PRT.model(mid).input.fs)
0053     fid = prt_init_fs(PRT, PRT.model(mid).input.fs(i));
0054     
0055     if i == 1
0056         ID = PRT.fs(fid).id_mat(PRT.model(mid).input.samp_idx,:);
0057     end
0058         
0059     if PRT.model(mid).input.use_kernel
0060         load(fullfile(prt_dir, PRT.fs(fid).k_file));
0061         Phi_all{i} = Phi(samp_idx,samp_idx);
0062     else
0063         error('training with features not implemented yet');
0064         %vname = whos('-file', [prt_dir,PRT.fs(fid).fs_file]);
0065         %eval(['Phi_all{',num2str(i),'}=',vname,'(samp_idx,:);']);
0066     end
0067 end
0068 
0069 % Begin cross-validation loop
0070 % -------------------------------------------------------------------------
0071 PRT.model(mid).output.fold = struct();
0072 for f = 1:n_folds
0073     disp ([' > running CV fold: ',num2str(f),' of ',num2str(n_folds),' ...'])
0074     % configure data structure for prt_cv_fold
0075     fdata.ID      = ID;
0076     fdata.mid     = mid;
0077     fdata.CV      = CV(:,f);
0078     fdata.Phi_all = Phi_all;
0079     fdata.t       = t;
0080     
0081     % compute the model for this CV fold
0082     [model, targets] = prt_cv_fold(PRT,fdata);
0083     
0084     % compute stats
0085     stats = prt_stats(model, targets.test, targets.train);
0086     
0087     % update PRT
0088     PRT.model(mid).output.fold(f).targets     = targets.test; 
0089     PRT.model(mid).output.fold(f).predictions = model.predictions(:);
0090     PRT.model(mid).output.fold(f).stats       = stats;
0091     % copy other fields from the model
0092     flds = fieldnames(model);
0093     for fld = 1:length(flds)
0094         fldnm = char(flds(fld));
0095         if ~strcmpi(fldnm,'predictions')
0096             PRT.model(mid).output.fold(f).(fldnm)=model.(fldnm);
0097         end
0098     end
0099 end
0100 
0101 
0102 % Model level statistics (across folds)
0103 t             = vertcat(PRT.model(mid).output.fold(:).targets);
0104 m.type        = PRT.model(mid).output.fold(1).type;
0105 m.predictions = vertcat(PRT.model(mid).output.fold(:).predictions);
0106 %m.func_val    = [PRT.model(mid).output.fold(:).func_val];
0107 stats         = prt_stats(m,t(:),t(:));
0108 
0109 PRT.model(mid).output.stats=stats;
0110 
0111 % Save PRT containing machine output
0112 % -------------------------------------------------------------------------
0113 outfile = [prt_dir, 'PRT'];
0114 disp('Updating PRT.mat.......>>')
0115 if spm_matlab_version_chk('7') >= 0
0116     save(outfile,'-V6','PRT');
0117 else
0118     save(outfile,'PRT');
0119 end
0120 end
0121

Generated on Mon 03-Sep-2012 18:07:18 by m2html © 2005