Home > machines > prt_machine_svm_bin.m

prt_machine_svm_bin

PURPOSE ^

Run binary SVM - wrapper for libSVM

SYNOPSIS ^

function output = prt_machine_svm_bin(d,args)

DESCRIPTION ^

 Run binary SVM - wrapper for libSVM
 FORMAT output = prt_machine_svm_bin(d,args)
 Inputs:
   d         - structure with data information, with mandatory fields:
     .train      - training data (cell array of matrices of row vectors,
                   each [Ntr x D]). each matrix contains one representation
                   of the data. This is useful for approaches such as
                   multiple kernel learning.
     .test       - testing data  (cell array of matrices row vectors, each
                   [Nte x D])
     .tr_targets - training labels (for classification) or values (for
                   regression) (column vector, [Ntr x 1])
     .use_kernel - flag, is data in form of kernel matrices (true) of in 
                form of features (false)
    args     - libSVM arguments
 Output:
    output  - output of machine (struct).
     * Mandatory fields:
      .predictions - predictions of classification or regression [Nte x D]
     * Optional fields:
      .func_val - value of the decision function
      .type     - which type of machine this is (here, 'classifier')
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function output = prt_machine_svm_bin(d,args)
0002 % Run binary SVM - wrapper for libSVM
0003 % FORMAT output = prt_machine_svm_bin(d,args)
0004 % Inputs:
0005 %   d         - structure with data information, with mandatory fields:
0006 %     .train      - training data (cell array of matrices of row vectors,
0007 %                   each [Ntr x D]). each matrix contains one representation
0008 %                   of the data. This is useful for approaches such as
0009 %                   multiple kernel learning.
0010 %     .test       - testing data  (cell array of matrices row vectors, each
0011 %                   [Nte x D])
0012 %     .tr_targets - training labels (for classification) or values (for
0013 %                   regression) (column vector, [Ntr x 1])
0014 %     .use_kernel - flag, is data in form of kernel matrices (true) of in
0015 %                form of features (false)
0016 %    args     - libSVM arguments
0017 % Output:
0018 %    output  - output of machine (struct).
0019 %     * Mandatory fields:
0020 %      .predictions - predictions of classification or regression [Nte x D]
0021 %     * Optional fields:
0022 %      .func_val - value of the decision function
0023 %      .type     - which type of machine this is (here, 'classifier')
0024 %__________________________________________________________________________
0025 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0026 
0027 % Written by M.J.Rosa, J.Mourao-Miranda and J.Richiardi
0028 % $Id: prt_machine_svm_bin.m 516 2012-04-11 14:34:30Z jmourao $
0029 
0030 % FIXME: support for multiple kernels / feature representations
0031 % is not yet tested, there might be transposition or dimensionality errors.
0032 
0033 % TODO: maybe also check the prt_machine .usebf argument for compatibility
0034 % with libSVM syntax ?
0035 
0036 % TODO: make sure the svmtrain we reach is the libsvm one, not the one
0037 % with the same name from the bioinformatics toolbox!
0038 % toolbox/bioinfo/biolearning/
0039 
0040 
0041 SANITYCHECK=true; % can turn off for "speed". Expert only.
0042 
0043 if SANITYCHECK==true
0044     % args should be a string (empty or otherwise)
0045     if ~ischar(args)
0046         error('prt_machine_svm_bin:libSVMargsNotString',['Error: libSVM'...
0047             ' args should be a string. ' ...
0048             ' SOLUTION: Please do XXX']);
0049     end
0050     
0051     % check we can reach the binary library
0052     if ~exist('svmtrain','file')
0053         error('prt_machine_svm_bin:libNotFound',['Error:'...
0054             ' libSVM svmtrain function could not be found !' ...
0055             ' SOLUTION: Please check your path.']);
0056     end
0057     % check it is indeed a two-class classification problem
0058     uTL=unique(d.tr_targets(:));
0059     nC=numel(uTL);
0060     if nC>2
0061         error('prt_machine_svm_bin:problemNotBinary',['Error:'...
0062             ' This machine is only for two-class problems but the' ...
0063             ' current problem has ' num2str(nC) ' ! ' ...
0064             'SOLUTION: Please select another machine than ' ...
0065             'prt_machine_svm_bin in XXX']);
0066     end
0067     % check it is indeed labelled correctly (probably should be done
0068     if ~all(uTL==[1 2]')
0069         error('prt_machine_svm_bin:LabellingIncorect',['Error:'...
0070             ' This machine needs labels to be in {1,2} ' ...
0071             ' but they are ' mat2str(uTL) ' ! ' ...
0072             'SOLUTION: Please relabel your classes by changing the '...
0073             ' ''tr_targets'' argument to prt_machine_svm_bin']);
0074     end
0075     
0076     % check we are using the C-SVC (exclude types -s 1,2,3,4)
0077     if ~isempty(regexp(args,'-s\s+[1234]','once'))
0078         error('prt_machine_svm_bin:argsProblem:onlyCSVCsupport',['Error:'...
0079             ' This machine only supports a C-SVC formulation ' ...
0080             ' (''-s 0'' in the ''args'' parameter), but the args ' ...
0081             ' supplied are ''' args ''' ! ' ...
0082             'SOLUTION: Please change the offending part of args to '...
0083             '''-s 0''']);
0084     end
0085     
0086     % check we are using linear or precomputed kernels
0087     % (exclude types -t 1,2,3)
0088     if ~isempty(regexp(args,'-t\s+[123]','once'))
0089         error('prt_machine_svm_bin:argsProblem:onlyLinOrPrecomputeSupport',...
0090             ['Error: This machine only supports linear or precomputed ' ...
0091             'kernels (''-t 0/4'' in the ''args'' parameter), but the args ' ...
0092             ' supplied are ''' args ''' ! ' ...
0093             'SOLUTION: Please change the offending part of args to '...
0094             '''-t 0'' or ''-t 4'' as intended']);
0095     end
0096     
0097 end
0098 
0099 if ~isempty(regexp(args,'-t\s+4','once'))
0100     hasPrecomputedKernel = true;
0101 else
0102     hasPrecomputedKernel = false;
0103 end
0104 
0105 % Run SVM
0106 %--------------------------------------------------------------------------
0107 nlbs  = length(d.tr_targets);
0108 if hasPrecomputedKernel
0109     allids_tr = (1:nlbs)';
0110 else
0111     allids_tr = [];
0112 end
0113 model = svmtrain(d.tr_targets,[allids_tr d.train{:}],args);
0114 
0115 % check if training succeeded:
0116 if isempty(model)
0117     if (ischar(args))
0118         args_str = args;
0119     else
0120         args_str = '';
0121     end
0122     error('prt_machine_svm_bin:libSVMsvmtrainUnsuccessful',['Error:'...
0123         ' libSVM svmtrain function did not run properly!' ...
0124         ' This could be a problem with the supplied function arguments'...
0125         ' ' args_str '']);
0126 end
0127 sgn = -1*(2 * model.Label(1) - 3);
0128 b     = -model.rho *sgn;
0129 
0130 if hasPrecomputedKernel
0131     alpha = get_alpha(model,nlbs,sgn);
0132 else
0133     alpha = model.sv_coef;    % recover alphas directly
0134     SVs   = model.SVs;          % recover also the SV's themselves
0135 end
0136 
0137 % compute prediction directly rather than using svmpredict, which does
0138 % not allow empty test labels
0139 if hasPrecomputedKernel
0140     if iscell(d.test)
0141         func_val = cell2mat(d.test)*alpha+b;
0142     else
0143         func_val = d.test*alpha+b;
0144     end
0145 else
0146     % compute primal weight vector
0147     w = SVs'*alpha;
0148     % compute function
0149     if iscell(d.test)
0150         func_val = cell2mat(d.test)*w+b;
0151     else
0152         func_val = d.test*w+b;
0153     end
0154 end
0155 
0156 % compute hard decisions
0157 predictions = sign(func_val);
0158 
0159 % % REMOVEME compare with libsvm svmpredict results
0160 % if hasPrecomputedKernel
0161 %     allids_te=(1:size(cell2mat(test),1))';
0162 % else
0163 %     allids_te=[];
0164 % end
0165 % [foo_preds, foo_acc, foo_decision] = svmpredict([ones(10,1); ones(10,1)*2],[allids_te cell2mat(test)], model);
0166 % [func_val foo_decision]
0167 
0168 % TODO: convert labels to chosen output format
0169 
0170 % Outputs
0171 %--------------------------------------------------------------------------
0172 % change predictions from 1/-1 to 1/2
0173 c1PredIdx               = predictions==1; 
0174 %predictions(c1PredIdx)  = model.Label(1);
0175 %predictions(~c1PredIdx) = model.Label(2);
0176 predictions(c1PredIdx)  = 1; %positive values = 1
0177 predictions(~c1PredIdx) = 2; %negative values = 2
0178 
0179 output.predictions = predictions;
0180 output.func_val    = func_val;
0181 output.type        = 'classifier';
0182 output.alpha       = alpha;
0183 output.b           = b;
0184 output.totalSV     = model.totalSV;
0185 if exist('w','var')==1
0186     output.w = w;
0187 end
0188 
0189 end
0190 
0191 % Get SV coefficients
0192 %--------------------------------------------------------------------------
0193 function alpha = get_alpha(model,n,sgn)
0194 % needs a function because examples can be re-ordered by libsvm
0195 alpha = zeros(n,1);
0196 
0197 for i = 1:model.totalSV
0198     ind        = model.SVs(i);
0199     alpha(ind) = model.sv_coef(i);
0200 end
0201 
0202 % alpha = model.Label(1)*alpha;
0203 % sgn = -1*(2 * model.Label(1) - 3);
0204 alpha = sgn*alpha;
0205 
0206 end
0207

Generated on Sun 20-May-2012 13:24:48 by m2html © 2005