Home > . > prt_model.m

prt_model

PURPOSE ^

Function to configure and build the PRT.model data structure

SYNOPSIS ^

function [PRT, CV, ID] = prt_model(PRT,in)

DESCRIPTION ^

 Function to configure and build the PRT.model data structure

 Input:
 ------
   PRT fields:
   model.fs(f).fs_name:     feature set(s) this CV approach is defined for
   model.fs(f).fs_features: feature selection mode ('all' or 'mask')
   model.fs(f).mask_file:   mask for this feature set (fs_features='mask')

   in.fname:      filename for PRT.mat
   in.model_name: name for this cross-validation structure
   in.type:       'classification' or 'regression'
   in.use_kernel: does this model use kernels or features?
   in.operations: operations to apply before prediction

   in.fs(f).fs_name:     feature set(s) this CV approach is defined for

   in.class(c).class_name
   in.class(c).group(g).subj(s).num
   in.class(c).group(g).subj(s).modality(m).mod_name
   EITHER: in.class(c).group(g).subj(s).modality(m).conds(c).cond_name
   OR:     in.class(c).group(g).subj(s).modality(m).all_scans
   OR:     in.class(c).group(g).subj(s).modality(m).all_cond

   in.cv.type:     type of cross-validation ('loso','losgo','custom')
   in.cv.mat_file: file specifying CV matrix (if type='custom');

 Output:
 -------

   This function performs the following functions:
      1. populates basic fields in PRT.model(m).input
      2. computes PRT.model(m).input.targets based on in.class(c)...
      3. computes PRT.model(m).input.samp_idx based on targets
      4. computes PRT.model(m).input.cv_mat based on the labels and CV spec
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [PRT, CV, ID] = prt_model(PRT,in)
0002 % Function to configure and build the PRT.model data structure
0003 %
0004 % Input:
0005 % ------
0006 %   PRT fields:
0007 %   model.fs(f).fs_name:     feature set(s) this CV approach is defined for
0008 %   model.fs(f).fs_features: feature selection mode ('all' or 'mask')
0009 %   model.fs(f).mask_file:   mask for this feature set (fs_features='mask')
0010 %
0011 %   in.fname:      filename for PRT.mat
0012 %   in.model_name: name for this cross-validation structure
0013 %   in.type:       'classification' or 'regression'
0014 %   in.use_kernel: does this model use kernels or features?
0015 %   in.operations: operations to apply before prediction
0016 %
0017 %   in.fs(f).fs_name:     feature set(s) this CV approach is defined for
0018 %
0019 %   in.class(c).class_name
0020 %   in.class(c).group(g).subj(s).num
0021 %   in.class(c).group(g).subj(s).modality(m).mod_name
0022 %   EITHER: in.class(c).group(g).subj(s).modality(m).conds(c).cond_name
0023 %   OR:     in.class(c).group(g).subj(s).modality(m).all_scans
0024 %   OR:     in.class(c).group(g).subj(s).modality(m).all_cond
0025 %
0026 %   in.cv.type:     type of cross-validation ('loso','losgo','custom')
0027 %   in.cv.mat_file: file specifying CV matrix (if type='custom');
0028 %
0029 % Output:
0030 % -------
0031 %
0032 %   This function performs the following functions:
0033 %      1. populates basic fields in PRT.model(m).input
0034 %      2. computes PRT.model(m).input.targets based on in.class(c)...
0035 %      3. computes PRT.model(m).input.samp_idx based on targets
0036 %      4. computes PRT.model(m).input.cv_mat based on the labels and CV spec
0037 %__________________________________________________________________________
0038 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0039 
0040 % Written by A Marquand
0041 % $Id$
0042 
0043 % Populate basic fields in PRT.mat
0044 % -------------------------------------------------------------------------
0045 [modelid, PRT] = prt_init_model(PRT,in);
0046 
0047 % specify model type and feature sets
0048 PRT.model(modelid).input.type = in.type;
0049 if strcmp(in.type,'classification')
0050     for c = 1:length(in.class)
0051         PRT.model(modelid).input.class(c) = in.class(c);
0052     end
0053 end
0054 
0055 for f = 1:length(in.fs)
0056     fid = prt_init_fs(PRT,in.fs(f));
0057     
0058     if length(PRT.fs(fid).modality) > 1 && length(in.fs) > 1
0059         error('prt_model:multipleFeatureSetsAppliedAsSamplesAndAsFeatures',...
0060             ['Feature set ',in.fs(f).fs_name,' contains multiple modalities ',...
0061             'and job specifies that multiple feature sets should be ',...
0062             'supplied to the machine. This usage is not supported.']);
0063     end
0064     
0065     PRT.model(modelid).input.fs(f).fs_name = in.fs(f).fs_name;
0066 end
0067 
0068 % compute targets and samp_idx
0069 % -------------------------------------------------------------------------
0070 if strcmp(in.type,'classification')
0071     [targets, samp_idx, t_allscans, samp_allscans,covar,cov_all] = compute_targets(PRT, in);
0072 else
0073     [targets, samp_idx, t_allscans,covar,cov_all] = compute_target_reg(PRT, in);
0074 end
0075 %[afm]
0076 if isfield(in,'include_allscans') && in.include_allscans   
0077     PRT.model(modelid).input.samp_idx = samp_allscans;
0078     PRT.model(modelid).input.include_allscans = in.include_allscans;
0079 else
0080     PRT.model(modelid).input.samp_idx = samp_idx;
0081     PRT.model(modelid).input.include_allscans = false;
0082 end
0083 PRT.model(modelid).input.targets          = targets;
0084 PRT.model(modelid).input.targ_allscans    = t_allscans;
0085 PRT.model(modelid).input.covar            = covar;
0086 PRT.model(modelid).input.cov_allscans     = cov_all;
0087 
0088 % compute cross-validation matrix and specify operations to apply
0089 % -------------------------------------------------------------------------
0090 if isfield(in.cv,'k')
0091     PRT.model(modelid).input.cv_k=in.cv.k;
0092 else
0093     PRT.model(modelid).input.cv_k = 0;
0094 end  
0095 [CV,ID] = prt_compute_cv_mat(PRT,in, modelid);
0096 PRT.model(modelid).input.cv_mat     = CV;
0097 PRT.model(modelid).input.cv_type=in.cv.type;
0098 % Deal with nested CV parameters
0099 if isfield(in.cv,'type_nested') && ~isempty(in.cv.type_nested)
0100     PRT.model(modelid).input.cv_type_nested = in.cv.type_nested;
0101 end
0102 if isfield(in.cv,'k_nested') && ~isempty(in.cv.k_nested)
0103     PRT.model(modelid).input.cv_k_nested = in.cv.k_nested;
0104 end
0105 if isfield(in.cv,'nested_param') && ~isempty(in.cv.nested_param)
0106     PRT.model(modelid).input.nested_param = in.cv.nested_param;
0107 end
0108 
0109 PRT.model(modelid).input.operations = in.operations;
0110 
0111 % Save PRT.mat
0112 % -------------------------------------------------------------------------
0113 disp('Updating PRT.mat.......>>')
0114 if spm_check_version('MATLAB','7') >= 0
0115     save(in.fname,'-V7','PRT');
0116 else
0117     save(in.fname,'-V6','PRT');
0118 end
0119 
0120 end
0121 
0122 %% -------------------------------------------------------------------------
0123 % Private Functions
0124 % -------------------------------------------------------------------------
0125 
0126 function [targets, samp_idx, t_all samp_all,covar,cov_all] = compute_targets(PRT, in)
0127 % Function to compute the prediction targets. Also does some error checking
0128 
0129 % Set the reference feature set
0130 fid = prt_init_fs(PRT, in.fs(1));
0131 ID  = PRT.fs(fid).id_mat;
0132 n   = size(ID,1);
0133 
0134 % Check the feature sets have the same number of samples (eg for MKL).
0135 if length(in.fs) > 1
0136     for f = 1:length(in.fs)
0137         fid = prt_init_fs(PRT, in.fs(f));
0138         if size(PRT.fs(fid).id_mat,1) ~= n
0139             error('prt_model:sizeOfFeatureSetsDiffer',...
0140                 ['Multiple feature sets included, but they have different ',...
0141                 'numbers of samples']);
0142         end
0143     end
0144 end
0145 
0146 modalities = {PRT.masks(:).mod_name};
0147 groups     = {PRT.group(:).gr_name};
0148 
0149 t_all    = zeros(n,1);
0150 samp_all = zeros(n,1);
0151 cov_all = zeros(n,1);
0152 for c = 1:length(in.class)
0153     
0154     % groups
0155     for g = 1:length(in.class(c).group)
0156         gr_name = in.class(c).group(g).gr_name;
0157         if any(strcmpi(gr_name,groups))
0158             gid = find(strcmpi(gr_name,groups));
0159         else
0160             error('prt_model:groupNotFoundInPRT',...
0161                 ['Group ',gr_name,' not found in PRT.mat']);
0162         end
0163         
0164         % subjects
0165         for s = 1:length(in.class(c).group(g).subj)
0166             sid = in.class(c).group(g).subj(s).num;
0167             % modalities
0168             for m = 1:length(in.class(c).group(g).subj(s).modality)
0169                 mod_name = in.class(c).group(g).subj(s).modality(m).mod_name;
0170                 if any(strcmpi(mod_name,modalities))
0171                     mid = find(strcmpi(mod_name,modalities));
0172                 else
0173                     error('prt_model:groupNotFoundInPRT',...
0174                         ['Modality ',mod_name,' not found in PRT.mat']);
0175                 end
0176                 
0177                 if isfield(in.class(c).group(g).subj(s).modality(m), 'all_scans')
0178                     % check whether this was included in the feature set
0179                     % using 'all conditions' (which is invalid)
0180                     if strcmpi(PRT.fs(fid).modality(m).mode,'all_cond')
0181                         error('prt_model:fsIsAllCondModelisAllScans',...
0182                             ['''All scans'' selected for subject ',num2str(s),...
0183                             ', group ',num2str(g), ', modality ', num2str(m),...
0184                             ' but the feature set was constructed using ',...
0185                             '''All conditions''. This syntax is invalid. ',...
0186                             'Please use ''All Conditions'' instead.']);
0187                     end
0188                     
0189                     % otherwise add all scans for each subject
0190                     %[afm] idx = ID(:,1) == gid & ID(:,2) == s & ID(:,3) == mid;
0191                     idx = ID(:,1) == gid & ID(:,2) == sid & ID(:,3) == mid;
0192                     t_all(idx) = c;
0193                     if any(ismember(in.operations, 5)) %Get covariates
0194                         cov_all(idx) = PRT.group(gid).subject(sid).modality(mid).covar;
0195                     end
0196                 else % conditions have been specified
0197                     % check whether conditions were specified in the design
0198                     if ~isfield(PRT.group(gid).subject(sid).modality(mid).design,'conds')
0199                         error('prt_model:conditionsSpecifiedButNoneInDesign',...
0200                             ['Conditions selected for subject ',num2str(s),...
0201                             ', class ',num2str(c),', group ',num2str(g), ...
0202                             ', modality ', num2str(m),' but there are none in the design. ',...
0203                             'Please use ''All Scans'' or adjust design.']);
0204                     end
0205                     %[afm]sid = in.class(c).group(g).subj(s).num;
0206                     conds     = {PRT.group(gid).subject(sid).modality(mid).design.conds(:).cond_name};
0207                     
0208                     
0209                     if isfield(in.class(c).group(g).subj(s).modality(m), 'all_cond')
0210                         % all conditions
0211                         for cid = 1:length(conds)
0212                             idx = ID(:,1) == gid & ID(:,2) == sid & ID(:,3) == mid & ID(:,4) == cid;
0213                             t_all(idx) = c;
0214                         end
0215                     else % loop over conditions
0216                         for cond = 1:length(in.class(c).group(g).subj(s).modality(m).conds)
0217                             cond_name = in.class(c).group(g).subj(s).modality(m).conds(cond).cond_name;
0218                             
0219                             if any(strcmpi(cond_name,conds))
0220                                 cid = find(strcmpi(cond_name,conds));
0221                             else
0222                                 error('prt_model:groupNotFoundInPRT',...
0223                                     ['Condition ',cond_name,' not found in PRT.mat']);
0224                             end
0225                             
0226                             idx = ID(:,1) == gid & ID(:,2) == sid & ID(:,3) == mid & ID(:,4) == cid;
0227                             t_all(idx) = c;
0228                         end
0229                     end
0230                     s_idx_mod = ID(:,1) == gid & ID(:,2) == sid & ID(:,3) == mid;
0231                     samp_all(s_idx_mod) = 1;
0232                 end
0233             end
0234         end
0235     end
0236 end
0237 
0238 samp_idx = find(t_all);
0239 samp_all = find(samp_all);
0240 targets  = t_all(samp_idx);
0241 covar = cov_all(samp_idx);
0242 end
0243 
0244 
0245 function [targets, samp_idx, targ_allscans,covar,cov_all]=compute_target_reg(PRT, in)
0246 % Function to compute the prediction targets. Not much error checking yet
0247 
0248 % Set the reference feature set
0249 fid = prt_init_fs(PRT, in.fs(1));
0250 ID  = PRT.fs(fid).id_mat;
0251 n   = size(ID,1);
0252 
0253 modalities = {PRT.masks(:).mod_name};
0254 groups     = {PRT.group(:).gr_name};
0255 %t_all = zeros(n,1);
0256 targ_allscans=zeros(n,1);
0257 cov_all = zeros(n,1);
0258 samp_idx=[];
0259 targ_g=[];
0260 covar = [];
0261 for g = 1:length(in.group)
0262     gr_name = in.group(g).gr_name;
0263     if any(strcmpi(gr_name,groups))
0264         gid = find(strcmpi(gr_name,groups));
0265     else
0266         error('prt_model:groupNotFoundInPRT',...
0267             ['Group ',gr_name,' not found in PRT.mat']);
0268     end
0269 %     nmod=length(in.group(g).subj(1).modality);
0270     targets=zeros(1,length(in.group(g).subj)); %replace by nmod for multiple targets per subject
0271     cov=zeros(1,length(in.group(g).subj));
0272     % subjects
0273     for s = 1:length(in.group(g).subj)
0274         %modalities
0275         for m = 1:length(in.group(g).subj(s).modality)
0276             mod_name = in.group(g).subj(s).modality(m).mod_name;
0277             if any(strcmpi(mod_name,modalities))
0278                 mid = find(strcmpi(mod_name,modalities));
0279             else
0280                 error('prt_model:groupNotFoundInPRT',...
0281                     ['Modality ',mod_name,' not found in PRT.mat']);
0282             end
0283             if m==1 %only one regression target per subject, whatever the number of modalities
0284                 idx = in.group(g).subj(s).num;
0285                 if ~isempty(PRT.group(gid).subject(idx).modality(mid).rt_subj)
0286                     targets(m,s) = PRT.group(gid).subject(idx).modality(mid).rt_subj;
0287                 else
0288                     error('prt_model:NoRegressionTarget','No regression target found, correct');
0289                 end
0290                 samp_idx=[samp_idx; find(ID(:,1) == gid & ID(:,2) == idx & ID(:,3) == mid)];
0291                 if any(ismember(in.operations, 5)) %Get covariates
0292                     cov(m,s) = PRT.group(gid).subject(idx).modality(mid).covar;
0293                 end
0294             end
0295         end        
0296     end
0297     targ_g=[targ_g;targets(:)];
0298     covar = [covar;cov(:)];
0299 end
0300 targ_allscans(samp_idx)=targ_g;
0301 targets=targ_g;
0302 cov_all(samp_idx)=covar;
0303 end

Generated on Tue 10-Feb-2015 18:16:33 by m2html © 2005