Home > . > prt_stats.m

prt_stats

PURPOSE ^

Function to compute predictions machine performance statistcs statistics

SYNOPSIS ^

function stats = prt_stats(model, tte, ttr)

DESCRIPTION ^

 Function to compute predictions machine performance statistcs statistics

 Inputs:
 ----------------
 model.predictions: predictions derived from the predictive model
 model.type:        what type of prediction machine (e.g. 'classifier','regression')

 tte: true targets (test set)
 ttr: true targets (training set - needed to get the number of classes)
 flag:  'fold' for statistics in each fold
         'model' for statistics in each model
 
 Outputs:
-------------------
 Classification:
 stats.con_mat: Confusion matrix (nClasses x nClasses matrix, pred x true)
 stats.acc:     Accuracy (scalar)
 stats.b_acc:   Balanced accuracy (nClasses x 1 vector)
 stats.c_acc:   Accuracy by class (nClasses x 1 vector)
 stats.c_pv:    Predictive value for each class (nClasses x 1 vector)

 Regression:
 stats.mse:     Mean square error between test and prediction
 stats.corr:     Correlation between test and prediction
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function stats = prt_stats(model, tte, ttr)
0002 % Function to compute predictions machine performance statistcs statistics
0003 %
0004 % Inputs:
0005 % ----------------
0006 % model.predictions: predictions derived from the predictive model
0007 % model.type:        what type of prediction machine (e.g. 'classifier','regression')
0008 %
0009 % tte: true targets (test set)
0010 % ttr: true targets (training set - needed to get the number of classes)
0011 % flag:  'fold' for statistics in each fold
0012 %         'model' for statistics in each model
0013 %
0014 % Outputs:
0015 %-------------------
0016 % Classification:
0017 % stats.con_mat: Confusion matrix (nClasses x nClasses matrix, pred x true)
0018 % stats.acc:     Accuracy (scalar)
0019 % stats.b_acc:   Balanced accuracy (nClasses x 1 vector)
0020 % stats.c_acc:   Accuracy by class (nClasses x 1 vector)
0021 % stats.c_pv:    Predictive value for each class (nClasses x 1 vector)
0022 %
0023 % Regression:
0024 % stats.mse:     Mean square error between test and prediction
0025 % stats.corr:     Correlation between test and prediction
0026 %__________________________________________________________________________
0027 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0028 
0029 % Written by A. Marquand
0030 % $Id: prt_stats.m 508 2012-04-07 16:16:21Z amarquan $
0031 
0032 % FIXME: is any code using the 'flags' input argument?
0033 if ~isfield(model,'type')
0034     warning('prt_stats:modelDoesNotProvideTypeField',...
0035         'model.type not specified, defaulting to classifier');
0036     model.type = 'classifier';
0037 end
0038 
0039 switch model.type
0040     case 'classifier'
0041         
0042         stats = compute_stats_classifier(model, tte, ttr);
0043         
0044     case 'regression'
0045         
0046         stats = compute_stats_regression(model, tte);
0047         
0048     otherwise
0049         error('prt_stats:unknownTypeSpecified',...
0050             ['No method exists for processing machine: ',machine.type]);
0051 end
0052 
0053 end
0054 
0055 % -------------------------------------------------------------------------
0056 % Private functions
0057 % -------------------------------------------------------------------------
0058 
0059 function stats = compute_stats_classifier(model, tte, ttr)
0060 
0061 k = max(unique(ttr));        % number of classes
0062 
0063 stats.con_mat = zeros(k,k);
0064 for i = 1:length(tte)
0065     true_lb = tte(i);
0066     pred_lb = model.predictions(i);
0067     stats.con_mat(pred_lb,true_lb) = stats.con_mat(pred_lb,true_lb) + 1;
0068 end
0069 
0070 Cc = diag(stats.con_mat);   % correct predictions for each class
0071 Zc = sum(stats.con_mat)';   % total samples for each class (cols)
0072 nz = Zc ~= 0;               % classes with nonzero totals (cols)
0073 Zcr = sum(stats.con_mat,2); % total predictions for each class (rows)
0074 nzr = Zcr ~= 0;             % classes with nonzero totals (rows)
0075 
0076 stats.acc       = sum(Cc) ./ sum(Zc);
0077 stats.c_acc     = zeros(k,1);
0078 stats.c_acc(nz) = Cc(nz) ./ Zc(nz);
0079 stats.b_acc     = mean(stats.c_acc);
0080 stats.c_pv      = zeros(k,1);
0081 stats.c_pv(nzr) = Cc(nzr) ./ Zcr(nzr); 
0082 
0083 % confidence interval
0084 % TODO: check IID assumption here (chunks in run_permutation.m)
0085 % before applying tests, and give nans if not applicable...
0086 [lb,ub] = computeWilsonBinomialCI(sum(Cc),sum(Zc));
0087 stats.acc_lb=lb;
0088 stats.acc_ub=ub;
0089 end
0090 
0091 function stats = compute_stats_regression(model, tte)
0092 
0093 if numel(tte)<3
0094     stats.corr=NaN;
0095 else
0096     coef=corrcoef(model.predictions,tte);
0097     stats.corr=coef(1,2);
0098 end
0099 stats.mse=mean((model.predictions-tte).^2);
0100 end
0101 
0102 function [lb,ub] = computeWilsonBinomialCI(k,n)
0103 % Compute upper and lower 5% confidence interval bounds
0104 % for a binomial distribution using Wilson's 'score interval'
0105 %
0106 % IN
0107 %   k: scalar, number of successes
0108 %   n: scalar, number of samples
0109 %
0110 % OUT
0111 %   lb: lower bound of confidence interval
0112 %   ub: upper bound of confidence interval
0113 %
0114 % REFERENCES
0115 % Brown, Lawrence D., Cai, T. Tony, Dasgupta, Anirban, 1999.
0116 %  Interval estimation for a binomial proportion. Stat. Sci. 16, 101?133.
0117 % Edwin B. Wilson, Probable Inference, the Law of Succession, and
0118 %   Statistical Inference, Journal of the American Statistical Association,
0119 %   Vol. 22, No. 158 (Jun., 1927), pp. 209-212
0120 
0121 alpha=0.05;
0122 
0123 l=spm_invNcdf(1-alpha/2,0,1); %
0124 p=k/n;                    % sample proportion of success
0125 q=1-p;
0126 
0127 % compute terms of formula
0128 firstTerm=(k+(l^2)/2)/(n+l^2);
0129 secondTerm=((l*sqrt(n))/(n+l^2))*sqrt(p*q+((l^2)/(4*n)));
0130 
0131 % compute upper and lower bounds
0132 lb=firstTerm-secondTerm;
0133 ub=firstTerm+secondTerm;
0134 
0135 end

Generated on Sun 20-May-2012 13:24:48 by m2html © 2005