Home > . > prt_model.m

prt_model

PURPOSE ^

Function to configure and build the PRT.model data structure

SYNOPSIS ^

function PRT = prt_model(PRT,in)

DESCRIPTION ^

 Function to configure and build the PRT.model data structure

 Input:
 ------
   PRT fields:
   model.fs(f).fs_name:     feature set(s) this CV approach is defined for
   model.fs(f).fs_features: feature selection mode ('all' or 'mask')
   model.fs(f).mask_file:   mask for this feature set (fs_features='mask')

   in.fname:      filename for PRT.mat
   in.model_name: name for this cross-validation structure
   in.type:       'classification' or 'regression'
   in.use_kernel: does this model use kernels or features?
   in.operations: operations to apply before prediction

   in.fs(f).fs_name:     feature set(s) this CV approach is defined for

   in.class(c).class_name
   in.class(c).group(g).subj(s).num
   in.class(c).group(g).subj(s).modality(m).mod_name
   EITHER: in.class(c).group(g).subj(s).modality(m).conds(c).cond_name
   OR:     in.class(c).group(g).subj(s).modality(m).all_scans
   OR:     in.class(c).group(g).subj(s).modality(m).all_cond

   in.cv.type:     type of cross-validation ('loso','losgo','custom')
   in.cv.mat_file: file specifying CV matrix (if type='custom');

 Output:
 -------

   This function performs the following functions:
      1. populates basic fields in PRT.model(m).input
      2. computes PRT.model(m).input.targets based on in.class(c)...
      3. computes PRT.model(m).input.samp_idx based on targets
      4. computes PRT.model(m).input.cv_mat based on the labels and CV spec
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function PRT = prt_model(PRT,in)
0002 % Function to configure and build the PRT.model data structure
0003 %
0004 % Input:
0005 % ------
0006 %   PRT fields:
0007 %   model.fs(f).fs_name:     feature set(s) this CV approach is defined for
0008 %   model.fs(f).fs_features: feature selection mode ('all' or 'mask')
0009 %   model.fs(f).mask_file:   mask for this feature set (fs_features='mask')
0010 %
0011 %   in.fname:      filename for PRT.mat
0012 %   in.model_name: name for this cross-validation structure
0013 %   in.type:       'classification' or 'regression'
0014 %   in.use_kernel: does this model use kernels or features?
0015 %   in.operations: operations to apply before prediction
0016 %
0017 %   in.fs(f).fs_name:     feature set(s) this CV approach is defined for
0018 %
0019 %   in.class(c).class_name
0020 %   in.class(c).group(g).subj(s).num
0021 %   in.class(c).group(g).subj(s).modality(m).mod_name
0022 %   EITHER: in.class(c).group(g).subj(s).modality(m).conds(c).cond_name
0023 %   OR:     in.class(c).group(g).subj(s).modality(m).all_scans
0024 %   OR:     in.class(c).group(g).subj(s).modality(m).all_cond
0025 %
0026 %   in.cv.type:     type of cross-validation ('loso','losgo','custom')
0027 %   in.cv.mat_file: file specifying CV matrix (if type='custom');
0028 %
0029 % Output:
0030 % -------
0031 %
0032 %   This function performs the following functions:
0033 %      1. populates basic fields in PRT.model(m).input
0034 %      2. computes PRT.model(m).input.targets based on in.class(c)...
0035 %      3. computes PRT.model(m).input.samp_idx based on targets
0036 %      4. computes PRT.model(m).input.cv_mat based on the labels and CV spec
0037 %__________________________________________________________________________
0038 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0039 
0040 % Written by A Marquand
0041 % $Id: prt_model.m 524 2012-05-10 16:06:12Z schrouff $
0042 
0043 % Populate basic fields in PRT.mat
0044 % -------------------------------------------------------------------------
0045 [modelid, PRT] = prt_init_model(PRT,in);
0046 
0047 % specify model type and feature sets
0048 PRT.model(modelid).input.type = in.type;
0049 if strcmp(in.type,'classification')
0050     for c = 1:length(in.class)
0051         PRT.model(modelid).input.class(c) = in.class(c);
0052     end
0053 end
0054 
0055 for f = 1:length(in.fs)
0056     fid = prt_init_fs(PRT,in.fs(f));
0057     
0058     if length(PRT.fs(fid).modality) > 1 && length(in.fs) > 1
0059         error('prt_model:multipleFeatureSetsAppliedAsSamplesAndAsFeatures',...
0060             ['Feature set ',in.fs(f).fs_name,' contains multiple modalities ',...
0061             'and job specifies that multiple feature sets should be ',...
0062             'supplied to the machine. This usage is not supported.']);
0063     end
0064     
0065     PRT.model(modelid).input.fs(f).fs_name = in.fs(f).fs_name;
0066 end
0067 
0068 % compute targets and samp_idx
0069 % -------------------------------------------------------------------------
0070 if strcmp(in.type,'classification')
0071     [targets, samp_idx, t_allscans, samp_allscans] = compute_targets(PRT, in);
0072 else
0073     [targets, samp_idx, t_allscans] = compute_target_reg(PRT, in);
0074 end
0075 %[afm]
0076 if isfield(in,'include_allscans') && in.include_allscans   
0077     PRT.model(modelid).input.samp_idx = samp_allscans;
0078     PRT.model(modelid).input.include_allscans = in.include_allscans;
0079 else
0080     PRT.model(modelid).input.samp_idx = samp_idx;
0081     PRT.model(modelid).input.include_allscans = false;
0082 end
0083 PRT.model(modelid).input.targets          = targets;
0084 PRT.model(modelid).input.targ_allscans    = t_allscans;
0085 
0086 % compute cross-validation matrix and specify operations to apply
0087 % -------------------------------------------------------------------------
0088 PRT.model(modelid).input.cv_mat     = compute_cv_mat(PRT,in, modelid);
0089 PRT.model(modelid).input.operations = in.operations;
0090 
0091 % Added by Carlton
0092 PRT.model(modelid).input.cv_type=in.cv.type;
0093 % Save PRT.mat
0094 % -------------------------------------------------------------------------
0095 disp('Updating PRT.mat.......>>')
0096 if spm_matlab_version_chk('7') >= 0
0097     save(in.fname,'-V7','PRT');
0098 else
0099     save(in.fname,'-V6','PRT');
0100 end
0101 
0102 end
0103 
0104 % -------------------------------------------------------------------------
0105 % Private Functions
0106 % -------------------------------------------------------------------------
0107 
0108 function [targets, samp_idx, t_all samp_all] = compute_targets(PRT, in)
0109 % Function to compute the prediction targets. Also does some error checking
0110 
0111 % Set the reference feature set
0112 fid = prt_init_fs(PRT, in.fs(1));
0113 ID  = PRT.fs(fid).id_mat;
0114 n   = size(ID,1);
0115 
0116 % Check the feature sets have the same number of samples (eg for MKL).
0117 if length(in.fs) > 1
0118     for f = 1:length(in.fs)
0119         fid = prt_init_fs(PRT, in.fs(f));
0120         if size(PRT.fs(fid).id_mat,1) ~= n
0121             error('prt_model:sizeOfFeatureSetsDiffer',...
0122                 ['Multiple feature sets included, but they have different ',...
0123                 'numbers of samples']);
0124         end
0125     end
0126 end
0127 
0128 modalities = {PRT.masks(:).mod_name};
0129 groups     = {PRT.group(:).gr_name};
0130 
0131 t_all    = zeros(n,1);
0132 samp_all = zeros(n,1);
0133 for c = 1:length(in.class)
0134     
0135     % groups
0136     for g = 1:length(in.class(c).group)
0137         gr_name = in.class(c).group(g).gr_name;
0138         if any(strcmpi(gr_name,groups))
0139             gid = find(strcmpi(gr_name,groups));
0140         else
0141             error('prt_model:groupNotFoundInPRT',...
0142                 ['Group ',gr_name,' not found in PRT.mat']);
0143         end
0144         
0145         % subjects
0146         for s = 1:length(in.class(c).group(g).subj)
0147             sid = in.class(c).group(g).subj(s).num;
0148             % modalities
0149             for m = 1:length(in.class(c).group(g).subj(s).modality)
0150                 mod_name = in.class(c).group(g).subj(s).modality(m).mod_name;
0151                 if any(strcmpi(mod_name,modalities))
0152                     mid = find(strcmpi(mod_name,modalities));
0153                 else
0154                     error('prt_model:groupNotFoundInPRT',...
0155                         ['Modality ',mod_name,' not found in PRT.mat']);
0156                 end
0157                 
0158                 if isfield(in.class(c).group(g).subj(s).modality(m), 'all_scans')
0159                     % check whether this was included in the feature set
0160                     % using 'all conditions' (which is invalid)
0161                     if strcmpi(PRT.fs(fid).modality(m).mode,'all_cond')
0162                         error('prt_model:fsIsAllCondModelisAllScans',...
0163                             ['''All scans'' selected for subject ',num2str(s),...
0164                             ', group ',num2str(g), ', modality ', num2str(m),...
0165                             ' but the feature set was constructed using ',...
0166                             '''All conditions''. This syntax is invalid. ',...
0167                             'Please use ''All Conditions'' instead.']);
0168                     end
0169                     
0170                     % otherwise add all scans for each subject
0171                     %[afm] idx = ID(:,1) == gid & ID(:,2) == s & ID(:,3) == mid;
0172                     idx = ID(:,1) == gid & ID(:,2) == sid & ID(:,3) == mid;
0173                     t_all(idx) = c;
0174                 else % conditions have been specified
0175                     %[afm]sid = in.class(c).group(g).subj(s).num;
0176                     conds     = {PRT.group(gid).subject(sid).modality(mid).design.conds(:).cond_name};
0177                     
0178                     % check whether conditions were specified in the design
0179                     if ~isfield(PRT.group(gid).subject(sid).modality(mid).design,'conds')
0180                         error('prt_model:conditionsSpecifiedButNoneInDesign',...
0181                             ['Conditions selected for subject ',num2str(s),...
0182                             ', class ',num2str(c),', group ',num2str(g), ...
0183                             ', modality ', num2str(m),' but there are none in the design. ',...
0184                             'Please use ''All Scans'' or adjust design.']);
0185                     end
0186                     if isfield(in.class(c).group(g).subj(s).modality(m), 'all_cond')
0187                         % all conditions
0188                         for cid = 1:length(conds)
0189                             idx = ID(:,1) == gid & ID(:,2) == sid & ID(:,3) == mid & ID(:,4) == cid;
0190                             t_all(idx) = c;
0191                         end
0192                     else % loop over conditions
0193                         for cond = 1:length(in.class(c).group(g).subj(s).modality(m).conds)
0194                             cond_name = in.class(c).group(g).subj(s).modality(m).conds(cond).cond_name;
0195                             
0196                             if any(strcmpi(cond_name,conds))
0197                                 cid = find(strcmpi(cond_name,conds));
0198                             else
0199                                 error('prt_model:groupNotFoundInPRT',...
0200                                     ['Condition ',cond_name,' not found in PRT.mat']);
0201                             end
0202                             
0203                             idx = ID(:,1) == gid & ID(:,2) == sid & ID(:,3) == mid & ID(:,4) == cid;
0204                             t_all(idx) = c;
0205                         end
0206                     end
0207                     s_idx_mod = ID(:,1) == gid & ID(:,2) == sid & ID(:,3) == mid;
0208                     samp_all(s_idx_mod) = 1;
0209                 end
0210             end
0211         end
0212     end
0213 end
0214 
0215 samp_idx = find(t_all);
0216 samp_all = find(samp_all);
0217 targets  = t_all(samp_idx);
0218 
0219 end
0220 
0221 function CV = compute_cv_mat(PRT, in, modelid)
0222 % Function to compute the cross-validation matrix. Also does error checking
0223 
0224 fid = prt_init_fs(PRT, in.fs(1));
0225 
0226 if isfield(in,'include_allscans') && in.include_allscans
0227     % use the full id matrix
0228     ID = PRT.fs(fid).id_mat;
0229 else
0230     % id matrix only contains samples within the CV structure
0231     % it is initialised in prt_init_fs. The columns contents are described
0232     % in PRT.fs(fid).id_col_names
0233     % ('group','subject','modality','condition','block','scan')
0234     ID = PRT.fs(fid).id_mat(PRT.model(modelid).input.samp_idx,:);
0235 end
0236 
0237 switch in.cv.type
0238     case 'loso'
0239         % leave-one-subject-out
0240         % give each subject a unique id
0241         gids = unique(ID(:,1));
0242         gc = 0;
0243         for g = 1:length(gids)
0244             gidx = ID(:,1) == gids(g);
0245             ID(gidx,2) = ID(gidx,2) + gc;
0246             gc = gc + max(ID(gidx,2));
0247         end
0248         
0249         % Compute CV matrix
0250         snums = histc(ID(:,2),unique(ID(:,2)));
0251         if length(snums) == 1
0252             error('prt_model:losoSelectedWithOneSubject',...
0253             'LOSO CV selected but only one subject is included');
0254         end
0255         %snums = accumarray(ID(:,2),1);
0256         G = cell(length(snums),1);
0257         for s = 1:length(snums)
0258             G{s} = ones(snums(s),1);
0259         end
0260         CV = blkdiag(G{:}) + 1;
0261         
0262     case 'losgo'
0263         % leave-one-subject-per-group-out
0264         sids = unique(ID(:,2));
0265         if length(sids) == 1
0266             error('prt_model:losoSelectedWithOneSubject',...
0267             'LOSGO CV selected but only one subject is included');
0268         end
0269         
0270         CV = zeros(size(ID,1),length(sids));
0271         for s = 1:length(sids)
0272             sidx = ID(:,2) == sids(s);
0273             CV(:,s) = double(sidx) + 1;
0274         end
0275         
0276     case 'lobo'
0277         % leave-one-block-out - limited to one single subject for the
0278         % moment
0279         % blocks already have a unique ID
0280         snums = histc(ID(:,5),unique(ID(:,5))); % how many scans per block
0281         G = cell(length(snums),1);
0282         for s = 1:length(snums)
0283             G{s} = ones(snums(s),1);
0284         end
0285         CV = blkdiag(G{:}) + 1;
0286         
0287     case 'locbo'
0288         % leave-one-condition-per-block-out
0289         error('leave-one-condition-per-block-out not yet implemented');
0290         
0291     case 'loro'
0292         % leave-one-run-out
0293         
0294         mids = unique(ID(:,3));
0295         
0296         CV = zeros(size(ID,1),length(mids));
0297         for m = 1:length(mids)
0298             midx = ID(:,3) == mids(m);
0299             CV(:,m) = double(midx) + 1;
0300         end
0301 
0302     case 'custom'
0303         error('custom CV not implemented yet');
0304         
0305     otherwise
0306         error('prt_cv:unknownTypeSpecified',...
0307             ['Unknown type specified for CV structure (',in.type',')']);
0308 end
0309 
0310 end
0311 
0312 function [targets, samp_idx, targ_allscans]=compute_target_reg(PRT, in)
0313 % Function to compute the prediction targets. Not much error checking yet
0314 
0315 % Set the reference feature set
0316 fid = prt_init_fs(PRT, in.fs(1));
0317 ID  = PRT.fs(fid).id_mat;
0318 n   = size(ID,1);
0319 
0320 modalities = {PRT.masks(:).mod_name};
0321 groups     = {PRT.group(:).gr_name};
0322 %t_all = zeros(n,1);
0323 targ_allscans=zeros(n,1);
0324 samp_idx=[];
0325 targ_g=[];
0326 for g = 1:length(in.group)
0327     gr_name = in.group(g).gr_name;
0328     if any(strcmpi(gr_name,groups))
0329         gid = find(strcmpi(gr_name,groups));
0330     else
0331         error('prt_model:groupNotFoundInPRT',...
0332             ['Group ',gr_name,' not found in PRT.mat']);
0333     end
0334     targets=zeros(length(in.group(g).subj),1);
0335     % subjects
0336     for s = 1:length(in.group(g).subj)
0337         % modalities, currently only one is allowed
0338         m=1;
0339         mod_name = in.group(g).subj(s).modality(m).mod_name;
0340         if any(strcmpi(mod_name,modalities))
0341             mid = find(strcmpi(mod_name,modalities));
0342         else
0343             error('prt_model:modalityNotFoundInPRT',...
0344                 ['Modality ',mod_name,' not found in PRT.mat']);
0345         end
0346         idx = in.group(g).subj(s).num;
0347         targets(s) = PRT.group(gid).subject(idx).modality(mid).rt_subj;
0348         samp_idx=[samp_idx; find(ID(:,1) == gid & ID(:,2) == idx & ID(:,3) == mid)];
0349         
0350     end
0351     targ_g=[targ_g;targets];
0352 end
0353 targ_allscans(samp_idx)=targ_g;
0354 targets=targ_g;
0355 end

Generated on Sun 20-May-2012 13:24:48 by m2html © 2005