0001 function [outfile]=prt_cv_model(PRT,in)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030 prt_dir = char(regexprep(in.fname,'PRT.mat', ''));
0031
0032
0033 mid = prt_init_model(PRT, in);
0034
0035
0036 CV = PRT.model(mid).input.cv_mat;
0037 n_folds = size(CV,2);
0038
0039
0040 if isfield(PRT.model(mid).input,'include_allscans') && ...
0041 PRT.model(mid).input.include_allscans
0042 t = PRT.model(mid).input.targ_allscans;
0043
0044 if any(ismember(PRT.model(mid).input.operations,5))
0045 cov = PRT.model(mid).input.cov_allscans;
0046 else
0047 cov=[];
0048 end
0049 else
0050 t = PRT.model(mid).input.targets;
0051
0052 if any(ismember(PRT.model(mid).input.operations,5))
0053 cov = PRT.model(mid).input.covar;
0054 else
0055 cov=[];
0056 end
0057 end
0058
0059
0060 if strcmpi(PRT.model(mid).input.type,'classification')
0061 nc=max(unique(t));
0062 else
0063 nc=[];
0064 end
0065 fdata.nc = nc;
0066
0067
0068 [Phi_all,ID] = prt_getKernelModel(PRT,prt_dir,mid);
0069
0070
0071
0072
0073 PRT.model(mid).output=struct();
0074 PRT.model(mid).output.fold = struct();
0075 for f = 1:n_folds
0076 disp ([' > running CV fold: ',num2str(f),' of ',num2str(n_folds),' ...'])
0077
0078 fdata.ID = ID;
0079 fdata.mid = mid;
0080 fdata.CV = CV(:,f);
0081 fdata.Phi_all = Phi_all;
0082 fdata.t = t;
0083 if ~isempty(cov)
0084 fdata.cov = cov;
0085 end
0086
0087
0088 if isfield(PRT.model(mid).input,'use_nested_cv')
0089 if PRT.model(mid).input.use_nested_cv
0090 [out] = prt_nested_cv(PRT, fdata);
0091 PRT.model(mid).output.fold(f).param_effect = out;
0092 PRT.model(mid).input.machine.args = out.opt_param;
0093 end
0094 end
0095
0096
0097 [model, targets] = prt_cv_fold(PRT,fdata);
0098
0099
0100 if strcmpi(PRT.model(mid).input.type,'classification')
0101 if ~all(ismember(unique(targets.test),unique(targets.train)))
0102 beep
0103 disp('At least one class is in the test set but not in the training set')
0104 disp('Abandoning modelling, please correct class selection/cross-validation')
0105 return
0106 end
0107 end
0108
0109
0110 stats = prt_stats(model, targets.test, nc);
0111
0112
0113 PRT.model(mid).output.fold(f).targets = targets.test;
0114 PRT.model(mid).output.fold(f).predictions = model.predictions(:);
0115 PRT.model(mid).output.fold(f).stats = stats;
0116
0117 flds = fieldnames(model);
0118 for fld = 1:length(flds)
0119 fldnm = char(flds(fld));
0120 if ~strcmpi(fldnm,'predictions')
0121 PRT.model(mid).output.fold(f).(fldnm)=model.(fldnm);
0122 end
0123 end
0124 end
0125
0126
0127
0128 ttt = vertcat(PRT.model(mid).output.fold(:).targets);
0129 m.type = PRT.model(mid).output.fold(1).type;
0130 m.predictions = vertcat(PRT.model(mid).output.fold(:).predictions);
0131
0132 stats = prt_stats(m,ttt(:),nc);
0133
0134 PRT.model(mid).output.stats=stats;
0135
0136
0137
0138
0139 outfile = [prt_dir, filesep,'PRT.mat'];
0140 disp('Updating PRT.mat.......>>')
0141 if spm_check_version('MATLAB','7') < 0
0142 save(outfile,'-V6','PRT');
0143 else
0144 save(outfile,'PRT');
0145 end
0146 end
0147
0148