Home > machines > prt_machine_svm_bin.m

prt_machine_svm_bin

PURPOSE ^

Run binary SVM - wrapper for libSVM

SYNOPSIS ^

function output = prt_machine_svm_bin(d,args)

DESCRIPTION ^

 Run binary SVM - wrapper for libSVM
 FORMAT output = prt_machine_svm_bin(d,args)
 Inputs:
   d         - structure with data information, with mandatory fields:
     .train      - training data (cell array of matrices of row vectors,
                   each [Ntr x D]). each matrix contains one representation
                   of the data. This is useful for approaches such as
                   multiple kernel learning.
     .test       - testing data  (cell array of matrices row vectors, each
                   [Nte x D])
     .tr_targets - training labels (for classification) or values (for
                   regression) (column vector, [Ntr x 1])
     .use_kernel - flag, is data in form of kernel matrices (true) of in 
                form of features (false)
    args     - libSVM arguments
 Output:
    output  - output of machine (struct).
     * Mandatory fields:
      .predictions - predictions of classification or regression [Nte x D]
     * Optional fields:
      .func_val - value of the decision function
      .type     - which type of machine this is (here, 'classifier')
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function output = prt_machine_svm_bin(d,args)
0002 % Run binary SVM - wrapper for libSVM
0003 % FORMAT output = prt_machine_svm_bin(d,args)
0004 % Inputs:
0005 %   d         - structure with data information, with mandatory fields:
0006 %     .train      - training data (cell array of matrices of row vectors,
0007 %                   each [Ntr x D]). each matrix contains one representation
0008 %                   of the data. This is useful for approaches such as
0009 %                   multiple kernel learning.
0010 %     .test       - testing data  (cell array of matrices row vectors, each
0011 %                   [Nte x D])
0012 %     .tr_targets - training labels (for classification) or values (for
0013 %                   regression) (column vector, [Ntr x 1])
0014 %     .use_kernel - flag, is data in form of kernel matrices (true) of in
0015 %                form of features (false)
0016 %    args     - libSVM arguments
0017 % Output:
0018 %    output  - output of machine (struct).
0019 %     * Mandatory fields:
0020 %      .predictions - predictions of classification or regression [Nte x D]
0021 %     * Optional fields:
0022 %      .func_val - value of the decision function
0023 %      .type     - which type of machine this is (here, 'classifier')
0024 %__________________________________________________________________________
0025 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0026 
0027 % Written by M.J.Rosa, J.Mourao-Miranda and J.Richiardi
0028 % $Id: prt_machine_svm_bin.m 557 2012-06-12 09:07:53Z mjrosa $
0029 
0030 
0031 % TODO: make sure the svmtrain we reach is the libsvm one, not the one
0032 % with the same name from the bioinformatics toolbox!
0033 % toolbox/bioinfo/biolearning/
0034 
0035 
0036 SANITYCHECK=true; % can turn off for "speed". Expert only.
0037 
0038 if SANITYCHECK==true
0039     % args should be a string (empty or otherwise)
0040     if ~ischar(args)
0041         error('prt_machine_svm_bin:libSVMargsNotString',['Error: libSVM'...
0042             ' args should be a string. ' ...
0043             ' SOLUTION: Please do XXX']);
0044     end
0045     
0046     % check we can reach the binary library
0047     if ~exist('svmtrain','file')
0048         error('prt_machine_svm_bin:libNotFound',['Error:'...
0049             ' libSVM svmtrain function could not be found !' ...
0050             ' SOLUTION: Please check your path.']);
0051     end
0052     % check it is indeed a two-class classification problem
0053     uTL=unique(d.tr_targets(:));
0054     nC=numel(uTL);
0055     if nC>2
0056         error('prt_machine_svm_bin:problemNotBinary',['Error:'...
0057             ' This machine is only for two-class problems but the' ...
0058             ' current problem has ' num2str(nC) ' ! ' ...
0059             'SOLUTION: Please select another machine than ' ...
0060             'prt_machine_svm_bin in XXX']);
0061     end
0062     % check it is indeed labelled correctly (probably should be done
0063     if ~all(uTL==[1 2]')
0064         error('prt_machine_svm_bin:LabellingIncorect',['Error:'...
0065             ' This machine needs labels to be in {1,2} ' ...
0066             ' but they are ' mat2str(uTL) ' ! ' ...
0067             'SOLUTION: Please relabel your classes by changing the '...
0068             ' ''tr_targets'' argument to prt_machine_svm_bin']);
0069     end
0070     
0071     % check we are using the C-SVC (exclude types -s 1,2,3,4)
0072     if ~isempty(regexp(args,'-s\s+[1234]','once'))
0073         error('prt_machine_svm_bin:argsProblem:onlyCSVCsupport',['Error:'...
0074             ' This machine only supports a C-SVC formulation ' ...
0075             ' (''-s 0'' in the ''args'' parameter), but the args ' ...
0076             ' supplied are ''' args ''' ! ' ...
0077             'SOLUTION: Please change the offending part of args to '...
0078             '''-s 0''']);
0079     end
0080     
0081     % check we are using linear or precomputed kernels
0082     % (exclude types -t 1,2,3)
0083     if ~isempty(regexp(args,'-t\s+[123]','once'))
0084         error('prt_machine_svm_bin:argsProblem:onlyLinOrPrecomputeSupport',...
0085             ['Error: This machine only supports linear or precomputed ' ...
0086             'kernels (''-t 0/4'' in the ''args'' parameter), but the args ' ...
0087             ' supplied are ''' args ''' ! ' ...
0088             'SOLUTION: Please change the offending part of args to '...
0089             '''-t 0'' or ''-t 4'' as intended']);
0090     end
0091     
0092 end
0093 
0094 
0095 % Run SVM
0096 %--------------------------------------------------------------------------
0097 nlbs  = length(d.tr_targets);
0098 allids_tr = (1:nlbs)';
0099 
0100 model = svmtrain(d.tr_targets,[allids_tr d.train{:}],args);
0101 
0102 % check if training succeeded:
0103 if isempty(model)
0104     if (ischar(args))
0105         args_str = args;
0106     else
0107         args_str = '';
0108     end
0109     error('prt_machine_svm_bin:libSVMsvmtrainUnsuccessful',['Error:'...
0110         ' libSVM svmtrain function did not run properly!' ...
0111         ' This could be a problem with the supplied function arguments'...
0112         ' ' args_str '']);
0113 end
0114 
0115 
0116 % Get SV coefficients (alpha) in the original order and the bias term (b)
0117 sgn   = -1*(2 * model.Label(1) - 3); %variable to account for label convention in PRoNTo
0118 alpha = get_alpha(model,nlbs,sgn);
0119 b     = -model.rho *sgn;
0120 
0121 % compute prediction directly rather than using svmpredict, which does
0122 % not allow empty test labels
0123 if iscell(d.test)
0124     func_val = cell2mat(d.test)*alpha+b;
0125 else
0126     func_val = d.test*alpha+b;
0127 end
0128 
0129 % compute hard decisions
0130 predictions = sign(func_val);
0131 
0132 
0133 % Outputs
0134 %--------------------------------------------------------------------------
0135 % change predictions from 1/-1 to 1/2
0136 c1PredIdx               = predictions==1; 
0137 predictions(c1PredIdx)  = 1; %positive values = 1
0138 predictions(~c1PredIdx) = 2; %negative values = 2
0139 
0140 output.predictions = predictions;
0141 output.func_val    = func_val;
0142 output.type        = 'classifier';
0143 output.alpha       = alpha;
0144 output.b           = b;
0145 output.totalSV     = model.totalSV;
0146 
0147 end
0148 
0149 % Get SV coefficients
0150 %--------------------------------------------------------------------------
0151 function alpha = get_alpha(model,n,sgn)
0152 % needs a function because examples can be re-ordered by libsvm
0153 alpha = zeros(n,1);
0154 
0155 for i = 1:model.totalSV
0156     ind        = model.SVs(i);
0157     alpha(ind) = model.sv_coef(i);
0158 end
0159 
0160 alpha = sgn*alpha;
0161 
0162 end
0163

Generated on Mon 03-Sep-2012 18:07:18 by m2html © 2005