0001 function out = prt_apply_operation(PRT, in, opid)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040 out = in;
0041
0042 for d = 1:length(in.train)
0043 switch opid
0044 case 1
0045
0046
0047
0048 Ptr = compute_tc_mat(in.tr_id);
0049 if in.use_kernel
0050 out.train{d} = Ptr*in.train{d}*Ptr';
0051 else
0052 out.train{d} = Ptr*in.train{d};
0053 end
0054 out.tr_id = round(Ptr*in.tr_id);
0055 if isfield(in,'tr_targets')
0056 out.tr_targets = Ptr*in.tr_targets;
0057 if strcmpi(in.pred_type,'classification');
0058 out.tr_targets = round(out.tr_targets);
0059 end
0060 end
0061
0062
0063 if isfield(in,'test')
0064 Pte = compute_tc_mat(in.te_id);
0065 if in.use_kernel
0066 out.test{d} = Pte*in.test{d}*Ptr';
0067 out.testcov{d} = Pte*in.testcov{d}*Pte';
0068 else
0069 out.test{d} = Pte*in.test{d};
0070 end
0071 out.te_id = round(Pte*in.te_id);
0072 if isfield(in,'te_targets')
0073 out.te_targets = Pte*in.te_targets;
0074 if strcmpi(in.pred_type,'classification');
0075 out.te_targets = round(out.te_targets);
0076 end
0077 end
0078 end
0079
0080 case 2
0081
0082
0083
0084 Ptr = compute_sa_mat(in.tr_id);
0085 if in.use_kernel
0086 out.train{d} = Ptr*in.train{d}*Ptr';
0087 else
0088 out.train{d} = Ptr*in.train{d};
0089 end
0090 out.tr_id = round(Ptr*in.tr_id);
0091 if isfield(in,'tr_targets')
0092 out.tr_targets = Ptr*in.tr_targets;
0093 if strcmpi(in.pred_type,'classification');
0094 out.tr_targets = round(out.tr_targets);
0095 end
0096 end
0097
0098
0099 if isfield(in,'test')
0100 Pte = compute_sa_mat(in.te_id);
0101 if in.use_kernel
0102 out.test{d} = Pte*in.test{d}*Ptr';
0103 out.testcov{d} = Pte*in.testcov{d}*Pte';
0104 else
0105 out.test{d} = Pte*in.test{d};
0106 end
0107 out.te_id = round(Pte*in.te_id);
0108 if isfield(in,'te_targets')
0109 out.te_targets = Pte*in.te_targets;
0110 if strcmpi(in.pred_type,'classification');
0111 out.te_targets = round(out.te_targets);
0112 end
0113 end
0114 end
0115
0116 case 3
0117
0118
0119 if ~isfield(in,'test')
0120
0121 if in.use_kernel
0122 out.train{d} = prt_centre_kernel(in.train{d});
0123 else
0124 m = mean(in.train{d});
0125
0126 out.train{d} = zeros(size(in.train{d}));
0127 for r = 1:size(in.train{d},1)
0128 out.train{d}(r,:) = in.train{d}(r,:) - m;
0129 end
0130 end
0131 else
0132 if in.use_kernel
0133 [out.train{d}, out.test{d}, out.testcov{d}] = ...
0134 prt_centre_kernel(in.train{d},in.test{d},in.testcov{d});
0135 else
0136 m = mean(in.train{d});
0137
0138
0139 out.train{d} = zeros(size(in.train{d}));
0140 for r = 1:size(in.train{d},1)
0141 out.train{d}(r,:) = in.train{d}(r,:) - m;
0142 end
0143 out.test{d} = zeros(size(in.test{d}));
0144 for r = 1:size(in.test{d},1)
0145 out.test{d}(r,:) = in.test{d}(r,:) - m;
0146 end
0147 end
0148 out.te_id = in.te_id;
0149 end
0150 out.tr_id = in.tr_id;
0151 if isfield(in,'tr_targets')
0152 out.tr_targets = in.tr_targets;
0153 end
0154 if isfield(in,'te_targets')
0155 out.te_targets = in.te_targets;
0156 end
0157
0158 case 4
0159
0160
0161
0162
0163
0164 if ~isfield(in,'test')
0165
0166 if in.use_kernel
0167 Phi = prt_normalise_kernel(in.train{d});
0168 tr = 1:size(in.train{d},1);
0169 out.train{d} = Phi(tr,tr);
0170 else
0171 out.train{d} = zeros(size(in.train{d}));
0172 for r = 1:size(in.train{d})
0173 out.train{d}(r,:) = in.train{d}(r,:) / norm(in.train{d}(r,:));
0174 end
0175 end
0176 else
0177
0178 if in.use_kernel
0179 Phi = [in.train{d}, in.test{d}'; in.test{d}, in.testcov{d}];
0180 Phi = prt_normalise_kernel(Phi);
0181
0182 tr = 1:size(in.train{d},1);
0183 te = (1:size(in.test{d},1))+max(tr);
0184 out.train{d} = Phi(tr,tr);
0185 out.test{d} = Phi(te,tr);
0186 out.testcov{d} = Phi(te,te);
0187 else
0188 out.train{d} = zeros(size(in.train{d}));
0189 for r = 1:size(in.train{d})
0190 out.train{d}(r,:) = in.train{d}(r,:) / norm(in.train{d}(r,:));
0191 end
0192 out.train{d} = zeros(size(in.test{d}));
0193 for r = 1:size(in.test{d})
0194 out.test{d}(r,:) = in.test{d}(r,:) / norm(in.test{d}(r,:));
0195 end
0196 end
0197 out.te_id = in.te_id;
0198 end
0199 out.tr_id = in.tr_id;
0200 if isfield(in,'tr_targets')
0201 out.tr_targets = in.tr_targets;
0202 end
0203 if isfield(in,'te_targets')
0204 out.te_targets = in.te_targets;
0205 end
0206
0207 case 5
0208
0209
0210 error ('GLM not implemented yet');
0211
0212 otherwise
0213 error('prt_apply_operation:UnknownOperationSpecified',...
0214 'Unknown operation requested');
0215 end
0216 end
0217
0218
0219
0220
0221
0222 end
0223
0224
0225
0226
0227
0228 function P = compute_tc_mat(ID)
0229
0230
0231
0232
0233 IDc = zeros(size(ID,1),1);
0234 C = {};
0235 ccount = 0;
0236 lastid = zeros(1,5);
0237 for c = 1:size(ID,1)
0238 currid = ID(c,1:5);
0239 if any(lastid ~= currid)
0240 ccount = ccount + 1;
0241 end
0242 lastid = currid;
0243 IDc(c) = ccount;
0244 end
0245
0246
0247 cids = unique(IDc);
0248 cnums = histc(IDc,cids);
0249 C = cell(length(cnums),1);
0250 for c = 1:length(cnums)
0251 C{c} = 1/cnums(c) .* ones(1,cnums(c));
0252 end
0253 P = blkdiag(C{:});
0254 end
0255
0256 function P = compute_sa_mat(ID)
0257
0258
0259
0260
0261 IDs = zeros(size(ID,1),1);
0262 ccount = 0;
0263 lastid = zeros(1,2);
0264 for s = 1:size(ID,1)
0265 currid = ID(s,1:2);
0266 if any(lastid ~= currid)
0267 ccount = ccount + 1;
0268 end
0269 lastid = currid;
0270 IDs(s) = ccount;
0271 end
0272
0273 subs = unique(IDs);
0274
0275 P = [];
0276 for s = 1:length(subs)
0277 sidx = IDs == subs(s);
0278 conds = unique(ID(sidx,4));
0279 for c = 1:length(conds)
0280 p = (IDs == s & ID(:,4) == conds(c))';
0281 P = [P; 1./sum(p) * double(p)];
0282 end
0283 end
0284 P = double(P);
0285 end
0286