Home > . > prt_compute_cv_mat.m

prt_compute_cv_mat

PURPOSE ^

Function to compute the cross-validation matrix. Also does error checking

SYNOPSIS ^

function [CV,ID] = prt_compute_cv_mat(PRT, in, modelid, use_nested_cv)

DESCRIPTION ^

 Function to compute the cross-validation matrix. Also does error checking

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [CV,ID] = prt_compute_cv_mat(PRT, in, modelid, use_nested_cv)
0002 % Function to compute the cross-validation matrix. Also does error checking
0003 
0004 % Check if the use_nested_cv varible has been inputed
0005 if ~exist('use_nested_cv', 'var')
0006     use_nested_cv = false;
0007 end
0008 
0009 if use_nested_cv
0010     fid = prt_init_fs(PRT,PRT.model(modelid).input.fs(1));
0011 else
0012     fid = prt_init_fs(PRT, in.fs(1));
0013 end
0014 
0015 % create the PRT.model(modelid).input.cv field
0016 if ~isfield(PRT.model(modelid).input, 'cv')
0017     PRT.model(modelid).input.cv={};
0018 end
0019 
0020 if ~use_nested_cv
0021     if isfield(PRT.model(modelid).input,'cv_k')
0022         k = PRT.model(modelid).input.cv_k;
0023     elseif isfield(in.cv,'k')
0024         k = in.cv.k;
0025         PRT.model(modelid).input.cv_k = k;
0026     else
0027         k=0; %loo cv
0028         PRT.model(modelid).input.cv_k = k;
0029     end
0030 else
0031     k = in.cv.k;    
0032 end
0033 
0034 
0035 if k==1 %half-half
0036     k=2;
0037     flaghh=1;
0038     PRT.model(modelid).input.cv_k = k;
0039 else
0040     flaghh=0;
0041 end
0042 
0043 if isfield(in,'include_allscans') && in.include_allscans
0044     % use the full id matrix if not user-provided (nested CV)
0045     if use_nested_cv == false 
0046         ID = PRT.fs(fid).id_mat;
0047     else
0048         ID = in.ID;
0049     end
0050 else
0051     if use_nested_cv == false
0052         ID = PRT.fs(fid).id_mat(PRT.model(modelid).input.samp_idx,:);
0053     else
0054         ID = in.ID;
0055     end
0056 end
0057 
0058 switch in.cv.type
0059     case 'loso'
0060         % leave-one-subject-out
0061         % give each subject a unique id
0062         [gids,d1] = unique(ID(:,1), 'last');
0063         [gids,d2] = unique(ID(:,1),'first');
0064         gc = 0;
0065         ns=zeros(length(gids),1);
0066         dID = ID;
0067         for g = 1:length(gids)
0068             ns(g)=length(unique(ID(d2(g):d1(g),2)));
0069             gidx = ID(:,1) == gids(g);
0070             dID(gidx,2) = dID(gidx,2) + gc;
0071             gc = gc + ns(g);
0072         end
0073         % Compute CV matrix
0074         if k>1 %k-fold CV
0075             nsf=floor(gc/k);
0076             % Check that the number of folds does not exceed the number of
0077             % subjects
0078             if length(unique(dID(:,2)))<2*nsf
0079                 error('prt_model:losoSelectedWithTooLargeK',...
0080                     'More than 50%% of data in testing set, reduce k');
0081             end
0082             mns=mod(gc,k);
0083             dk=nsf*ones(1,k);
0084             dk(end)=dk(end)+mns;
0085             inds=1;
0086             sk=[];
0087             for ii=1:length(dk)
0088                 sk=[sk,inds*ones(1,dk(ii))];
0089                 inds=inds+1;
0090             end
0091         else %Leave-One-Subject-Out
0092             sk=1:gc;
0093         end
0094         snums=[];
0095         for g = 1:length(gids)
0096             snums = [snums;histc(dID(d2(g):d1(g),2),unique(dID(d2(g):d1(g),2)))];
0097         end
0098         if length(snums) == 1
0099             error('prt_model:losoSelectedWithOneSubject',...
0100                 'LOSO CV selected but only one subject is included');
0101         end
0102         G = cell(length(unique(sk)),1);
0103         for s = 1:length(unique(sk))
0104             G{s} = ones(sum(snums(sk==s)),1);
0105         end
0106         CV = blkdiag(G{:}) + 1;
0107         if flaghh
0108             CV=CV(:,1);
0109         end
0110         
0111         
0112     case 'losgo'
0113         %modify the ID to take the structure of the classes into account
0114         vcl=zeros(size(ID,1),2);
0115         if isfield(in,'class')
0116             for ic=1:length(in.class)
0117                 nsg=1;
0118                 for ig=1:length(in.class(ic).group)
0119                     gnames={PRT.group(:).gr_name};
0120                     [d,ng]=ismember(in.class(ic).group(ig).gr_name,gnames);
0121                     for is=1:length(in.class(ic).group(ig).subj)
0122                         inds=find(ID(:,1)==ng);
0123                         indss=find(ID(inds,2)==is);
0124                         vcl(inds(indss),1)=ic;
0125                         vcl(inds(indss),2)=nsg;
0126                         nsg=nsg+1;
0127                     end
0128                 end
0129             end
0130             % leave-one-subject-per-group-out
0131             [gids,d1] = unique(vcl(:,1), 'last');
0132             [gids,d2] = unique(vcl(:,1),'first');
0133             %compute the number of subjects per class
0134             ns=zeros(length(gids),1);
0135             for ig= 1:length(gids)
0136                 ns(ig)=length(unique(vcl(d2(ig):d1(ig),2)));
0137             end
0138         elseif isfield(in,'t')
0139             ntar = unique(in.t);
0140             nsg = 1;
0141             ns=zeros(length(ntar),1);
0142             for ic = 1:length(ntar)
0143                 inds = find(in.t == ic);
0144                 ns(ic) = length(inds);
0145                 vcl(inds,1) = ic;
0146                 ngi = unique(ID(inds,1));
0147                 for ig = 1:length(ngi)
0148                     igi = find(ID(inds,1)==ngi(ig));
0149                     indss = unique(ID(inds(igi),2));
0150                     for is = 1:length(indss)
0151                         inss = find(ID(inds(igi),2) == indss(is));
0152                         vcl(inds(igi(inss)),2) = nsg;
0153                         nsg = nsg + 1;
0154                     end
0155                 end
0156             end
0157         end
0158         
0159         
0160         sids=max(ns);
0161         if sids == 1
0162             error('prt_model:losgoSelectedWithOneSubject',...
0163                 'LOSGO CV selected but only one subject is included');
0164         end
0165         [nsf]=floor(min(ns/k));
0166         if k==0
0167             CV = zeros(size(ID,1),sids);
0168         else
0169             CV = zeros(size(ID,1),k);
0170         end
0171         if k>1 && nsf==1
0172             disp('Performing Leave-One Subject per Group-Out')
0173         end
0174         snums=[];
0175         for g=1:length(ns)
0176             is=vcl(:,1)==g;
0177             if k>1 && nsf>1 %k-fold CV
0178                 nsfg=floor(ns(g)/k);
0179                 if nsfg<1
0180                     error('prt_model:losgoSelectedWithTooLargeK',...
0181                         ['Number of subjects in group ',num2str(g),' smaller than k']);
0182                 elseif nsfg*2>ns
0183                     error('prt_model:losgoSelectedWithTooLargeK2',...
0184                         ['Leaving more than 50%% of subjects in group ',num2str(g),' out']);
0185                 end
0186                 mns=mod(ns(g),nsfg);
0187                 dk=nsfg*ones(1,floor(length(unique(vcl(is,2)))/nsfg));
0188                 if mns>0
0189                     dk(end)=dk(end)+mns;
0190                 end
0191                 inds=1;
0192                 sk=[];
0193                 for ii=1:length(dk)
0194                     sk=[sk,inds*ones(1,dk(ii))];
0195                     inds=inds+1;
0196                 end
0197             else %Leave-One-Subject per Group-Out
0198                 sk=1:ns(g);
0199             end
0200             snums = histc(vcl(is,2),unique(vcl(is,2)));
0201             G = cell(length(unique(sk)),1);
0202             for s = 1:length(unique(sk))
0203                 G{s} = ones(sum(snums(sk==s)),1);
0204             end
0205             CV(is,1:max(sk)) = blkdiag(G{:}) + 1;
0206             if length(unique(sk))<size(CV,2)  %smaller group, fill with 'train'
0207                 CV(is,length(unique(sk))+1:size(CV,2))= ...
0208                     ones(length(find(is)),length(length(unique(sk))+1:size(CV,2)));
0209             end
0210             if flaghh
0211                 CV=CV(:,1);
0212             end
0213         end
0214 
0215         
0216         
0217     case 'lobo'
0218         % leave-one-block-out - limited to one single subject for the
0219         % moment
0220         % blocks already have a unique ID
0221         
0222         [cids,d1] = unique(ID(:,4), 'last');
0223         [cids,d2] = unique(ID(:,4),'first');
0224         gc = 0;
0225         nb=zeros(length(cids),1);
0226         dID = ID;
0227         for c = 1:length(cids)
0228             nb(c)=length(unique(ID(d2(c):d1(c),5)));
0229             cidx = ID(:,4) == cids(c);
0230             dID(cidx,5) = dID(cidx,5) + gc;
0231             gc = gc + nb(c);
0232         end
0233 
0234         if k>1 %k-fold CV
0235             nsb=floor(gc/k);
0236             % Check that the number of folds does not exceed the number of
0237             % subjects
0238             if length(unique(dID(:,5)))<2*nsb
0239                 error('prt_model:loboSelectedWithTooLargeK',...
0240                     'More than 50%% of data in testing set, reduce k');
0241             end
0242             mns=mod(gc,k);
0243             dk=nsb*ones(1,k);
0244             dk(end)=dk(end)+mns;
0245             inds=1;
0246             sk=[];
0247             for ii=1:length(dk)
0248                 sk=[sk,inds*ones(1,dk(ii))];
0249                 inds=inds+1;
0250             end
0251         else %Leave-One-Block-Out
0252             sk = 1:gc;
0253         end
0254         snums=[];
0255         for g = 1:length(cids)
0256             snums = [snums;histc(dID(d2(g):d1(g),5),unique(dID(d2(g):d1(g),5)))];
0257         end
0258         if length(snums) == 1
0259             error('prt_model:logoSelectedWithOneSubject',...
0260                 'LOGO CV selected but only one block is included');
0261         end
0262         G = cell(length(unique(sk)),1);
0263         for s = 1:length(unique(sk))
0264             G{s} = ones(sum(snums(sk==s)),1);
0265         end
0266         CV = blkdiag(G{:}) + 1;
0267         if flaghh
0268             CV=CV(:,1);
0269         end
0270         
0271     case 'locbo'
0272         % leave-one-condition-per-block-out
0273         error('leave-one-condition-per-block-out not yet implemented');
0274         
0275     case 'loro'
0276         % leave-one-run-out
0277         
0278         mids = unique(ID(:,3));
0279         
0280         CV = zeros(size(ID,1),length(mids));
0281         for m = 1:length(mids)
0282             midx = ID(:,3) == mids(m);
0283             CV(:,m) = double(midx) + 1;
0284         end
0285         
0286     case 'custom'
0287         % load matrix and check that each fold contains test and train data.
0288         if isfield(in.cv,'mat_file') && ~isempty(in.cv.mat_file)
0289             load(in.cv.mat_file)
0290             if ~exist('CV')
0291                 error('No CV variable found in the mat file provided')
0292             else
0293                 if size(CV,1) ~= size(ID,1)
0294                     error('CV does not comprise the same number of samples as selected')
0295                 else
0296                     nfo = size(CV,2);
0297                     macv = max(CV);
0298                     if length(find(macv==2)) ~= nfo %test data in all folds
0299                         error('One (or more) fold does not contain test data')
0300                     else
0301                         [i,j]=find(CV==1);
0302                         if length(unique(j)) ~= nfo %train data in all folds
0303                             error('One (or more) fold does not contain train data')
0304                         else
0305                             lv=CV>2;
0306                             sv=CV<0;
0307                             if any(any(lv)) || any(any(sv))
0308                                 error('Values larger than 2 or smaller than 0 found in CV')
0309                             end
0310                         end
0311                     end
0312                 end
0313             end
0314         elseif isfield(PRT.model(modelid).input,'cv_mat') && ...
0315                 ~isempty(PRT.model(modelid).input.cv_mat) % custom CV specified by GUI
0316             CV = PRT.model(modelid).input.cv_mat;
0317         else
0318             % custom CV with only number of folds specified
0319             if isfield(in.cv,'k')
0320                 CV = ones(size(ID,1),in.cv.k);
0321             end
0322             
0323         end
0324         
0325         
0326     otherwise
0327         error('prt_cv:unknownTypeSpecified',...
0328             ['Unknown type specified for CV structure (',in.type',')']);
0329 end
0330 
0331 end

Generated on Tue 10-Feb-2015 18:16:33 by m2html © 2005