0001 function [outfile]=prt_cv_model(PRT,in)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027 prt_dir = char(regexprep(in.fname,'PRT.mat', ''));
0028
0029
0030 mid = prt_init_model(PRT, in);
0031
0032
0033 CV = PRT.model(mid).input.cv_mat;
0034 n_folds = size(CV,2);
0035 n_Phi = length(PRT.model(mid).input.fs);
0036 samp_idx = PRT.model(mid).input.samp_idx;
0037
0038
0039 if isfield(PRT.model(mid).input,'include_allscans') && ...
0040 PRT.model(mid).input.include_allscans
0041 t = PRT.model(mid).input.targ_allscans;
0042 else
0043 t = PRT.model(mid).input.targets;
0044 end
0045
0046
0047 disp('Loading data files.....>>');
0048 Phi_all = cell(1,n_Phi);
0049 for i = 1:length(PRT.model(mid).input.fs)
0050 fid = prt_init_fs(PRT, PRT.model(mid).input.fs(i));
0051
0052 if i == 1
0053 ID = PRT.fs(fid).id_mat(PRT.model(mid).input.samp_idx,:);
0054 end
0055
0056 if PRT.model(mid).input.use_kernel
0057 load(fullfile(prt_dir, PRT.fs(fid).k_file));
0058 Phi_all{i} = Phi(samp_idx,samp_idx);
0059 else
0060 error('training with features not implemented yet');
0061
0062 vname = whos('-file', [prt_dir,PRT.fs(fid).fs_file]);
0063 eval(['Phi_all{',num2str(i),'}=',vname,'(samp_idx,:);']);
0064 end
0065 end
0066
0067
0068
0069
0070 PRT.model(mid).output.fold = struct();
0071 for f = 1:n_folds
0072 disp ([' > running CV fold: ',num2str(f),' of ',num2str(n_folds),' ...'])
0073
0074 tr_idx = CV(:,f) == 1;
0075 te_idx = CV(:,f) == 2;
0076
0077 [Phi_tr, Phi_te, Phi_tt] = ...
0078 split_data(Phi_all, tr_idx, te_idx, PRT.model(mid).input.use_kernel);
0079
0080
0081 cvdata.train = Phi_tr;
0082 cvdata.test = Phi_te;
0083 if PRT.model(mid).input.use_kernel
0084 cvdata.testcov = Phi_tt;
0085 end
0086
0087
0088 cvdata.tr_targets = t(tr_idx,:);
0089
0090 if ~isempty(strfind(PRT.model(mid).input.machine.function,'krr'))
0091 mm=mean(t(tr_idx,:));
0092 cvdata.tr_targets = cvdata.tr_targets-mm;
0093 end
0094 cvdata.te_targets = t(te_idx,:);
0095 cvdata.tr_id = ID(tr_idx,:);
0096 cvdata.te_id = ID(te_idx,:);
0097 cvdata.use_kernel = PRT.model(mid).input.use_kernel;
0098 cvdata.pred_type = PRT.model(mid).input.type;
0099
0100
0101 cvdata.tr_param = prt_cv_opt_param(PRT, ID(tr_idx,:), mid);
0102 cvdata.te_param = prt_cv_opt_param(PRT, ID(te_idx,:), mid);
0103
0104
0105 ops = PRT.model(mid).input.operations(PRT.model(mid).input.operations ~=0 );
0106 for o = 1:length(ops)
0107 cvdata = prt_apply_operation(PRT, cvdata, ops(o));
0108 end
0109
0110
0111 model = prt_machine(cvdata, PRT.model(mid).input.machine);
0112
0113
0114 if ~any(strcmpi(fieldnames(model),'predictions'))
0115 error(['prt_cv_model:machineDoesNotGivePredictions',...
0116 'Machine did not produce a predictions field']);
0117 end
0118
0119
0120 if isfield(model,'te_targets')
0121 true_te_targets = model.te_targets(:);
0122 else
0123 true_te_targets = cvdata.te_targets(:);
0124 end
0125 if isfield(model,'tr_targets')
0126 tr_targets = model.tr_targets(:);
0127 else
0128 if ~isempty(strfind(PRT.model(mid).input.machine.function,'krr'))
0129 tr_targets = cvdata.tr_targets(:)+mm;
0130 else
0131 tr_targets = cvdata.tr_targets(:);
0132 end
0133 end
0134
0135 if ~isempty(strfind(PRT.model(mid).input.machine.function,'krr'))
0136
0137 model.predictions=model.predictions+mm;
0138 model.func_val=model.func_val+mm;
0139 end
0140
0141
0142 stats = prt_stats(model, true_te_targets, tr_targets);
0143
0144
0145 PRT.model(mid).output.fold(f).targets = true_te_targets;
0146 PRT.model(mid).output.fold(f).predictions = model.predictions(:);
0147 PRT.model(mid).output.fold(f).stats = stats;
0148
0149 if isfield(model,'func_val')
0150 PRT.model(mid).output.fold(f).func_val = model.func_val;
0151 end
0152
0153
0154 flds = fieldnames(model);
0155 for fld = 1:length(flds)
0156 fldnm = char(flds(fld));
0157 if ~strcmpi(fldnm,'predictions')
0158
0159 PRT.model(mid).output.fold(f).(fldnm)=model.(fldnm);
0160 end
0161 end
0162 end
0163
0164
0165
0166 t = vertcat(PRT.model(mid).output.fold(:).targets);
0167 m.type = PRT.model(mid).output.fold(1).type;
0168 m.predictions = vertcat(PRT.model(mid).output.fold(:).predictions);
0169
0170 stats = prt_stats(m,t(:),t(:));
0171
0172 PRT.model(mid).output.stats=stats;
0173
0174
0175
0176 outfile = [prt_dir, 'PRT'];
0177 disp('Updating PRT.mat.......>>')
0178 if spm_matlab_version_chk('7') >= 0
0179 save(outfile,'-V6','PRT');
0180 else
0181 save(outfile,'PRT');
0182 end
0183 end
0184
0185
0186
0187
0188
0189 function [Phi_tr Phi_te Phi_tt] = split_data(Phi_all, tr_idx, te_idx, usebf)
0190
0191
0192 n_mat = length(Phi_all);
0193
0194
0195 Phi_tr = cell(1,n_mat);
0196 for i = 1:n_mat;
0197 if usebf
0198 cols_tr = tr_idx;
0199 else
0200 cols_tr = size(Phi_all{i},2);
0201 end
0202
0203 Phi_tr{i} = Phi_all{i}(tr_idx,cols_tr);
0204 end
0205
0206
0207 Phi_te = cell(1,n_mat);
0208 Phi_tt = cell(1,n_mat);
0209 if usebf
0210 cols_tr = tr_idx;
0211 cols_te = te_idx;
0212 else
0213 cols_tr = size(Phi_all{i},2);
0214
0215 end
0216
0217 for i = 1:length(Phi_all)
0218 Phi_te{i} = Phi_all{i}(te_idx, cols_tr);
0219 if usebf
0220 Phi_tt{i} = Phi_all{i}(te_idx, cols_te);
0221 else
0222 Phi_tt{i} = [];
0223 end
0224 end
0225 end
0226