0001 function out = prt_apply_operation(PRT, in, opid)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040 out = in;
0041
0042 for d = 1:length(in.train)
0043 switch opid
0044 case 1
0045
0046
0047
0048 Ptr = compute_tc_mat(in.tr_id);
0049 if in.use_kernel
0050 out.train{d} = Ptr*in.train{d}*Ptr';
0051 else
0052 out.train{d} = Ptr*in.train{d};
0053 end
0054 out.tr_id = round(Ptr*in.tr_id);
0055 if isfield(in,'tr_targets')
0056 out.tr_targets = Ptr*in.tr_targets;
0057 if strcmpi(in.pred_type,'classification');
0058 out.tr_targets = round(out.tr_targets);
0059 end
0060 end
0061
0062
0063 if isfield(in,'test')
0064 Pte = compute_tc_mat(in.te_id);
0065 if in.use_kernel
0066 out.test{d} = Pte*in.test{d}*Ptr';
0067 out.testcov{d} = Pte*in.testcov{d}*Pte';
0068 else
0069 out.test{d} = Pte*in.test{d};
0070 end
0071 out.te_id = round(Pte*in.te_id);
0072 if isfield(in,'te_targets')
0073 out.te_targets = Pte*in.te_targets;
0074 if strcmpi(in.pred_type,'classification');
0075 out.te_targets = round(out.te_targets);
0076 end
0077 end
0078 end
0079
0080 case 2
0081
0082
0083
0084 Ptr = compute_sa_mat(in.tr_id);
0085 if in.use_kernel
0086 out.train{d} = Ptr*in.train{d}*Ptr';
0087 else
0088 out.train{d} = Ptr*in.train{d};
0089 end
0090 out.tr_id = round(Ptr*in.tr_id);
0091 if isfield(in,'tr_targets')
0092 out.tr_targets = Ptr*in.tr_targets;
0093 if strcmpi(in.pred_type,'classification');
0094 out.tr_targets = round(out.tr_targets);
0095 end
0096 end
0097
0098
0099 if isfield(in,'test')
0100 Pte = compute_sa_mat(in.te_id);
0101 if in.use_kernel
0102 out.test{d} = Pte*in.test{d}*Ptr';
0103 out.testcov{d} = Pte*in.testcov{d}*Pte';
0104 else
0105 out.test{d} = Pte*in.test{d};
0106 end
0107 out.te_id = round(Pte*in.te_id);
0108 if isfield(in,'te_targets')
0109 out.te_targets = Pte*in.te_targets;
0110 if strcmpi(in.pred_type,'classification');
0111 out.te_targets = round(out.te_targets);
0112 end
0113 end
0114 end
0115
0116 case 3
0117
0118
0119 if ~isfield(in,'test')
0120
0121 if in.use_kernel
0122 out.train{d} = prt_centre_kernel(in.train{d});
0123 else
0124 m = mean(in.train{d});
0125
0126 out.train{d} = zeros(size(in.train{d}));
0127 for r = 1:size(in.train{d},1)
0128 out.train{d}(r,:) = in.train{d}(r,:) - m;
0129 end
0130 end
0131 else
0132 if in.use_kernel
0133 [out.train{d}, out.test{d}, out.testcov{d}] = ...
0134 prt_centre_kernel(in.train{d},in.test{d},in.testcov{d});
0135 else
0136 m = mean(in.train{d});
0137
0138
0139 out.train{d} = zeros(size(in.train{d}));
0140 for r = 1:size(in.train{d},1)
0141 out.train{d}(r,:) = in.train{d}(r,:) - m;
0142 end
0143 out.test{d} = zeros(size(in.test{d}));
0144 for r = 1:size(in.test{d},1)
0145 out.test{d}(r,:) = in.test{d}(r,:) - m;
0146 end
0147 end
0148 out.te_id = in.te_id;
0149 end
0150 out.tr_id = in.tr_id;
0151 if isfield(in,'tr_targets')
0152 out.tr_targets = in.tr_targets;
0153 end
0154 if isfield(in,'te_targets')
0155 out.te_targets = in.te_targets;
0156 end
0157
0158 case 4
0159
0160
0161
0162
0163
0164 if ~isfield(in,'test')
0165
0166 if in.use_kernel
0167 Phi = prt_normalise_kernel(in.train{d});
0168 tr = 1:size(in.train{d},1);
0169 out.train{d} = Phi(tr,tr);
0170 else
0171 out.train{d} = zeros(size(in.train{d}));
0172 for r = 1:size(in.train{d})
0173 out.train{d}(r,:) = in.train{d}(r,:) / norm(in.train{d}(r,:));
0174 end
0175 end
0176 else
0177
0178 if in.use_kernel
0179 Phi = [in.train{d}, in.test{d}'; in.test{d}, in.testcov{d}];
0180 Phi = prt_normalise_kernel(Phi);
0181
0182 tr = 1:size(in.train{d},1);
0183 te = (1:size(in.test{d},1))+max(tr);
0184 out.train{d} = Phi(tr,tr);
0185 out.test{d} = Phi(te,tr);
0186 out.testcov{d} = Phi(te,te);
0187 else
0188 out.train{d} = zeros(size(in.train{d}));
0189 for r = 1:size(in.train{d})
0190 out.train{d}(r,:) = in.train{d}(r,:) / norm(in.train{d}(r,:));
0191 end
0192 out.train{d} = zeros(size(in.test{d}));
0193 for r = 1:size(in.test{d})
0194 out.test{d}(r,:) = in.test{d}(r,:) / norm(in.test{d}(r,:));
0195 end
0196 end
0197 out.te_id = in.te_id;
0198 end
0199 out.tr_id = in.tr_id;
0200 if isfield(in,'tr_targets')
0201 out.tr_targets = in.tr_targets;
0202 end
0203 if isfield(in,'te_targets')
0204 out.te_targets = in.te_targets;
0205 end
0206
0207 case 5
0208
0209
0210 if ~isfield(in,'tr_cov')
0211 error('prt_apply_operation:NoCovariates',...
0212 'No covariates found to perform requested GLM');
0213 end
0214 if ~isfield(in,'test')
0215
0216 if in.use_kernel
0217 out.train{d} = prt_remove_confounds(in.train{d},in.tr_cov);
0218 else
0219 error('prt_apply_operation:GLMnonKernel',...
0220 'GLM not implemented for non-kernel methods');
0221 end
0222 else
0223 Phi = [in.train{d}, in.test{d}'; in.test{d}, in.testcov{d}];
0224 if in.use_kernel
0225 [Phi] = prt_remove_confounds(Phi,[in.tr_cov;in.te_cov]);
0226 tr = 1:size(in.train{d},1);
0227 te = (1:size(in.test{d},1))+max(tr);
0228 out.train{d} = Phi(tr,tr);
0229 out.test{d} = Phi(te,tr);
0230 out.testcov{d} = Phi(te,te);
0231 else
0232 error('prt_apply_operation:GLMnonKernel',...
0233 'GLM not implemented for non-kernel methods');
0234 end
0235 out.te_id = in.te_id;
0236 end
0237 out.tr_id = in.tr_id;
0238 if isfield(in,'tr_targets')
0239 out.tr_targets = in.tr_targets;
0240 end
0241 if isfield(in,'te_targets')
0242 out.te_targets = in.te_targets;
0243 end
0244
0245
0246 otherwise
0247 error('prt_apply_operation:UnknownOperationSpecified',...
0248 'Unknown operation requested');
0249 end
0250 end
0251
0252
0253
0254
0255
0256 end
0257
0258
0259
0260
0261
0262 function P = compute_tc_mat(ID)
0263
0264
0265
0266
0267 IDc = zeros(size(ID,1),1);
0268 C = {};
0269 ccount = 0;
0270 lastid = zeros(1,5);
0271 for c = 1:size(ID,1)
0272 currid = ID(c,1:5);
0273 if any(lastid ~= currid)
0274 ccount = ccount + 1;
0275 end
0276 lastid = currid;
0277 IDc(c) = ccount;
0278 end
0279
0280
0281 cids = unique(IDc);
0282 cnums = histc(IDc,cids);
0283 C = cell(length(cnums),1);
0284 for c = 1:length(cnums)
0285 C{c} = 1/cnums(c) .* ones(1,cnums(c));
0286 end
0287 P = blkdiag(C{:});
0288 end
0289
0290 function P = compute_sa_mat(ID)
0291
0292
0293
0294
0295 IDs = zeros(size(ID,1),1);
0296 ccount = 0;
0297 lastid = zeros(1,2);
0298 for s = 1:size(ID,1)
0299 currid = ID(s,1:2);
0300 if any(lastid ~= currid)
0301 ccount = ccount + 1;
0302 end
0303 lastid = currid;
0304 IDs(s) = ccount;
0305 end
0306
0307 subs = unique(IDs);
0308
0309 P = [];
0310 for s = 1:length(subs)
0311 sidx = IDs == subs(s);
0312 conds = unique(ID(sidx,4));
0313 for c = 1:length(conds)
0314 p = (IDs == s & ID(:,4) == conds(c))';
0315 P = [P; 1./sum(p) * double(p)];
0316 end
0317 end
0318 P = double(P);
0319 end
0320