Home > batch > prt_run_model.m

prt_run_model

PURPOSE ^

SYNOPSIS ^

function out = prt_run_model(varargin)

DESCRIPTION ^

 PRoNTo job execution function
 takes a harvested job data structure and rearrange data into "proper"
 data structure, then save do what it has to do...
 Here simply the harvested job structure in a mat file.

 INPUT
   job    - harvested job data structure (see matlabbatch help)

 OUTPUT
   out    - filename of saved data structure.

   This function assembles a model structure with following fields:

   model.fname:      filename for PRT.mat
   model.model_name: name for this cross-validation structure
   model.type:       'classification' or 'regression'
   model.use_kernel: does this model use kernels or features?
   model.operations: operations to apply before prediction

   model.fs(f).fs_name:     feature set(s) this CV approach is defined for
   model.fs(f).fs_features: feature selection mode ('all' or 'mask')
   model.fs(f).mask_file:   mask for this feature set (fs_features='mask')

   model.class(c).class_name
   model.class(c).group(g).subj(s).num
   model.class(c).group(g).subj(s).modality(m).mod_name
   EITHER: model.class(c).group(g).subj(s).modality(m).conds(c).cond_name
   OR:     model.class(c).group(g).subj(s).modality(m).all_scans
   OR:     model.class(c).group(g).subj(s).modality(m).all_cond

   model.cv.type:     type of cross-validation ('loso','losgo','custom')
   model.cv.mat_file: file specifying CV matrix (if type = 'custom');

   FIXME: add a more flexible interface for specifying custom CV
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function out = prt_run_model(varargin)
0002 %
0003 % PRoNTo job execution function
0004 % takes a harvested job data structure and rearrange data into "proper"
0005 % data structure, then save do what it has to do...
0006 % Here simply the harvested job structure in a mat file.
0007 %
0008 % INPUT
0009 %   job    - harvested job data structure (see matlabbatch help)
0010 %
0011 % OUTPUT
0012 %   out    - filename of saved data structure.
0013 %
0014 %   This function assembles a model structure with following fields:
0015 %
0016 %   model.fname:      filename for PRT.mat
0017 %   model.model_name: name for this cross-validation structure
0018 %   model.type:       'classification' or 'regression'
0019 %   model.use_kernel: does this model use kernels or features?
0020 %   model.operations: operations to apply before prediction
0021 %
0022 %   model.fs(f).fs_name:     feature set(s) this CV approach is defined for
0023 %   model.fs(f).fs_features: feature selection mode ('all' or 'mask')
0024 %   model.fs(f).mask_file:   mask for this feature set (fs_features='mask')
0025 %
0026 %   model.class(c).class_name
0027 %   model.class(c).group(g).subj(s).num
0028 %   model.class(c).group(g).subj(s).modality(m).mod_name
0029 %   EITHER: model.class(c).group(g).subj(s).modality(m).conds(c).cond_name
0030 %   OR:     model.class(c).group(g).subj(s).modality(m).all_scans
0031 %   OR:     model.class(c).group(g).subj(s).modality(m).all_cond
0032 %
0033 %   model.cv.type:     type of cross-validation ('loso','losgo','custom')
0034 %   model.cv.mat_file: file specifying CV matrix (if type = 'custom');
0035 %
0036 %   FIXME: add a more flexible interface for specifying custom CV
0037 %__________________________________________________________________________
0038 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0039 
0040 % Written by A Marquand
0041 % $Id: prt_run_model.m 501 2012-04-06 14:16:38Z amarquan $
0042 
0043 % Job variable
0044 % -------------------------------------------------------------------------
0045 job   = varargin{1};
0046 
0047 % Load PRT.mat
0048 % -------------------------------------------------------------------------
0049 fname = char(job.infile);
0050 if exist('PRT','var')
0051     clear PRT
0052 end
0053 PRT=prt_load(fname);
0054 if ~isempty(PRT)
0055     handles.dat=PRT;
0056 else
0057     beep
0058     disp('Could not load file')
0059     return
0060 end
0061 
0062 % assemble basic fields
0063 model.fname      = fname;
0064 model.model_name = job.model_name;
0065 if ~(prt_checkAlphaNumUnder(model.model_name))
0066     beep
0067     disp('Model name should be entered in alphanumeric format only')
0068     disp('Please correct')
0069     return
0070 end
0071 model.use_kernel = job.use_kernel;
0072 
0073 % insert feature set fields
0074 
0075 model.fs(1).fs_name = job.fsets;
0076 fid=prt_init_fs(PRT,model.fs(1));
0077 mods={PRT.fs(fid).modality(:).mod_name};
0078 
0079 %get the conditions which are common to all subjects from all groups
0080 nm=length(mods);
0081 for i=1:nm
0082     flag=1;
0083     for j=1:length(PRT.group)
0084         for k=1:length(PRT.group(j).subject)
0085             m2= strcmpi(PRT.fs(fid).modality(nm).mod_name,mods);
0086             des=PRT.group(j).subject(k).modality(m2).design;
0087             if isstruct(des) && flag
0088                 if k==1 && j==1
0089                     lcond={des.conds(:).cond_name};
0090                 else
0091                     tocmp={des.conds(:).cond_name};
0092                     lcond=intersect(lcond,tocmp);
0093                 end
0094             else
0095                 flag=0;
0096                 lcond={};
0097             end
0098         end
0099     end
0100 end
0101 % Insert fields for generating the labels (ie. translate the fields coming
0102 % from matlabbatch to something more consistent for the prt_model function)
0103 % Note that we cycle through the groups to flatten out the structure, since
0104 % we potentially specify multiple subjects per group
0105 if isfield(job.model_type,'classification')
0106     model.type = 'classification';
0107     for c = 1:length(job.model_type.classification.class)
0108         model.class(c).class_name = job.model_type.classification.class(c).class_name;
0109         
0110         for g = 1:length(job.model_type.classification.class(c).group)
0111             scount = 1;
0112             model.class(c).group(g).gr_name = ...
0113                 job.model_type.classification.class(c).group(g).gr_name;
0114             
0115             sids   = job.model_type.classification.class(c).group(g).subj_nums;
0116             for s = 1:length(sids)
0117                 model.class(c).group(g).subj(scount).num = sids(s);
0118                 for m = 1: length(mods)
0119                     model.class(c).group(g).subj(scount).modality(m).mod_name=mods{m};
0120                     if isfield(job.model_type.classification.class(c).group(g).conditions,'all_scans')
0121                         model.class(c).group(g).subj(scount).modality(m).all_scans = true;
0122                     elseif isfield(job.model_type.classification.class(c).group(g).conditions,'all_cond')
0123                         model.class(c).group(g).subj(scount).modality(m).all_cond = true;
0124                         if isempty(lcond)
0125                             beep
0126                             disp('All conditions selected while no conditions were common to all subjects')
0127                             disp('Please review the selection and/or the data and design')
0128                             return
0129                         end
0130                     else
0131                         model.class(c).group(g).subj(scount).modality(m).conds = ...
0132                             job.model_type.classification.class(c).group(g).conditions.conds;
0133                         for cc=1:length(job.model_type.classification.class(c).group(g).conditions.conds)
0134                             cname=job.model_type.classification.class(c).group(g).conditions.conds(cc).cond_name;
0135                             if isempty(intersect({cname},lcond))
0136                                 beep
0137                                 disp('This condition is not common to all subjects')
0138                                 disp('Please remove it from the selection')
0139                                 return
0140                             end
0141                         end
0142                     end
0143                 end
0144                 scount = scount+1;
0145             end
0146         end
0147     end
0148     % insert machine fields
0149     if isfield(job.model_type.classification.machine_cl,'svm')
0150         model.machine.function = 'prt_machine_svm_bin';
0151         model.machine.args     = job.model_type.classification.machine_cl.svm.svm_args;
0152     elseif isfield(job.model_type.classification.machine_cl,'gpc')
0153         model.machine.function='prt_machine_gpml';
0154         model.machine.args=job.model_type.classification.machine_cl.gpc.gpc_args;
0155     elseif isfield(job.model_type.classification.machine_cl,'gpclap')
0156         model.machine.function='prt_machine_gpclap';
0157         model.machine.args=job.model_type.classification.machine_cl.gpclap.gpclap_args;
0158     elseif isfield(job.model_type.classification.machine_cl,'rt')
0159         model.machine.function='prt_machine_RT_bin';
0160         model.machine.args=job.model_type.classification.machine_cl.rt.rt_args;
0161     else
0162         [pat, nam] = fileparts(char(job.model_type.classification.machine_cl.custom_machine.machine_func));
0163         model.machine.function = nam;
0164         model.machine.args = job.model_type.classification.machine_cl.custom_machine.machine_args;
0165     end
0166 
0167 elseif isfield(job.model_type,'regression')
0168     model.type = 'regression';
0169     scount = 1;
0170     for g = 1:length(job.model_type.regression.reg_group)
0171         model.group(g).gr_name = job.model_type.regression.reg_group(g).gr_name;
0172         sids   =  job.model_type.regression.reg_group(g).subj_nums;
0173         for s = 1:length(sids)
0174             model.group(g).subj(scount).num = sids(s);
0175             model.group(g).subj(scount).modality.mod_name =  mods;
0176             scount=scount+1;
0177         end
0178     end
0179     
0180     if isfield(job.model_type.regression.machine_rg,'krr')
0181         model.machine.function='prt_machine_krr';
0182         model.machine.args=job.model_type.regression.machine_rg.krr.krr_args;
0183     elseif isfield(job.model_type.regression.machine_rg,'rvr')
0184         model.machine.function='prt_machine_rvr';
0185         model.machine.args=[];
0186     elseif isfield(job.model_type.regression.machine_rg,'gpr')
0187         model.machine.function='prt_machine_gpr';
0188         model.machine.args=job.model_type.regression.machine_rg.gpr.gpr_args;
0189     else
0190         [pat, nam] = fileparts(char(job.model_type.regression.machine_rg.custom_machine.machine_func));
0191         model.machine.function = nam;
0192         model.machine.args = job.model_type.regression.machine_rg.custom_machine.machine_args;
0193     end   
0194 else
0195     error('this is not implemented yet');   
0196 end
0197 
0198 % assemble structure for performing cross-validation
0199 if isfield(job.cv_type,'cv_loso')
0200     model.cv.type = 'loso';
0201 elseif isfield(job.cv_type,'cv_losgo')
0202     model.cv.type = 'losgo';
0203 elseif isfield(job.cv_type,'cv_lobo')
0204     model.cv.type = 'lobo';
0205 %     if scount>1
0206 %         beep
0207 %         disp('Leave One Block Out Cross Validation only allowed for within subject modeling')
0208 %         disp('Please correct')
0209 %     end
0210 elseif isfield(job.cv_type,'cv_loro') %currently implemented for MCKR only
0211     model.cv.type = 'loro';
0212 else
0213     model.cv.type     = 'custom';
0214     model.cv.mat_file = job.cv_type;
0215 end
0216 
0217 model.include_allscans = job.include_allscans;
0218 
0219 % specify operations to apply to the data prior to prediction
0220 % if isfield(job.data_ops,'data_ops')
0221 %     model.operations = [job.data_ops.sel_ops.data_op{:}];
0222 % elseif isfield(job.data_ops,'no_op')
0223 %     model.operations = [];
0224 % end
0225 if isfield(job.sel_ops.use_other_ops,'data_op')
0226     ops = [job.sel_ops.use_other_ops.data_op{:}];
0227 elseif isfield(job.sel_ops.use_other_ops,'no_op')
0228     ops = [];
0229 end
0230 if job.sel_ops.data_op_mc == 0
0231     model.operations = [3 ops];
0232 else
0233     model.operations = ops;
0234 end
0235 
0236 prt_model(PRT,model);
0237 
0238 % Function output
0239 % -------------------------------------------------------------------------
0240 out.files{1} = fname;
0241 out.names{1} = model.model_name;
0242 disp('Model configuration complete.')
0243 end

Generated on Mon 03-Sep-2012 18:07:18 by m2html © 2005