0001 function [outfile]=prt_cv_model(PRT,in)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030 prt_dir = char(regexprep(in.fname,'PRT.mat', ''));
0031
0032
0033 mid = prt_init_model(PRT, in);
0034
0035
0036 CV = PRT.model(mid).input.cv_mat;
0037 n_folds = size(CV,2);
0038 n_Phi = length(PRT.model(mid).input.fs);
0039 samp_idx = PRT.model(mid).input.samp_idx;
0040
0041
0042 if isfield(PRT.model(mid).input,'include_allscans') && ...
0043 PRT.model(mid).input.include_allscans
0044 t = PRT.model(mid).input.targ_allscans;
0045 else
0046 t = PRT.model(mid).input.targets;
0047 end
0048
0049
0050 disp('Loading data files.....>>');
0051 Phi_all = cell(1,n_Phi);
0052 for i = 1:length(PRT.model(mid).input.fs)
0053 fid = prt_init_fs(PRT, PRT.model(mid).input.fs(i));
0054
0055 if i == 1
0056 ID = PRT.fs(fid).id_mat(PRT.model(mid).input.samp_idx,:);
0057 end
0058
0059 if PRT.model(mid).input.use_kernel
0060 load(fullfile(prt_dir, PRT.fs(fid).k_file));
0061 Phi_all{i} = Phi(samp_idx,samp_idx);
0062 else
0063 error('training with features not implemented yet');
0064
0065
0066 end
0067 end
0068
0069
0070
0071 PRT.model(mid).output.fold = struct();
0072 for f = 1:n_folds
0073 disp ([' > running CV fold: ',num2str(f),' of ',num2str(n_folds),' ...'])
0074
0075 fdata.ID = ID;
0076 fdata.mid = mid;
0077 fdata.CV = CV(:,f);
0078 fdata.Phi_all = Phi_all;
0079 fdata.t = t;
0080
0081
0082 [model, targets] = prt_cv_fold(PRT,fdata);
0083
0084
0085 stats = prt_stats(model, targets.test, targets.train);
0086
0087
0088 PRT.model(mid).output.fold(f).targets = targets.test;
0089 PRT.model(mid).output.fold(f).predictions = model.predictions(:);
0090 PRT.model(mid).output.fold(f).stats = stats;
0091
0092 flds = fieldnames(model);
0093 for fld = 1:length(flds)
0094 fldnm = char(flds(fld));
0095 if ~strcmpi(fldnm,'predictions')
0096 PRT.model(mid).output.fold(f).(fldnm)=model.(fldnm);
0097 end
0098 end
0099 end
0100
0101
0102
0103 t = vertcat(PRT.model(mid).output.fold(:).targets);
0104 m.type = PRT.model(mid).output.fold(1).type;
0105 m.predictions = vertcat(PRT.model(mid).output.fold(:).predictions);
0106
0107 stats = prt_stats(m,t(:),t(:));
0108
0109 PRT.model(mid).output.stats=stats;
0110
0111
0112
0113 outfile = [prt_dir, 'PRT'];
0114 disp('Updating PRT.mat.......>>')
0115 if spm_matlab_version_chk('7') >= 0
0116 save(outfile,'-V6','PRT');
0117 else
0118 save(outfile,'PRT');
0119 end
0120 end
0121