Home > . > prt_stats.m

prt_stats

PURPOSE ^

Function to compute predictions machine performance statistcs statistics

SYNOPSIS ^

function stats = prt_stats(model, tte, nk)

DESCRIPTION ^

 Function to compute predictions machine performance statistcs statistics

 Inputs:
 ----------------
 model.predictions: predictions derived from the predictive model
 model.type:        what type of prediction machine (e.g. 'classifier','regression')

 tte: true targets (test set)
 nk:  number of classes if classification (empty otherwise)
 flag:  'fold' for statistics in each fold
         'model' for statistics in each model
 
 Outputs:
-------------------
 Classification:
 stats.con_mat: Confusion matrix (nClasses x nClasses matrix, pred x true)
 stats.acc:     Accuracy (scalar)
 stats.b_acc:   Balanced accuracy (nClasses x 1 vector)
 stats.c_acc:   Accuracy by class (nClasses x 1 vector)
 stats.c_pv:    Predictive value for each class (nClasses x 1 vector)

 Regression:
 stats.mse:     Mean square error between test and prediction
 stats.corr:    Correlation between test and prediction
 stats.r2:      Squared correlation
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function stats = prt_stats(model, tte, nk)
0002 % Function to compute predictions machine performance statistcs statistics
0003 %
0004 % Inputs:
0005 % ----------------
0006 % model.predictions: predictions derived from the predictive model
0007 % model.type:        what type of prediction machine (e.g. 'classifier','regression')
0008 %
0009 % tte: true targets (test set)
0010 % nk:  number of classes if classification (empty otherwise)
0011 % flag:  'fold' for statistics in each fold
0012 %         'model' for statistics in each model
0013 %
0014 % Outputs:
0015 %-------------------
0016 % Classification:
0017 % stats.con_mat: Confusion matrix (nClasses x nClasses matrix, pred x true)
0018 % stats.acc:     Accuracy (scalar)
0019 % stats.b_acc:   Balanced accuracy (nClasses x 1 vector)
0020 % stats.c_acc:   Accuracy by class (nClasses x 1 vector)
0021 % stats.c_pv:    Predictive value for each class (nClasses x 1 vector)
0022 %
0023 % Regression:
0024 % stats.mse:     Mean square error between test and prediction
0025 % stats.corr:    Correlation between test and prediction
0026 % stats.r2:      Squared correlation
0027 %__________________________________________________________________________
0028 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0029 
0030 % Written by A. Marquand
0031 % $Id$
0032 
0033 % FIXME: is any code using the 'flags' input argument?
0034 if ~isfield(model,'type')
0035     warning('prt_stats:modelDoesNotProvideTypeField',...
0036         'model.type not specified, defaulting to classifier');
0037     model.type = 'classifier';
0038 end
0039 
0040 switch model.type
0041     case 'classifier'
0042         
0043         stats = compute_stats_classifier(model, tte, nk);
0044         
0045     case 'regression'
0046         
0047         stats = compute_stats_regression(model, tte);
0048         
0049     otherwise
0050         error('prt_stats:unknownTypeSpecified',...
0051             ['No method exists for processing machine: ',machine.type]);
0052 end
0053 
0054 end
0055 
0056 % -------------------------------------------------------------------------
0057 % Private functions
0058 % -------------------------------------------------------------------------
0059 
0060 function stats = compute_stats_classifier(model, tte, k)
0061 
0062 k = max(unique(k));        % number of classes
0063 
0064 stats.con_mat = zeros(k,k);
0065 for i = 1:length(tte)
0066     true_lb = tte(i);
0067     pred_lb = model.predictions(i);
0068     stats.con_mat(pred_lb,true_lb) = stats.con_mat(pred_lb,true_lb) + 1;
0069 end
0070 
0071 Cc = diag(stats.con_mat);   % correct predictions for each class
0072 Zc = sum(stats.con_mat)';   % total samples for each class (cols)
0073 nz = Zc ~= 0;               % classes with nonzero totals (cols)
0074 Zcr = sum(stats.con_mat,2); % total predictions for each class (rows)
0075 nzr = Zcr ~= 0;             % classes with nonzero totals (rows)
0076 
0077 stats.acc       = sum(Cc) ./ sum(Zc);
0078 stats.c_acc     = zeros(k,1);
0079 stats.c_acc(nz) = Cc(nz) ./ Zc(nz);
0080 stats.b_acc     = mean(stats.c_acc);
0081 stats.c_pv      = zeros(k,1);
0082 stats.c_pv(nzr) = Cc(nzr) ./ Zcr(nzr); 
0083 
0084 % confidence interval
0085 % TODO: check IID assumption here (chunks in run_permutation.m)
0086 % before applying tests, and give nans if not applicable...
0087 [lb,ub] = computeWilsonBinomialCI(sum(Cc),sum(Zc));
0088 stats.acc_lb=lb;
0089 stats.acc_ub=ub;
0090 end
0091 
0092 function stats = compute_stats_regression(model, tte)
0093 
0094 if numel(tte)<3
0095     stats.corr = NaN;
0096     stats.r2 = NaN;
0097 else
0098     coef = corrcoef(model.predictions,tte);
0099     stats.corr = coef(1,2);
0100     stats.r2 = coef(1,2).^2;
0101 end
0102 stats.mse  = mean((model.predictions-tte).^2);
0103 stats.nmse = mean((model.predictions-tte).^2)/(max(tte)-min(tte));
0104 end
0105 
0106 function [lb,ub] = computeWilsonBinomialCI(k,n)
0107 % Compute upper and lower 5% confidence interval bounds
0108 % for a binomial distribution using Wilson's 'score interval'
0109 %
0110 % IN
0111 %   k: scalar, number of successes
0112 %   n: scalar, number of samples
0113 %
0114 % OUT
0115 %   lb: lower bound of confidence interval
0116 %   ub: upper bound of confidence interval
0117 %
0118 % REFERENCES
0119 % Brown, Lawrence D., Cai, T. Tony, Dasgupta, Anirban, 1999.
0120 %  Interval estimation for a binomial proportion. Stat. Sci. 16, 101?133.
0121 % Edwin B. Wilson, Probable Inference, the Law of Succession, and
0122 %   Statistical Inference, Journal of the American Statistical Association,
0123 %   Vol. 22, No. 158 (Jun., 1927), pp. 209-212
0124 
0125 alpha=0.05;
0126 
0127 l=spm_invNcdf(1-alpha/2,0,1); %
0128 p=k/n;                    % sample proportion of success
0129 q=1-p;
0130 
0131 % compute terms of formula
0132 firstTerm=(k+(l^2)/2)/(n+l^2);
0133 secondTerm=((l*sqrt(n))/(n+l^2))*sqrt(p*q+((l^2)/(4*n)));
0134 
0135 % compute upper and lower bounds
0136 lb=firstTerm-secondTerm;
0137 ub=firstTerm+secondTerm;
0138 
0139 end

Generated on Tue 10-Feb-2015 18:16:33 by m2html © 2005