Home > machines > prt_machine.m

prt_machine

PURPOSE ^

Run machine function for classification or regression

SYNOPSIS ^

function output = prt_machine(d,m)

DESCRIPTION ^

 Run machine function for classification or regression
 FORMAT output = prt_machine(d,m)
 Inputs:
   d            - structure with information about the data, with fields:
    Mandatory fields:
    .train      - training data (cell array of matrices of row vectors,
                  each [Ntr x D]). each matrix contains one representation
                  of the data. This is useful for approaches such as
                  multiple kernel learning.
    .test       - testing data  (cell array of matrices row vectors, each
                  [Nte x D])
    .tr_targets - training labels (for classification) or values (for
                  regression) (column vector, [Ntr x 1])
    .use_kernel - flag, is data in form of kernel matrices (true) or in 
                  form of features (false)
    Optional fields: the machine is respnsible for dealing with this
                  optional fields (e.g. d.testcov)
   m            - structure with information about the classification or
                  regression machine to use, with fields:
      .function - function for classification or regression (string)
      .args     - function arguments (either a string, a matrix, or a
                  struct). This is specific to each machine, e.g. for
                  an L2-norm linear SVM this could be the C parameter
 Output:
    output      - output of machine (struct).
       Mandatory fields:
       .predictions - predictions of classification or regression
                      [Nte x D]
       Optional fields: the machine is responsible for returning
       parameters of interest. For exemple for an SVM this could be the
       number of support vector used in the hyperplane weights computation
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function output = prt_machine(d,m)
0002 % Run machine function for classification or regression
0003 % FORMAT output = prt_machine(d,m)
0004 % Inputs:
0005 %   d            - structure with information about the data, with fields:
0006 %    Mandatory fields:
0007 %    .train      - training data (cell array of matrices of row vectors,
0008 %                  each [Ntr x D]). each matrix contains one representation
0009 %                  of the data. This is useful for approaches such as
0010 %                  multiple kernel learning.
0011 %    .test       - testing data  (cell array of matrices row vectors, each
0012 %                  [Nte x D])
0013 %    .tr_targets - training labels (for classification) or values (for
0014 %                  regression) (column vector, [Ntr x 1])
0015 %    .use_kernel - flag, is data in form of kernel matrices (true) or in
0016 %                  form of features (false)
0017 %    Optional fields: the machine is respnsible for dealing with this
0018 %                  optional fields (e.g. d.testcov)
0019 %   m            - structure with information about the classification or
0020 %                  regression machine to use, with fields:
0021 %      .function - function for classification or regression (string)
0022 %      .args     - function arguments (either a string, a matrix, or a
0023 %                  struct). This is specific to each machine, e.g. for
0024 %                  an L2-norm linear SVM this could be the C parameter
0025 % Output:
0026 %    output      - output of machine (struct).
0027 %       Mandatory fields:
0028 %       .predictions - predictions of classification or regression
0029 %                      [Nte x D]
0030 %       Optional fields: the machine is responsible for returning
0031 %       parameters of interest. For exemple for an SVM this could be the
0032 %       number of support vector used in the hyperplane weights computation
0033 %__________________________________________________________________________
0034 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0035 
0036 % Written by M.J.Rosa and J.Richiardi
0037 % $Id$
0038 
0039 % TODO: make tr_targets a cell array (?)
0040 % TODO: fix 80-cols limit in source code
0041 % TODO: Multi-kernel learning
0042 
0043 SANITYCHECK = true; % can turn off for "speed"
0044 
0045 %% INPUT CHECKS
0046 %--------------------------------------------------------------------------
0047 fnch   = str2func(m.function);
0048 if SANITYCHECK==true
0049     % Check machine struct properties
0050     if ~isempty(m)
0051         if isstruct(m)
0052             if isfield(m,'function')
0053                 % TODO: This case maybe needs more cautious handling
0054                 if ~exist(m.function,'file')
0055                     error('prt_machine:machineFunctionFileNotFound',...
0056                         ['Error: %s function could not be found!'],...
0057                         m.function);
0058                 end
0059             else
0060                 error('prt_machine:machineFunctionFieldNotFound',...
0061                     ['Error: machine structure should contain'...
0062                     ' ''.function'' field!']);
0063             end
0064             if ~isfield(m,'args')
0065                 error('prt_machine:argsFieldNotFound',...
0066                     ['Error: machine structure should contain' ...
0067                     ' ''.args'' field!']);
0068             end
0069         else
0070             error('prt_machine:machineNotStruct',...
0071                 'Error: machine should be a structure!');
0072         end
0073     else
0074         error('prt_machine:machineStructEmpty',...
0075             'Error: ''machine'' struct cannot be empty!');
0076     end
0077     
0078     %----------------------------------------------------------------------
0079     % Check data struct properties
0080     if ~isempty(d)
0081         % 1: BASIC: check all mandatory fields exist so we can relax later
0082         if ~isfield(d,'train')
0083             error('prt_machine:missingField_train',...
0084                 ['Error: ''data'' struct must contain a ''train'' '...
0085                 ' field!']);
0086         end
0087         if ~isfield(d,'test')
0088             error('prt_machine:missingField_test',...
0089                 ['Error: ''data'' struct must contain a ''test'' '...
0090                 ' field!']);
0091         end
0092         if ~isfield(d,'tr_targets')
0093             error('prt_machine:missingField_tr_targets',...
0094                 ['Error: ''data'' struct must contain a ''tr_targets'' '...
0095                 ' field!']);
0096         end
0097         if ~isfield(d,'use_kernel')
0098             error('prt_machine:missingField_use_kernel',...
0099                 ['Error: ''data'' struct must contain a ''use_kernel'' '...
0100                 ' field!']);
0101         end
0102         if ~isfield(d,'pred_type')
0103             error('prt_machine:missingField_pred_type',...
0104                 ['Error: ''data'' struct must contain a ''pred_type'' '...
0105                 ' field!']);
0106         end
0107         
0108         
0109         % 2: BASIC: check datatype of train/test sets
0110         if isempty(d.train) || isempty(d.test),
0111             error('prt_machine:TrAndTeEmpty',...
0112                 'Error: training and testing data cannot be empty!');
0113         else
0114             if ~iscell(d.train) || ~iscell(d.test),
0115                 error('prt_machine:TrAndTeEmpty',...
0116                     'Error: training and testing data should be cell arrays!');
0117             end
0118         end
0119         
0120         % 3: BASIC: check datatypes of labels
0121         if ~isempty(d.tr_targets)
0122             if isvector(d.tr_targets)
0123                 % force targets to column vectors
0124                 d.tr_targets   = d.tr_targets(:);
0125                 Ntrain_lbs = length(d.tr_targets);
0126             else
0127                 error('prt_machine:trainingLabelsNotVector',...
0128                     'Error: training labels should be a vector!');
0129             end
0130         else
0131             error('prt_machine:trainingLabelsEmpty',...
0132                 'Error: training labels cannot be empty!');
0133         end
0134         
0135         % 4: Check data properties (over cells)
0136         Nk_train   = length(d.train);
0137         
0138         % 5: Check if data has more than one cell
0139         if isempty(strfind(char(fnch),'MKL')) && Nk_train > 1
0140             %Check that if multiple kernels, MKL was selected,
0141             %otherwise add the kernels
0142             tr_tmp = zeros(size(d.train{1}));
0143             te_tmp = zeros(size(d.test{1}));
0144             tecov_tmp = zeros(size(d.testcov{1}));
0145             for j=1:Nk_train
0146                 try
0147                     %add kernels
0148                     tp = d.train{j}; %train set
0149                     tr_tmp=tr_tmp + tp;
0150                     tp = d.test{j}; %test set
0151                     te_tmp=te_tmp + tp;
0152                     tp = d.testcov{j}; %test set covariance matrix for GP
0153                     tecov_tmp=tecov_tmp + tp;
0154                 catch
0155                     error('prt_cv_model:KernelsWithDifferentDimensions', ...
0156                         'Kernels cannot be added since they have different dimensions')
0157                 end
0158             end
0159             d.train = {tr_tmp};
0160             d.test = {te_tmp};
0161             d.testcov = {tecov_tmp};
0162             Nk_train = 1;
0163             clear tr_tmp te_tmp tecov_tmp     
0164         end
0165         
0166         %6: Check validity of machines chosen.(e.g. use SVM to do
0167         %regression is not valid
0168         if  strcmp(d.pred_type,'regression') 
0169             if ~any(strcmp(m.function,{'prt_machine_krr','prt_machine_rvr',...
0170                                        'prt_machine_gpml','prt_machine_gpr', 'prt_machine_sMKL_reg'}))
0171                 error('prt_machine:RgressionMachineSupport',...
0172                     'Error: Regresion can only chose use KRR or RVR machines');
0173             end
0174         end
0175         
0176         % 7: Check datasets properties (within cells)
0177         for k = 1:Nk_train,
0178             if ~isempty(d.train{k}) && ~isempty(d.test{k})
0179                 if (~prt_ismatrix(d.train{k}) && ~isvector(d.train{k}) ) || ...
0180                         (~prt_ismatrix(d.test{k}) && ~isvector(d.test{k}) )
0181                     error('prt_machine:TrAndTeNotMatrices',...
0182                         ['Error: training and testing datasets should ' ...
0183                         ' be either matrices or vectors!']);
0184                 end
0185             else
0186                 error('prt_machine:TrAndTeEmpty',...
0187                     'Error: training and testing datasest cannot be empty!');
0188             end
0189             % check dimensions
0190             [Ntrain Dtrain] = size(d.train{k});
0191             [Ntest, Dtest]  = size(d.test{k});
0192             % a: feature space dimension should be equal
0193             if ~(Dtrain==Dtest)
0194                 error('prt_machine:DtrNotEqDte',['Error: Training and testing '...
0195                     'dimensions should match, but Dtrain=%d and Dtest=%d for '...
0196                     'dataset %d!'],Dtrain,Dtest,k);
0197             end
0198             % b: check we have as many training labels as examples
0199             if ~(Ntrain_lbs==Ntrain)
0200                 error('prt_machine:NtrlbsNotEqNtr',['Error: Number of training '...
0201                     'examples and training labels should match, but Ntrain_lbs=%d '...
0202                     'and Ntrain=%d for dataset %d!'],Ntrain_lbs,Ntrain,k);
0203             end
0204             % c: if kernel check for kernel properties
0205             if d.use_kernel
0206                 if ~(Ntrain==Dtrain)
0207                     error('prt_machine:NtrainNotEqDtrain',['Error: Training '...
0208                         'dimensions should match, but Ntr=%d and Dtr=%d for '...
0209                         'dataset %d!'],Ntrain,Dtrain,k);
0210                 end
0211                 if ~(Dtest==Ntrain)
0212                     error('prt_machine:DtestNotEqNtrain',['Error: Testing '...
0213                         'dimensions should match, but Dte=%d and Ntr=%d for '...
0214                         'dataset %d!'],Dtest,Ntrain,k);
0215                 end    
0216             end
0217         end
0218     else
0219         error('prt_machine:dataStructEmpty',...
0220             'Error: data struct cannot be empty!');
0221     end
0222 end % SANITYCHECK
0223 
0224 %% Run model
0225 %--------------------------------------------------------------------------
0226 try
0227     output = fnch(d,m.args);
0228 catch
0229     err = lasterror;
0230     err_ID=lower(err.identifier);
0231     err_libProblem = strfind(err_ID,'libnotfound');
0232     err_argsProblem = strfind(err_ID,'argsproblem');
0233     disp('prt_machine: machine did not run sucessfully.');
0234     if ~isempty(err_libProblem)
0235         error('prt_machine:libNotFound',['Error: the library for '...
0236             'machine %s could not be found on your path. '],m.function);
0237     elseif ~isempty(err_argsProblem)
0238         disp(['Error: the arguments supplied '...
0239             ' are invalid. ' ...
0240             'SOLUTION: Please follow the advice given by the machine.']);
0241         error('prt_machine:argsProblem',...
0242             'Error running machine %s: %s %s', ...
0243             m.function,err.identifier,err.message);
0244     else
0245         % we don't know what more to do here, pass it up
0246         disp(['SOLUTION: Please read the message below and attempt to' ...
0247             ' correct the problem, or ask the developpers for ' ...
0248             'assistance by copy-pasting all messages and explaining the'...
0249             ' exact steps that led to the problem.']);
0250         disp(['These kinds of issues are typically caused by Matlab '...
0251             'path problems.']);
0252         for en=numel(err.stack):-1:1
0253             e=err.stack(en);
0254             fprintf('%d : function [%s] in file [%s] at line [%d]\n',...
0255                 en,e.name,e.file,e.line);
0256         end
0257         error('prt_machine:otherProblem',...
0258             'Error running machine %s: %s %s', ...
0259             m.function,err.identifier,err.message);
0260     end
0261 end
0262 
0263 %% OUTPUT CHECKS
0264 %--------------------------------------------------------------------------
0265 if SANITYCHECK==true
0266     
0267     % Check output properties
0268     if ~isfield(output,'predictions');
0269         error('prt_machine:outputNoPredictions',['Output of machine should '...
0270             'contain the field ''.predictions''.']);
0271     else
0272         % FIXME: multiple kernels / feature representations is unsupported
0273         % here
0274         % [afm] removed to test glm approach
0275         %if (size(output.predictions,1)~= Ntest)
0276         %    error('prt_machine:outputNpredictionsNotEqNte',['Error: Number '...
0277         %        'of predictions output and number of test examples should '...
0278         %        'match, but Npre=%d and Nte=%d !'],...
0279         %        size(output.predictions,1),Ntest);
0280         %end
0281     end
0282     
0283 end % SANITYCHECK on output
0284 
0285 end
0286 
0287 %% local functions
0288 function out = prt_ismatrix(A)
0289 % ismatrix was not a built-in in Matlab 7.1, so do a homebrew
0290 % implementation (based on Dan Vimont's Matlab libs at
0291 % http://www.aos.wisc.edu/~dvimont/matlab but with short-circuit AND for
0292 % "speed")
0293 out=(ndims(A)==2) && (min(size(A)) ~= 1); % enable stricter check - a struct array should NOT pass.
0294 end

Generated on Tue 10-Feb-2015 18:16:33 by m2html © 2005