0001 function output = prt_machine_svm_bin(d,args)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030 SANITYCHECK=true;
0031
0032
0033 if ~ischar(args)
0034 def = prt_get_defaults('model');
0035 args = [def.libsvmargs, num2str(args)];
0036 end
0037
0038 if SANITYCHECK==true
0039
0040 if ~ischar(args)
0041 error('prt_machine_svm_bin:libSVMargsNotString',['Error: libSVM'...
0042 ' args should be a string. ' ...
0043 ' SOLUTION: Please do XXX']);
0044 end
0045
0046
0047 if ~exist('svmtrain','file')
0048 error('prt_machine_svm_bin:libNotFound',['Error:'...
0049 ' libSVM svmtrain function could not be found !' ...
0050 ' SOLUTION: Please check your path.']);
0051 end
0052
0053 uTL=unique(d.tr_targets(:));
0054 nC=numel(uTL);
0055 if nC>2
0056 error('prt_machine_svm_bin:problemNotBinary',['Error:'...
0057 ' This machine is only for two-class problems but the' ...
0058 ' current problem has ' num2str(nC) ' ! ' ...
0059 'SOLUTION: Please select another machine than ' ...
0060 'prt_machine_svm_bin in XXX']);
0061 end
0062
0063 if ~all(uTL==[1 2]')
0064 error('prt_machine_svm_bin:LabellingIncorect',['Error:'...
0065 ' This machine needs labels to be in {1,2} ' ...
0066 ' but they are ' mat2str(uTL) ' ! ' ...
0067 'SOLUTION: Please relabel your classes by changing the '...
0068 ' ''tr_targets'' argument to prt_machine_svm_bin']);
0069 end
0070
0071
0072 if ~isempty(regexp(args,'-s\s+[1234]','once'))
0073 error('prt_machine_svm_bin:argsProblem:onlyCSVCsupport',['Error:'...
0074 ' This machine only supports a C-SVC formulation ' ...
0075 ' (''-s 0'' in the ''args'' parameter), but the args ' ...
0076 ' supplied are ''' args ''' ! ' ...
0077 'SOLUTION: Please change the offending part of args to '...
0078 '''-s 0''']);
0079 end
0080
0081
0082
0083 if ~isempty(regexp(args,'-t\s+[123]','once'))
0084 error('prt_machine_svm_bin:argsProblem:onlyLinOrPrecomputeSupport',...
0085 ['Error: This machine only supports linear or precomputed ' ...
0086 'kernels (''-t 0/4'' in the ''args'' parameter), but the args ' ...
0087 ' supplied are ''' args ''' ! ' ...
0088 'SOLUTION: Please change the offending part of args to '...
0089 '''-t 0'' or ''-t 4'' as intended']);
0090 end
0091
0092 end
0093
0094
0095
0096
0097 nlbs = length(d.tr_targets);
0098 allids_tr = (1:nlbs)';
0099
0100 model = svmtrain(d.tr_targets,[allids_tr d.train{:}],args);
0101
0102
0103 if isempty(model)
0104 if (ischar(args))
0105 args_str = args;
0106 else
0107 args_str = '';
0108 end
0109 error('prt_machine_svm_bin:libSVMsvmtrainUnsuccessful',['Error:'...
0110 ' libSVM svmtrain function did not run properly!' ...
0111 ' This could be a problem with the supplied function arguments'...
0112 ' ' args_str '']);
0113 end
0114
0115
0116
0117 sgn = -1*(2 * model.Label(1) - 3);
0118 alpha = get_alpha(model,nlbs,sgn);
0119 b = -model.rho *sgn;
0120
0121
0122
0123 if iscell(d.test)
0124 func_val = cell2mat(d.test)*alpha+b;
0125 else
0126 func_val = d.test*alpha+b;
0127 end
0128
0129
0130 predictions = sign(func_val);
0131
0132
0133
0134
0135
0136 c1PredIdx = predictions==1;
0137 predictions(c1PredIdx) = 1;
0138 predictions(~c1PredIdx) = 2;
0139
0140 output.predictions = predictions;
0141 output.func_val = func_val;
0142 output.type = 'classifier';
0143 output.alpha = alpha;
0144 output.b = b;
0145 output.totalSV = model.totalSV;
0146
0147 end
0148
0149
0150
0151 function alpha = get_alpha(model,n,sgn)
0152
0153 alpha = zeros(n,1);
0154
0155 for i = 1:model.totalSV
0156 ind = model.SVs(i);
0157 alpha(ind) = model.sv_coef(i);
0158 end
0159
0160 alpha = sgn*alpha;
0161
0162 end
0163