0001 function [CV,ID] = prt_compute_cv_mat(PRT, in, modelid, use_nested_cv)
0002
0003
0004
0005 if ~exist('use_nested_cv', 'var')
0006 use_nested_cv = false;
0007 end
0008
0009 if use_nested_cv
0010 fid = prt_init_fs(PRT,PRT.model(modelid).input.fs(1));
0011 else
0012 fid = prt_init_fs(PRT, in.fs(1));
0013 end
0014
0015
0016 if ~isfield(PRT.model(modelid).input, 'cv')
0017 PRT.model(modelid).input.cv={};
0018 end
0019
0020 if ~use_nested_cv
0021 if isfield(PRT.model(modelid).input,'cv_k')
0022 k = PRT.model(modelid).input.cv_k;
0023 elseif isfield(in.cv,'k')
0024 k = in.cv.k;
0025 PRT.model(modelid).input.cv_k = k;
0026 else
0027 k=0;
0028 PRT.model(modelid).input.cv_k = k;
0029 end
0030 else
0031 k = in.cv.k;
0032 end
0033
0034
0035 if k==1
0036 k=2;
0037 flaghh=1;
0038 PRT.model(modelid).input.cv_k = k;
0039 else
0040 flaghh=0;
0041 end
0042
0043 if isfield(in,'include_allscans') && in.include_allscans
0044
0045 if use_nested_cv == false
0046 ID = PRT.fs(fid).id_mat;
0047 else
0048 ID = in.ID;
0049 end
0050 else
0051 if use_nested_cv == false
0052 ID = PRT.fs(fid).id_mat(PRT.model(modelid).input.samp_idx,:);
0053 else
0054 ID = in.ID;
0055 end
0056 end
0057
0058 switch in.cv.type
0059 case 'loso'
0060
0061
0062 [gids,d1] = unique(ID(:,1), 'last');
0063 [gids,d2] = unique(ID(:,1),'first');
0064 gc = 0;
0065 ns=zeros(length(gids),1);
0066 dID = ID;
0067 for g = 1:length(gids)
0068 ns(g)=length(unique(ID(d2(g):d1(g),2)));
0069 gidx = ID(:,1) == gids(g);
0070 dID(gidx,2) = dID(gidx,2) + gc;
0071 gc = gc + ns(g);
0072 end
0073
0074 if k>1
0075 nsf=floor(gc/k);
0076
0077
0078 if length(unique(dID(:,2)))<2*nsf
0079 error('prt_model:losoSelectedWithTooLargeK',...
0080 'More than 50%% of data in testing set, reduce k');
0081 end
0082 mns=mod(gc,k);
0083 dk=nsf*ones(1,k);
0084 dk(end)=dk(end)+mns;
0085 inds=1;
0086 sk=[];
0087 for ii=1:length(dk)
0088 sk=[sk,inds*ones(1,dk(ii))];
0089 inds=inds+1;
0090 end
0091 else
0092 sk=1:gc;
0093 end
0094 snums=[];
0095 for g = 1:length(gids)
0096 snums = [snums;histc(dID(d2(g):d1(g),2),unique(dID(d2(g):d1(g),2)))];
0097 end
0098 if length(snums) == 1
0099 error('prt_model:losoSelectedWithOneSubject',...
0100 'LOSO CV selected but only one subject is included');
0101 end
0102 G = cell(length(unique(sk)),1);
0103 for s = 1:length(unique(sk))
0104 G{s} = ones(sum(snums(sk==s)),1);
0105 end
0106 CV = blkdiag(G{:}) + 1;
0107 if flaghh
0108 CV=CV(:,1);
0109 end
0110
0111
0112 case 'losgo'
0113
0114 vcl=zeros(size(ID,1),2);
0115 if isfield(in,'class')
0116 for ic=1:length(in.class)
0117 nsg=1;
0118 for ig=1:length(in.class(ic).group)
0119 gnames={PRT.group(:).gr_name};
0120 [d,ng]=ismember(in.class(ic).group(ig).gr_name,gnames);
0121 for is=1:length(in.class(ic).group(ig).subj)
0122 inds=find(ID(:,1)==ng);
0123 indss=find(ID(inds,2)==is);
0124 vcl(inds(indss),1)=ic;
0125 vcl(inds(indss),2)=nsg;
0126 nsg=nsg+1;
0127 end
0128 end
0129 end
0130
0131 [gids,d1] = unique(vcl(:,1), 'last');
0132 [gids,d2] = unique(vcl(:,1),'first');
0133
0134 ns=zeros(length(gids),1);
0135 for ig= 1:length(gids)
0136 ns(ig)=length(unique(vcl(d2(ig):d1(ig),2)));
0137 end
0138 elseif isfield(in,'t')
0139 ntar = unique(in.t);
0140 nsg = 1;
0141 ns=zeros(length(ntar),1);
0142 for ic = 1:length(ntar)
0143 inds = find(in.t == ic);
0144 ns(ic) = length(inds);
0145 vcl(inds,1) = ic;
0146 ngi = unique(ID(inds,1));
0147 for ig = 1:length(ngi)
0148 igi = find(ID(inds,1)==ngi(ig));
0149 indss = unique(ID(inds(igi),2));
0150 for is = 1:length(indss)
0151 inss = find(ID(inds(igi),2) == indss(is));
0152 vcl(inds(igi(inss)),2) = nsg;
0153 nsg = nsg + 1;
0154 end
0155 end
0156 end
0157 end
0158
0159
0160 sids=max(ns);
0161 if sids == 1
0162 error('prt_model:losgoSelectedWithOneSubject',...
0163 'LOSGO CV selected but only one subject is included');
0164 end
0165 [nsf]=floor(min(ns/k));
0166 if k==0
0167 CV = zeros(size(ID,1),sids);
0168 else
0169 CV = zeros(size(ID,1),k);
0170 end
0171 if k>1 && nsf==1
0172 disp('Performing Leave-One Subject per Group-Out')
0173 end
0174 snums=[];
0175 for g=1:length(ns)
0176 is=vcl(:,1)==g;
0177 if k>1 && nsf>1
0178 nsfg=floor(ns(g)/k);
0179 if nsfg<1
0180 error('prt_model:losgoSelectedWithTooLargeK',...
0181 ['Number of subjects in group ',num2str(g),' smaller than k']);
0182 elseif nsfg*2>ns
0183 error('prt_model:losgoSelectedWithTooLargeK2',...
0184 ['Leaving more than 50%% of subjects in group ',num2str(g),' out']);
0185 end
0186 mns=mod(ns(g),nsfg);
0187 dk=nsfg*ones(1,floor(length(unique(vcl(is,2)))/nsfg));
0188 if mns>0
0189 dk(end)=dk(end)+mns;
0190 end
0191 inds=1;
0192 sk=[];
0193 for ii=1:length(dk)
0194 sk=[sk,inds*ones(1,dk(ii))];
0195 inds=inds+1;
0196 end
0197 else
0198 sk=1:ns(g);
0199 end
0200 snums = histc(vcl(is,2),unique(vcl(is,2)));
0201 G = cell(length(unique(sk)),1);
0202 for s = 1:length(unique(sk))
0203 G{s} = ones(sum(snums(sk==s)),1);
0204 end
0205 CV(is,1:max(sk)) = blkdiag(G{:}) + 1;
0206 if length(unique(sk))<size(CV,2)
0207 CV(is,length(unique(sk))+1:size(CV,2))= ...
0208 ones(length(find(is)),length(length(unique(sk))+1:size(CV,2)));
0209 end
0210 if flaghh
0211 CV=CV(:,1);
0212 end
0213 end
0214
0215
0216
0217 case 'lobo'
0218
0219
0220
0221
0222 [cids,d1] = unique(ID(:,4), 'last');
0223 [cids,d2] = unique(ID(:,4),'first');
0224 gc = 0;
0225 nb=zeros(length(cids),1);
0226 dID = ID;
0227 for c = 1:length(cids)
0228 nb(c)=length(unique(ID(d2(c):d1(c),5)));
0229 cidx = ID(:,4) == cids(c);
0230 dID(cidx,5) = dID(cidx,5) + gc;
0231 gc = gc + nb(c);
0232 end
0233
0234 if k>1
0235 nsb=floor(gc/k);
0236
0237
0238 if length(unique(dID(:,5)))<2*nsb
0239 error('prt_model:loboSelectedWithTooLargeK',...
0240 'More than 50%% of data in testing set, reduce k');
0241 end
0242 mns=mod(gc,k);
0243 dk=nsb*ones(1,k);
0244 dk(end)=dk(end)+mns;
0245 inds=1;
0246 sk=[];
0247 for ii=1:length(dk)
0248 sk=[sk,inds*ones(1,dk(ii))];
0249 inds=inds+1;
0250 end
0251 else
0252 sk = 1:gc;
0253 end
0254 snums=[];
0255 for g = 1:length(cids)
0256 snums = [snums;histc(dID(d2(g):d1(g),5),unique(dID(d2(g):d1(g),5)))];
0257 end
0258 if length(snums) == 1
0259 error('prt_model:logoSelectedWithOneSubject',...
0260 'LOGO CV selected but only one block is included');
0261 end
0262 G = cell(length(unique(sk)),1);
0263 for s = 1:length(unique(sk))
0264 G{s} = ones(sum(snums(sk==s)),1);
0265 end
0266 CV = blkdiag(G{:}) + 1;
0267 if flaghh
0268 CV=CV(:,1);
0269 end
0270
0271 case 'locbo'
0272
0273 error('leave-one-condition-per-block-out not yet implemented');
0274
0275 case 'loro'
0276
0277
0278 mids = unique(ID(:,3));
0279
0280 CV = zeros(size(ID,1),length(mids));
0281 for m = 1:length(mids)
0282 midx = ID(:,3) == mids(m);
0283 CV(:,m) = double(midx) + 1;
0284 end
0285
0286 case 'custom'
0287
0288 if isfield(in.cv,'mat_file') && ~isempty(in.cv.mat_file)
0289 load(in.cv.mat_file)
0290 if ~exist('CV')
0291 error('No CV variable found in the mat file provided')
0292 else
0293 if size(CV,1) ~= size(ID,1)
0294 error('CV does not comprise the same number of samples as selected')
0295 else
0296 nfo = size(CV,2);
0297 macv = max(CV);
0298 if length(find(macv==2)) ~= nfo
0299 error('One (or more) fold does not contain test data')
0300 else
0301 [i,j]=find(CV==1);
0302 if length(unique(j)) ~= nfo
0303 error('One (or more) fold does not contain train data')
0304 else
0305 lv=CV>2;
0306 sv=CV<0;
0307 if any(any(lv)) || any(any(sv))
0308 error('Values larger than 2 or smaller than 0 found in CV')
0309 end
0310 end
0311 end
0312 end
0313 end
0314 elseif isfield(PRT.model(modelid).input,'cv_mat') && ...
0315 ~isempty(PRT.model(modelid).input.cv_mat)
0316 CV = PRT.model(modelid).input.cv_mat;
0317 else
0318
0319 if isfield(in.cv,'k')
0320 CV = ones(size(ID,1),in.cv.k);
0321 end
0322
0323 end
0324
0325
0326 otherwise
0327 error('prt_cv:unknownTypeSpecified',...
0328 ['Unknown type specified for CV structure (',in.type',')']);
0329 end
0330
0331 end