


Run binary SVM - wrapper for libSVM
FORMAT output = prt_machine_svm_bin(d,args)
Inputs:
d - structure with data information, with mandatory fields:
.train - training data (cell array of matrices of row vectors,
each [Ntr x D]). each matrix contains one representation
of the data. This is useful for approaches such as
multiple kernel learning.
.test - testing data (cell array of matrices row vectors, each
[Nte x D])
.tr_targets - training labels (for classification) or values (for
regression) (column vector, [Ntr x 1])
.use_kernel - flag, is data in form of kernel matrices (true) of in
form of features (false)
args - libSVM arguments
Output:
output - output of machine (struct).
* Mandatory fields:
.predictions - predictions of classification or regression [Nte x D]
* Optional fields:
.func_val - value of the decision function
.type - which type of machine this is (here, 'classifier')
__________________________________________________________________________
Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory


0001 function output = prt_machine_svm_bin(d,args) 0002 % Run binary SVM - wrapper for libSVM 0003 % FORMAT output = prt_machine_svm_bin(d,args) 0004 % Inputs: 0005 % d - structure with data information, with mandatory fields: 0006 % .train - training data (cell array of matrices of row vectors, 0007 % each [Ntr x D]). each matrix contains one representation 0008 % of the data. This is useful for approaches such as 0009 % multiple kernel learning. 0010 % .test - testing data (cell array of matrices row vectors, each 0011 % [Nte x D]) 0012 % .tr_targets - training labels (for classification) or values (for 0013 % regression) (column vector, [Ntr x 1]) 0014 % .use_kernel - flag, is data in form of kernel matrices (true) of in 0015 % form of features (false) 0016 % args - libSVM arguments 0017 % Output: 0018 % output - output of machine (struct). 0019 % * Mandatory fields: 0020 % .predictions - predictions of classification or regression [Nte x D] 0021 % * Optional fields: 0022 % .func_val - value of the decision function 0023 % .type - which type of machine this is (here, 'classifier') 0024 %__________________________________________________________________________ 0025 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory 0026 0027 % Written by M.J.Rosa, J.Mourao-Miranda and J.Richiardi 0028 % $Id: prt_machine_svm_bin.m 557 2012-06-12 09:07:53Z mjrosa $ 0029 0030 0031 % TODO: make sure the svmtrain we reach is the libsvm one, not the one 0032 % with the same name from the bioinformatics toolbox! 0033 % toolbox/bioinfo/biolearning/ 0034 0035 0036 SANITYCHECK=true; % can turn off for "speed". Expert only. 0037 0038 if SANITYCHECK==true 0039 % args should be a string (empty or otherwise) 0040 if ~ischar(args) 0041 error('prt_machine_svm_bin:libSVMargsNotString',['Error: libSVM'... 0042 ' args should be a string. ' ... 0043 ' SOLUTION: Please do XXX']); 0044 end 0045 0046 % check we can reach the binary library 0047 if ~exist('svmtrain','file') 0048 error('prt_machine_svm_bin:libNotFound',['Error:'... 0049 ' libSVM svmtrain function could not be found !' ... 0050 ' SOLUTION: Please check your path.']); 0051 end 0052 % check it is indeed a two-class classification problem 0053 uTL=unique(d.tr_targets(:)); 0054 nC=numel(uTL); 0055 if nC>2 0056 error('prt_machine_svm_bin:problemNotBinary',['Error:'... 0057 ' This machine is only for two-class problems but the' ... 0058 ' current problem has ' num2str(nC) ' ! ' ... 0059 'SOLUTION: Please select another machine than ' ... 0060 'prt_machine_svm_bin in XXX']); 0061 end 0062 % check it is indeed labelled correctly (probably should be done 0063 if ~all(uTL==[1 2]') 0064 error('prt_machine_svm_bin:LabellingIncorect',['Error:'... 0065 ' This machine needs labels to be in {1,2} ' ... 0066 ' but they are ' mat2str(uTL) ' ! ' ... 0067 'SOLUTION: Please relabel your classes by changing the '... 0068 ' ''tr_targets'' argument to prt_machine_svm_bin']); 0069 end 0070 0071 % check we are using the C-SVC (exclude types -s 1,2,3,4) 0072 if ~isempty(regexp(args,'-s\s+[1234]','once')) 0073 error('prt_machine_svm_bin:argsProblem:onlyCSVCsupport',['Error:'... 0074 ' This machine only supports a C-SVC formulation ' ... 0075 ' (''-s 0'' in the ''args'' parameter), but the args ' ... 0076 ' supplied are ''' args ''' ! ' ... 0077 'SOLUTION: Please change the offending part of args to '... 0078 '''-s 0''']); 0079 end 0080 0081 % check we are using linear or precomputed kernels 0082 % (exclude types -t 1,2,3) 0083 if ~isempty(regexp(args,'-t\s+[123]','once')) 0084 error('prt_machine_svm_bin:argsProblem:onlyLinOrPrecomputeSupport',... 0085 ['Error: This machine only supports linear or precomputed ' ... 0086 'kernels (''-t 0/4'' in the ''args'' parameter), but the args ' ... 0087 ' supplied are ''' args ''' ! ' ... 0088 'SOLUTION: Please change the offending part of args to '... 0089 '''-t 0'' or ''-t 4'' as intended']); 0090 end 0091 0092 end 0093 0094 0095 % Run SVM 0096 %-------------------------------------------------------------------------- 0097 nlbs = length(d.tr_targets); 0098 allids_tr = (1:nlbs)'; 0099 0100 model = svmtrain(d.tr_targets,[allids_tr d.train{:}],args); 0101 0102 % check if training succeeded: 0103 if isempty(model) 0104 if (ischar(args)) 0105 args_str = args; 0106 else 0107 args_str = ''; 0108 end 0109 error('prt_machine_svm_bin:libSVMsvmtrainUnsuccessful',['Error:'... 0110 ' libSVM svmtrain function did not run properly!' ... 0111 ' This could be a problem with the supplied function arguments'... 0112 ' ' args_str '']); 0113 end 0114 0115 0116 % Get SV coefficients (alpha) in the original order and the bias term (b) 0117 sgn = -1*(2 * model.Label(1) - 3); %variable to account for label convention in PRoNTo 0118 alpha = get_alpha(model,nlbs,sgn); 0119 b = -model.rho *sgn; 0120 0121 % compute prediction directly rather than using svmpredict, which does 0122 % not allow empty test labels 0123 if iscell(d.test) 0124 func_val = cell2mat(d.test)*alpha+b; 0125 else 0126 func_val = d.test*alpha+b; 0127 end 0128 0129 % compute hard decisions 0130 predictions = sign(func_val); 0131 0132 0133 % Outputs 0134 %-------------------------------------------------------------------------- 0135 % change predictions from 1/-1 to 1/2 0136 c1PredIdx = predictions==1; 0137 predictions(c1PredIdx) = 1; %positive values = 1 0138 predictions(~c1PredIdx) = 2; %negative values = 2 0139 0140 output.predictions = predictions; 0141 output.func_val = func_val; 0142 output.type = 'classifier'; 0143 output.alpha = alpha; 0144 output.b = b; 0145 output.totalSV = model.totalSV; 0146 0147 end 0148 0149 % Get SV coefficients 0150 %-------------------------------------------------------------------------- 0151 function alpha = get_alpha(model,n,sgn) 0152 % needs a function because examples can be re-ordered by libsvm 0153 alpha = zeros(n,1); 0154 0155 for i = 1:model.totalSV 0156 ind = model.SVs(i); 0157 alpha(ind) = model.sv_coef(i); 0158 end 0159 0160 alpha = sgn*alpha; 0161 0162 end 0163