Home > machines > prt_machine.m

prt_machine

PURPOSE ^

Run machine function for classification or regression

SYNOPSIS ^

function output = prt_machine(d,m)

DESCRIPTION ^

 Run machine function for classification or regression
 FORMAT output = prt_machine(d,m)
 Inputs:
   d            - structure with information about the data, with fields:
    Mandatory fields:
    .train      - training data (cell array of matrices of row vectors,
                  each [Ntr x D]). each matrix contains one representation
                  of the data. This is useful for approaches such as
                  multiple kernel learning.
    .test       - testing data  (cell array of matrices row vectors, each
                  [Nte x D])
    .tr_targets - training labels (for classification) or values (for
                  regression) (column vector, [Ntr x 1])
    .use_kernel - flag, is data in form of kernel matrices (true) or in 
                  form of features (false)
    Optional fields: the machine is respnsible for dealing with this
                  optional fields (e.g. d.testcov)
   m            - structure with information about the classification or
                  regression machine to use, with fields:
      .function - function for classification or regression (string)
      .args     - function arguments (either a string, a matrix, or a
                  struct). This is specific to each machine, e.g. for
                  an L2-norm linear SVM this could be the C parameter
 Output:
    output      - output of machine (struct).
       Mandatory fields:
       .predictions - predictions of classification or regression
                      [Nte x D]
       Optional fields: the machine is responsible for returning
       parameters of interest. For exemple for an SVM this could be the
       number of support vector used in the hyperplane weights computation
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function output = prt_machine(d,m)
0002 % Run machine function for classification or regression
0003 % FORMAT output = prt_machine(d,m)
0004 % Inputs:
0005 %   d            - structure with information about the data, with fields:
0006 %    Mandatory fields:
0007 %    .train      - training data (cell array of matrices of row vectors,
0008 %                  each [Ntr x D]). each matrix contains one representation
0009 %                  of the data. This is useful for approaches such as
0010 %                  multiple kernel learning.
0011 %    .test       - testing data  (cell array of matrices row vectors, each
0012 %                  [Nte x D])
0013 %    .tr_targets - training labels (for classification) or values (for
0014 %                  regression) (column vector, [Ntr x 1])
0015 %    .use_kernel - flag, is data in form of kernel matrices (true) or in
0016 %                  form of features (false)
0017 %    Optional fields: the machine is respnsible for dealing with this
0018 %                  optional fields (e.g. d.testcov)
0019 %   m            - structure with information about the classification or
0020 %                  regression machine to use, with fields:
0021 %      .function - function for classification or regression (string)
0022 %      .args     - function arguments (either a string, a matrix, or a
0023 %                  struct). This is specific to each machine, e.g. for
0024 %                  an L2-norm linear SVM this could be the C parameter
0025 % Output:
0026 %    output      - output of machine (struct).
0027 %       Mandatory fields:
0028 %       .predictions - predictions of classification or regression
0029 %                      [Nte x D]
0030 %       Optional fields: the machine is responsible for returning
0031 %       parameters of interest. For exemple for an SVM this could be the
0032 %       number of support vector used in the hyperplane weights computation
0033 %__________________________________________________________________________
0034 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0035 
0036 % Written by M.J.Rosa and J.Richiardi
0037 % $Id: prt_machine.m 498 2012-04-05 13:26:23Z amarquan $
0038 
0039 % TODO: make tr_targets a cell array (?)
0040 % TODO: fix 80-cols limit in source code
0041 % TODO: Multi-kernel learning
0042 
0043 SANITYCHECK = true; % can turn off for "speed"
0044 
0045 %% INPUT CHECKS
0046 %--------------------------------------------------------------------------
0047 if SANITYCHECK==true
0048     % Check machine struct properties
0049     if ~isempty(m)
0050         if isstruct(m)
0051             if isfield(m,'function')
0052                 % TODO: This case maybe needs more cautious handling
0053                 if ~exist(m.function,'file')
0054                     error('prt_machine:machineFunctionFileNotFound',...
0055                         ['Error: %s function could not be found!'],...
0056                         m.function);
0057                 end
0058             else
0059                 error('prt_machine:machineFunctionFieldNotFound',...
0060                     ['Error: machine structure should contain'...
0061                     ' ''.function'' field!']);
0062             end
0063             if ~isfield(m,'args')
0064                 error('prt_machine:argsFieldNotFound',...
0065                     ['Error: machine structure should contain' ...
0066                     ' ''.args'' field!']);
0067             end
0068         else
0069             error('prt_machine:machineNotStruct',...
0070                 'Error: machine should be a structure!');
0071         end
0072     else
0073         error('prt_machine:machineStructEmpty',...
0074             'Error: ''machine'' struct cannot be empty!');
0075     end
0076     
0077     %----------------------------------------------------------------------
0078     % Check data struct properties
0079     if ~isempty(d)
0080         % 1: BASIC: check all mandatory fields exist so we can relax later
0081         if ~isfield(d,'train')
0082             error('prt_machine:missingField_train',...
0083                 ['Error: ''data'' struct must contain a ''train'' '...
0084                 ' field!']);
0085         end
0086         if ~isfield(d,'test')
0087             error('prt_machine:missingField_test',...
0088                 ['Error: ''data'' struct must contain a ''test'' '...
0089                 ' field!']);
0090         end
0091         if ~isfield(d,'tr_targets')
0092             error('prt_machine:missingField_tr_targets',...
0093                 ['Error: ''data'' struct must contain a ''tr_targets'' '...
0094                 ' field!']);
0095         end
0096         if ~isfield(d,'use_kernel')
0097             error('prt_machine:missingField_use_kernel',...
0098                 ['Error: ''data'' struct must contain a ''use_kernel'' '...
0099                 ' field!']);
0100         end
0101         if ~isfield(d,'pred_type')
0102             error('prt_machine:missingField_pred_type',...
0103                 ['Error: ''data'' struct must contain a ''pred_type'' '...
0104                 ' field!']);
0105         end
0106         
0107         
0108         % 2: BASIC: check datatype of train/test sets
0109         if isempty(d.train) || isempty(d.test),
0110             error('prt_machine:TrAndTeEmpty',...
0111                 'Error: training and testing data cannot be empty!');
0112         else
0113             if ~iscell(d.train) || ~iscell(d.test),
0114                 error('prt_machine:TrAndTeEmpty',...
0115                     'Error: training and testing data should be cell arrays!');
0116             end
0117         end
0118         
0119         % 3: BASIC: check datatypes of labels
0120         if ~isempty(d.tr_targets)
0121             if isvector(d.tr_targets)
0122                 % force targets to column vectors
0123                 d.tr_targets   = d.tr_targets(:);
0124                 Ntrain_lbs = length(d.tr_targets);
0125             else
0126                 error('prt_machine:trainingLabelsNotVector',...
0127                     'Error: training labels should be a vector!');
0128             end
0129         else
0130             error('prt_machine:trainingLabelsEmpty',...
0131                 'Error: training labels cannot be empty!');
0132         end
0133         
0134         % 4: Check data properties (over cells)
0135         Nk_train   = length(d.train);
0136         
0137         % 5: Check if data has more than one cell
0138         if Nk_train > 1
0139              error('prt_machine:MKLnotSupported',...
0140                     'Error: Multi-kernel learning not supported yet!');
0141         end 
0142         
0143         %6: Check validity of machines chosen.(e.g. use SVM to do
0144         %regression is not valid
0145         if  strcmp(d.pred_type,'regression') 
0146             if ~any(strcmp(m.function,{'prt_machine_krr','prt_machine_rvr',...
0147                                        'prt_machine_gpml','prt_machine_gpr'}))
0148                 error('prt_machine:RgressionMachineSupport',...
0149                     'Error: Regresion can only chose use KRR or RVR machines');
0150             end
0151         end
0152         
0153         % 7: Check datasets properties (within cells)
0154         for k = 1:Nk_train,
0155             if ~isempty(d.train{k}) && ~isempty(d.test{k})
0156                 if (~prt_ismatrix(d.train{k}) && ~isvector(d.train{k}) ) || ...
0157                         (~prt_ismatrix(d.test{k}) && ~isvector(d.test{k}) )
0158                     error('prt_machine:TrAndTeNotMatrices',...
0159                         ['Error: training and testing datasets should ' ...
0160                         ' be either matrices or vectors!']);
0161                 end
0162             else
0163                 error('prt_machine:TrAndTeEmpty',...
0164                     'Error: training and testing datasest cannot be empty!');
0165             end
0166             % check dimensions
0167             [Ntrain Dtrain] = size(d.train{k});
0168             [Ntest, Dtest]  = size(d.test{k});
0169             % a: feature space dimension should be equal
0170             if ~(Dtrain==Dtest)
0171                 error('prt_machine:DtrNotEqDte',['Error: Training and testing '...
0172                     'dimensions should match, but Dtrain=%d and Dtest=%d for '...
0173                     'dataset %d!'],Dtrain,Dtest,k);
0174             end
0175             % b: check we have as many training labels as examples
0176             if ~(Ntrain_lbs==Ntrain)
0177                 error('prt_machine:NtrlbsNotEqNtr',['Error: Number of training '...
0178                     'examples and training labels should match, but Ntrain_lbs=%d '...
0179                     'and Ntrain=%d for dataset %d!'],Ntrain_lbs,Ntrain,k);
0180             end
0181             % c: if kernel check for kernel properties
0182             if d.use_kernel
0183                 if ~(Ntrain==Dtrain)
0184                     error('prt_machine:NtrainNotEqDtrain',['Error: Training '...
0185                         'dimensions should match, but Ntr=%d and Dtr=%d for '...
0186                         'dataset %d!'],Ntrain,Dtrain,k);
0187                 end
0188                 if ~(Dtest==Ntrain)
0189                     error('prt_machine:DtestNotEqNtrain',['Error: Testing '...
0190                         'dimensions should match, but Dte=%d and Ntr=%d for '...
0191                         'dataset %d!'],Dtest,Ntrain,k);
0192                 end    
0193             end
0194         end
0195     else
0196         error('prt_machine:dataStructEmpty',...
0197             'Error: data struct cannot be empty!');
0198     end
0199 end % SANITYCHECK
0200 
0201 %% Run model
0202 %--------------------------------------------------------------------------
0203 fnch   = str2func(m.function);
0204 
0205 try
0206     output = fnch(d,m.args);
0207 catch
0208     err = lasterror;
0209     err_ID=lower(err.identifier);
0210     err_libProblem = strfind(err_ID,'libnotfound');
0211     err_argsProblem = strfind(err_ID,'argsproblem');
0212     disp('prt_machine: machine did not run sucessfully.');
0213     if ~isempty(err_libProblem)
0214         error('prt_machine:libNotFound',['Error: the library for '...
0215             'machine %s could not be found on your path. '],m.function);
0216     elseif ~isempty(err_argsProblem)
0217         disp(['Error: the arguments supplied '...
0218             ' are invalid. ' ...
0219             'SOLUTION: Please follow the advice given by the machine.']);
0220         error('prt_machine:argsProblem',...
0221             'Error running machine %s: %s %s', ...
0222             m.function,err.identifier,err.message);
0223     else
0224         % we don't know what more to do here, pass it up
0225         disp(['SOLUTION: Please read the message below and attempt to' ...
0226             ' correct the problem, or ask the developpers for ' ...
0227             'assistance by copy-pasting all messages and explaining the'...
0228             ' exact steps that led to the problem.']);
0229         disp(['These kinds of issues are typically caused by Matlab '...
0230             'path problems.']);
0231         for en=numel(err.stack):-1:1
0232             e=err.stack(en);
0233             fprintf('%d : function [%s] in file [%s] at line [%d]\n',...
0234                 en,e.name,e.file,e.line);
0235         end
0236         error('prt_machine:otherProblem',...
0237             'Error running machine %s: %s %s', ...
0238             m.function,err.identifier,err.message);
0239     end
0240 end
0241 
0242 %% OUTPUT CHECKS
0243 %--------------------------------------------------------------------------
0244 if SANITYCHECK==true
0245     
0246     % Check output properties
0247     if ~isfield(output,'predictions');
0248         error('prt_machine:outputNoPredictions',['Output of machine should '...
0249             'contain the field ''.predictions''.']);
0250     else
0251         % FIXME: multiple kernels / feature representations is unsupported
0252         % here
0253         % [afm] removed to test glm approach
0254         %if (size(output.predictions,1)~= Ntest)
0255         %    error('prt_machine:outputNpredictionsNotEqNte',['Error: Number '...
0256         %        'of predictions output and number of test examples should '...
0257         %        'match, but Npre=%d and Nte=%d !'],...
0258         %        size(output.predictions,1),Ntest);
0259         %end
0260     end
0261     
0262 end % SANITYCHECK on output
0263 
0264 end
0265 
0266 %% local functions
0267 function out = prt_ismatrix(A)
0268 % ismatrix was not a built-in in Matlab 7.1, so do a homebrew
0269 % implementation (based on Dan Vimont's Matlab libs at
0270 % http://www.aos.wisc.edu/~dvimont/matlab but with short-circuit AND for
0271 % "speed")
0272 out=(ndims(A)==2) && (min(size(A)) ~= 1); % enable stricter check - a struct array should NOT pass.
0273 end

Generated on Sun 20-May-2012 13:24:48 by m2html © 2005