Home > batch > prt_run_model.m

prt_run_model

PURPOSE ^

SYNOPSIS ^

function out = prt_run_model(varargin)

DESCRIPTION ^

 PRoNTo job execution function
 takes a harvested job data structure and rearrange data into "proper"
 data structure, then save do what it has to do...
 Here simply the harvested job structure in a mat file.

 INPUT
   job    - harvested job data structure (see matlabbatch help)

 OUTPUT
   out    - filename of saved data structure.

   This function assembles a model structure with following fields:

   model.fname:      filename for PRT.mat
   model.model_name: name for this cross-validation structure
   model.type:       'classification' or 'regression'
   model.use_kernel: does this model use kernels or features?
   model.operations: operations to apply before prediction

   model.fs(f).fs_name:     feature set(s) this CV approach is defined for
   model.fs(f).fs_features: feature selection mode ('all' or 'mask')
   model.fs(f).mask_file:   mask for this feature set (fs_features='mask')

   model.class(c).class_name
   model.class(c).group(g).subj(s).num
   model.class(c).group(g).subj(s).modality(m).mod_name
   EITHER: model.class(c).group(g).subj(s).modality(m).conds(c).cond_name
   OR:     model.class(c).group(g).subj(s).modality(m).all_scans
   OR:     model.class(c).group(g).subj(s).modality(m).all_cond

   model.cv.type:     type of cross-validation ('loso','losgo','custom')
   model.cv.mat_file: file specifying CV matrix (if type = 'custom');

   FIXME: add a more flexible interface for specifying custom CV
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function out = prt_run_model(varargin)
0002 %
0003 % PRoNTo job execution function
0004 % takes a harvested job data structure and rearrange data into "proper"
0005 % data structure, then save do what it has to do...
0006 % Here simply the harvested job structure in a mat file.
0007 %
0008 % INPUT
0009 %   job    - harvested job data structure (see matlabbatch help)
0010 %
0011 % OUTPUT
0012 %   out    - filename of saved data structure.
0013 %
0014 %   This function assembles a model structure with following fields:
0015 %
0016 %   model.fname:      filename for PRT.mat
0017 %   model.model_name: name for this cross-validation structure
0018 %   model.type:       'classification' or 'regression'
0019 %   model.use_kernel: does this model use kernels or features?
0020 %   model.operations: operations to apply before prediction
0021 %
0022 %   model.fs(f).fs_name:     feature set(s) this CV approach is defined for
0023 %   model.fs(f).fs_features: feature selection mode ('all' or 'mask')
0024 %   model.fs(f).mask_file:   mask for this feature set (fs_features='mask')
0025 %
0026 %   model.class(c).class_name
0027 %   model.class(c).group(g).subj(s).num
0028 %   model.class(c).group(g).subj(s).modality(m).mod_name
0029 %   EITHER: model.class(c).group(g).subj(s).modality(m).conds(c).cond_name
0030 %   OR:     model.class(c).group(g).subj(s).modality(m).all_scans
0031 %   OR:     model.class(c).group(g).subj(s).modality(m).all_cond
0032 %
0033 %   model.cv.type:     type of cross-validation ('loso','losgo','custom')
0034 %   model.cv.mat_file: file specifying CV matrix (if type = 'custom');
0035 %
0036 %   FIXME: add a more flexible interface for specifying custom CV
0037 %__________________________________________________________________________
0038 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0039 
0040 % Written by A Marquand
0041 % $Id$
0042 
0043 % Job variable
0044 % -------------------------------------------------------------------------
0045 job   = varargin{1};
0046 
0047 % Load PRT.mat
0048 % -------------------------------------------------------------------------
0049 fname = char(job.infile);
0050 if exist('PRT','var')
0051     clear PRT
0052 end
0053 PRT=prt_load(fname);
0054 if ~isempty(PRT)
0055     handles.dat=PRT;
0056 else
0057     beep
0058     disp('Could not load file')
0059     return
0060 end
0061 
0062 % assemble basic fields
0063 model.fname      = fname;
0064 model.model_name = job.model_name;
0065 if ~(prt_checkAlphaNumUnder(model.model_name))
0066     beep
0067     disp('Model name should be entered in alphanumeric format only')
0068     disp('Please correct')
0069     return
0070 end
0071 model.use_kernel = job.use_kernel;
0072 
0073 % insert feature set fields
0074 
0075 model.fs(1).fs_name = job.fsets;
0076 fid = prt_init_fs(PRT,model.fs(1));
0077 mods = cellstr(char(PRT.fs(fid).modality(:).mod_name));
0078 
0079 % get the conditions which are common to all subjects from all groups
0080 nm = length(mods);
0081 for i=1:nm
0082     flag=1;
0083     for j=1:length(PRT.group)
0084         for k=1:length(PRT.group(j).subject)
0085             m2= find(strcmpi(PRT.fs(fid).modality(i).mod_name,mods));
0086             if isempty(m2)
0087                 m2= find(strcmpi(PRT.fs(fid).modality(i).mod_name,mods{1}));
0088             end
0089             des=PRT.group(j).subject(k).modality(m2).design;
0090             if isstruct(des) && flag
0091                 if k==1 && j==1
0092                     lcond={des.conds(:).cond_name};
0093                 else
0094                     tocmp={des.conds(:).cond_name};
0095                     lcond=intersect(lower(lcond),lower(tocmp));
0096                 end
0097             else
0098                 flag=0;
0099                 lcond={};
0100             end
0101         end
0102     end
0103 end
0104 % Insert fields for generating the labels (ie. translate the fields coming
0105 % from matlabbatch to something more consistent for the prt_model function)
0106 % Note that we cycle through the groups to flatten out the structure, since
0107 % we potentially specify multiple subjects per group
0108 if isfield(job.model_type,'classification')
0109     model.type = 'classification';
0110     for c = 1:length(job.model_type.classification.class)
0111         model.class(c).class_name = job.model_type.classification.class(c).class_name;
0112 
0113         for g = 1:length(job.model_type.classification.class(c).group)
0114             scount = 1;
0115             model.class(c).group(g).gr_name = ...
0116                 job.model_type.classification.class(c).group(g).gr_name;
0117 
0118             sids   = job.model_type.classification.class(c).group(g).subj_nums;
0119             for s = 1:length(sids)
0120                 model.class(c).group(g).subj(scount).num = sids(s);
0121                 for m = 1: length(mods)
0122                     model.class(c).group(g).subj(scount).modality(m).mod_name=mods{m};
0123                     if isfield(job.model_type.classification.class(c).group(g).conditions,'all_scans')
0124                         model.class(c).group(g).subj(scount).modality(m).all_scans = true;
0125                     elseif isfield(job.model_type.classification.class(c).group(g).conditions,'all_cond')
0126                         model.class(c).group(g).subj(scount).modality(m).all_cond = true;
0127                         if isempty(lcond)
0128                             beep
0129                             disp('All conditions selected while no conditions were common to all subjects')
0130                             disp('Please review the selection and/or the data and design')
0131                             return
0132                         end
0133                     else
0134                         model.class(c).group(g).subj(scount).modality(m).conds = ...
0135                             job.model_type.classification.class(c).group(g).conditions.conds;
0136                         for cc=1:length(job.model_type.classification.class(c).group(g).conditions.conds)
0137                             cname=job.model_type.classification.class(c).group(g).conditions.conds(cc).cond_name;
0138                             if isempty(intersect(lower({cname}),lower(lcond)))
0139                                 beep
0140                                 disp('This condition is not common to all subjects')
0141                                 disp('Please remove it from the selection')
0142                                 return
0143                             end
0144                         end
0145                     end
0146                 end
0147                 scount = scount+1;
0148             end
0149         end
0150     end
0151     % insert machine fields
0152     if isfield(job.model_type.classification.machine_cl,'svm')
0153         model.machine.function = 'prt_machine_svm_bin';
0154         model.machine.args     = job.model_type.classification.machine_cl.svm.svm_args;
0155         if isfield(job.model_type.classification.machine_cl.svm, 'svm_opt')
0156             if job.model_type.classification.machine_cl.svm.svm_opt
0157                 model.cv.nested = 1;
0158                 model.cv.nested_param = job.model_type.classification.machine_cl.svm.svm_args;
0159             end
0160         end
0161         if isfield(job.model_type.classification.machine_cl.svm, 'cv_type_nested')
0162            [cv_tmp] = get_cv_type(job.model_type.classification.machine_cl.svm.cv_type_nested);
0163            model.cv.type_nested = cv_tmp.type;
0164            model.cv.k_nested = cv_tmp.k;
0165         end
0166     elseif isfield(job.model_type.classification.machine_cl,'gpc')
0167         model.machine.function='prt_machine_gpml';
0168         model.machine.args=job.model_type.classification.machine_cl.gpc.gpc_args;
0169     elseif isfield(job.model_type.classification.machine_cl,'gpclap')
0170         model.machine.function='prt_machine_gpclap';
0171         model.machine.args=job.model_type.classification.machine_cl.gpclap.gpclap_args;
0172     elseif isfield(job.model_type.classification.machine_cl,'rt')
0173         model.machine.function='prt_machine_RT_bin';
0174         model.machine.args=job.model_type.classification.machine_cl.rt.rt_args;
0175     elseif isfield(job.model_type.classification.machine_cl,'sMKL_cla')
0176         model.machine.function='prt_machine_sMKL_cla';
0177         model.machine.args=job.model_type.classification.machine_cl.sMKL_cla.sMKL_cla_args;
0178         if isfield(job.model_type.classification.machine_cl.sMKL_cla, 'sMKL_cla_opt')
0179             if job.model_type.classification.machine_cl.sMKL_cla.sMKL_cla_opt
0180                 model.cv.nested = 1;
0181                 model.cv.nested_param = job.model_type.classification.machine_cl.sMKL_cla.sMKL_cla_args;
0182             end
0183         end
0184         if isfield(job.model_type.classification.machine_cl.sMKL_cla, 'cv_type_nested')
0185            [cv_tmp] = get_cv_type(job.model_type.classification.machine_cl.sMKL_cla.cv_type_nested);
0186            model.cv.type_nested = cv_tmp.type;
0187            model.cv.k_nested = cv_tmp.k;
0188         end
0189         
0190     else
0191         [pat, nam] = fileparts(char(job.model_type.classification.machine_cl.custom_machine.machine_func));
0192         model.machine.function = nam;
0193         model.machine.args = job.model_type.classification.machine_cl.custom_machine.machine_args;
0194     end
0195 
0196 elseif isfield(job.model_type,'regression')
0197     model.type = 'regression';
0198     for g = 1:length(job.model_type.regression.reg_group)
0199         scount = 1;
0200         model.group(g).gr_name = job.model_type.regression.reg_group(g).gr_name;
0201         sids   =  job.model_type.regression.reg_group(g).subj_nums;
0202         for s = 1:length(sids)
0203             model.group(g).subj(scount).num = sids(s);
0204             model.group(g).subj(scount).modality.mod_name =  mods;
0205             scount=scount+1;
0206         end
0207     end
0208     
0209     if isfield(job.model_type.regression.machine_rg,'krr')
0210         model.machine.function = 'prt_machine_krr';
0211         model.machine.args=job.model_type.regression.machine_rg.krr.krr_args;
0212         if isfield(job.model_type.regression.machine_rg.krr, 'krr_opt')
0213             if job.model_type.regression.machine_rg.krr.krr_opt
0214                 model.cv.nested = 1;
0215                 model.cv.nested_param = job.model_type.regression.machine_rg.krr.krr_args;
0216             end
0217         end
0218          if isfield(job.model_type.regression.machine_rg.krr, 'cv_type_nested')
0219            [cv_tmp] = get_cv_type(job.model_type.regression.machine_rg.krr.cv_type_nested);
0220            model.cv.type_nested = cv_tmp.type;
0221            model.cv.k_nested = cv_tmp.k;
0222         end
0223     elseif isfield(job.model_type.regression.machine_rg,'rvr')
0224         model.machine.function='prt_machine_rvr';
0225         model.machine.args=[];
0226     elseif isfield(job.model_type.regression.machine_rg,'gpr')
0227         model.machine.function='prt_machine_gpr';
0228         model.machine.args=job.model_type.regression.machine_rg.gpr.gpr_args;
0229     elseif isfield(job.model_type.regression.machine_rg,'sMKL_reg')
0230         model.machine.function='prt_machine_sMKL_reg';
0231         model.machine.args=job.model_type.regression.machine_rg.sMKL_reg.sMKL_reg_args;
0232         if isfield(job.model_type.regression.machine_rg.sMKL_reg, 'sMKL_reg_opt')
0233             if job.model_type.regression.machine_rg.sMKL_reg.sMKL_reg_opt
0234                 model.cv.nested = 1;
0235                 model.cv.nested_param = job.model_type.regression.machine_rg.sMKL_reg.sMKL_reg_args;
0236             end
0237         end
0238         if isfield(job.model_type.regression.machine_rg.sMKL_reg, 'cv_type_nested')
0239 %            [cv_type, k] = get_cv_type(job.model_type.regression.machine_rg.sMKL_reg.cv_type_nested);
0240 %            model.cv.type_nested = cv_type;
0241 %            model.cv.k_nested = k;
0242            [cv_tmp] = get_cv_type(job.model_type.regression.machine_rg.sMKL_reg.cv_type_nested);
0243            model.cv.type_nested = cv_tmp.type;
0244            model.cv.k_nested = cv_tmp.k;
0245         end        
0246         
0247     else
0248         [pat, nam] = fileparts(char(job.model_type.regression.machine_rg.custom_machine.machine_func));
0249         model.machine.function = nam;
0250         model.machine.args = job.model_type.regression.machine_rg.custom_machine.machine_args;
0251     end
0252 else
0253     error('this is not implemented yet');
0254 end
0255 
0256 % assemble structure for performing cross-validation
0257 model.cv = get_cv_type(job.cv_type);
0258 model.include_allscans = job.include_allscans;
0259 
0260 % specify operations to apply to the data prior to prediction
0261 % if isfield(job.data_ops,'data_ops')
0262 %     model.operations = [job.data_ops.sel_ops.data_op{:}];
0263 % elseif isfield(job.data_ops,'no_op')
0264 %     model.operations = [];
0265 % end
0266 if isfield(job.sel_ops.use_other_ops,'data_op')
0267     ops = [job.sel_ops.use_other_ops.data_op{:}];
0268 elseif isfield(job.sel_ops.use_other_ops,'no_op')
0269     ops = [];
0270 end
0271 if job.sel_ops.data_op_mc == 1
0272     model.operations = [3 ops];
0273 else
0274     model.operations = ops;
0275 end
0276 
0277 prt_model(PRT,model);
0278 
0279 % Function output
0280 % -------------------------------------------------------------------------
0281 out.files{1} = fname;
0282 out.mname = model.model_name;
0283 disp('Model configuration complete.')
0284 end
0285 
0286 
0287 %--------------------------------------------------------------------------
0288 % Private functions
0289 %--------------------------------------------------------------------------
0290 function cv = get_cv_type(cv_struct)
0291 
0292 % assemble structure for performing cross-validation
0293 if isfield(cv_struct,'cv_loso')
0294     cv = struct('type','loso','k',0);
0295 elseif isfield(cv_struct,'cv_lkso')
0296     cv = struct('type','loso','k',cv_struct.cv_lkso.k_args);
0297 elseif isfield(cv_struct,'cv_losgo')
0298     cv = struct('type','losgo','k',0);
0299 elseif isfield(cv_struct,'cv_lksgo')
0300     cv = struct('type','losgo','k',cv_struct.cv_lksgo.k_args);
0301 elseif isfield(cv_struct,'cv_lobo')
0302     cv = struct('type','lobo','k',0);
0303 elseif isfield(cv_struct,'cv_lkbo')
0304     cv = struct('type','lobo','k',cv_struct.cv_lkbo.k_args);
0305 elseif isfield(cv_struct,'cv_loro') % currently implemented for MCKR only
0306     cv = struct('type','loro');
0307 else
0308     cv = struct('type','custom','k',cv_struct.cv_custom{1},...
0309         'mat_file',cv_struct.cv_custom{1});
0310     % Not sure if I should keep the field 'k' here...
0311 end
0312 
0313 end

Generated on Tue 10-Feb-2015 18:16:33 by m2html © 2005