Home > machines > prt_machine_RT_bin.m

prt_machine_RT_bin

PURPOSE ^

Run binary Ensemble of Regression Tree - wrapper for Pierre Geurt's

SYNOPSIS ^

function output = prt_machine_RT_bin(d,args)

DESCRIPTION ^

 Run binary Ensemble of Regression Tree - wrapper for Pierre Geurt's
 RT code
 FORMAT output =  prt_machine_RT_bin(d,args)
 Inputs:
   d         - structure with data information, with mandatory fields:
     .train      - training data (cell array of matrices of row vectors,
                   each [Ntr x D]). each matrix contains one representation
                   of the data. This is useful for approaches such as
                   multiple kernel learning.
     .test       - testing data  (cell array of matrices row vectors, each
                   [Nte x D])
     .tr_targets - training labels (for classification) or values (for
                   regression) (column vector, [Ntr x 1])
     .use_kernel - flag, is data in form of kernel matrices (true) of in 
                form of features (false)
    args    - vector of RT arguments
       args(1) - number of trees (default: 501)
 Output:
    output  - output of machine (struct).
     * Mandatory fields:
      .predictions - predictions of classification or regression [Nte x D]
     * Optional fields:
      .func_val - value of the decision function
      .type     - which type of machine this is (here, 'classifier')
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function output = prt_machine_RT_bin(d,args)
0002 % Run binary Ensemble of Regression Tree - wrapper for Pierre Geurt's
0003 % RT code
0004 % FORMAT output =  prt_machine_RT_bin(d,args)
0005 % Inputs:
0006 %   d         - structure with data information, with mandatory fields:
0007 %     .train      - training data (cell array of matrices of row vectors,
0008 %                   each [Ntr x D]). each matrix contains one representation
0009 %                   of the data. This is useful for approaches such as
0010 %                   multiple kernel learning.
0011 %     .test       - testing data  (cell array of matrices row vectors, each
0012 %                   [Nte x D])
0013 %     .tr_targets - training labels (for classification) or values (for
0014 %                   regression) (column vector, [Ntr x 1])
0015 %     .use_kernel - flag, is data in form of kernel matrices (true) of in
0016 %                form of features (false)
0017 %    args    - vector of RT arguments
0018 %       args(1) - number of trees (default: 501)
0019 % Output:
0020 %    output  - output of machine (struct).
0021 %     * Mandatory fields:
0022 %      .predictions - predictions of classification or regression [Nte x D]
0023 %     * Optional fields:
0024 %      .func_val - value of the decision function
0025 %      .type     - which type of machine this is (here, 'classifier')
0026 %__________________________________________________________________________
0027 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0028 
0029 %--------------------------------------------------------------------------
0030 % Written by J.Richiardi
0031 % $Id: prt_machine_RT_bin.m 401 2011-11-23 10:44:20Z jrichiar $
0032 
0033 % FIXME: support for multiple kernels / feature representations
0034 % is not yet tested, there might be transposition or dimensionality errors.
0035 
0036 % TODO: this machine supports regression AND classification
0037 % TODO: this machine supports multi-class
0038 
0039 SANITYCHECK=true; % can turn off for "speed". Expert only.
0040 
0041 if SANITYCHECK==true
0042     % args should be a vector (empty or otherwise)
0043     if ~isvector(args)
0044         error('prt_machine_RT_bin:RTargsNotVec',['Error: RT'...
0045             ' args should be a vector. ' ...
0046             ' SOLUTION: Please check your code. ']);
0047     end
0048     
0049     % check we can reach the binary library
0050     if ~exist('rtenslearn_c','file')
0051         error('prt_machine_RT_bin:libNotFound',['Error:'...
0052             ' RT function rtenslearn_c could not be found !' ...
0053             ' SOLUTION: Please check your path.']);
0054     end
0055     
0056     % check it is indeed a two-class classification problem
0057     uTL=unique(d.tr_targets(:)); % unique training labels
0058     nC=numel(uTL);
0059     if nC>2
0060         error('prt_machine_RT_bin:problemNotBinary',['Error:'...
0061             ' This machine is only for two-class problems but the' ...
0062             ' current problem has ' num2str(nC) ' ! ' ...
0063             'SOLUTION: Please select another machine than ' ...
0064             'prt_machine_RT_bin in XXX']);
0065     end
0066     
0067     % check it is indeed labelled correctly (probably should be done
0068     % above?)
0069     if ~all(uTL==[1 2]')
0070         error('prt_machine_RT_bin:LabellingIncorect',['Error:'...
0071             ' This machine needs labels to be in {1,2} ' ...
0072             ' but they are ' mat2str(uTL) ' ! ' ...
0073             'SOLUTION: Please relabel your classes by changing the '...
0074             ' ''tr_lbs'' argument to prt_machine_RT_bin']);
0075     end
0076     
0077     % check we are not setting a ridiculous number of trees
0078     if isempty(args)
0079         args(1)=501;
0080         disp('prt_machine_RT_bin: defaulting to 501 trees');
0081     else
0082         if args(1)>10000;
0083             error('prt_machine_RT_bin:argsProblem:maxTrees',['Error:'...
0084                 ' Setting a high number of trees is not supported ' ...
0085                 ' without some modifications of the wrapper code. ' ...
0086                 ' Expert only! ' ...
0087                 'SOLUTION: Please change the offending args to '...
0088                 'a value less than 10000.']);
0089         end
0090     end
0091 end % SANITYCHECK
0092 
0093 
0094 % Run RT
0095 %--------------------------------------------------------------------------
0096 rtParams=init_rf(); % random forests
0097 rtParams.nbterms=args(1); % number of trees
0098 tridx=int32(1:numel(d.tr_targets));  % (WARNING: int32 format is mandatory)
0099 verbose=1;   % TODO: make this a machine arg
0100 
0101 [output.func_val output.w trees]=rtenslearn_c(single(d.train{1}),...
0102     single(d.tr_targets),tridx,[],rtParams,single(d.test{1}),verbose);
0103 
0104 % check if training succeeded:
0105 if isempty(output)
0106     error('prt_machine_RT_bin:RTtrainUnsuccessful',['Error:'...
0107         ' RT rtenslearn_c function did not run properly!' ...
0108         ' This could be a problem with the supplied function arguments'...
0109         mat2str(args)]);
0110 end
0111 
0112 % compute hard decisions
0113 output.predictions=round(output.func_val);
0114 
0115 if d.use_kernel==false
0116     % normalise importance to norm 1
0117     output.w=output.w/norm(output.w,1);
0118 else
0119     % do nothing - we can't compute primal weights from inside here
0120 end
0121 
0122 % prepare output
0123 output.type        = 'classifier';
0124 
0125 end

Generated on Mon 03-Sep-2012 18:07:18 by m2html © 2005