0001 function img_name = prt_compute_weights_regre(PRT,in,model_idx,flag, ibe, flag2)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030 mfunc = PRT.model(model_idx).input.machine.function;
0031 mname = PRT.model(model_idx).model_name;
0032 m.args = [];
0033
0034 if nargin<5
0035 ibe=[];
0036 end
0037 if nargin<6
0038 flag2 = 0;
0039 end
0040
0041
0042
0043
0044 switch mfunc
0045 case 'prt_machine_sMKL_reg'
0046 m.function = 'prt_weights_sMKL_reg';
0047 img_mach{1} = ['weights_',mname,'.img'];
0048 case 'prt_machine_RT_bin'
0049 error('prt_compute_weights:MachineNotSupported',...
0050 'Error: weights computation not supported for this machine!');
0051 otherwise
0052 m.function = 'prt_weights_bin_linkernel';
0053 img_mach{1} = ['weights_',mname,'.img'];
0054 end
0055
0056 nimage = length(img_mach);
0057
0058
0059 if ~isempty(in.img_name)
0060 if ~(prt_checkAlphaNumUnder(in.img_name))
0061 error('prt_compute_weights:NameNotAlphaNumeric',...
0062 'Error: image name should contain only alpha-numeric elements!');
0063 end
0064 if nimage>1 && ~flag2
0065 for c = 1:nimage
0066 in.img_name_c = [in.img_name,'_',num2str(c),'.img'];
0067 img_name{c} = fullfile(in.pathdir,in.img_name_c);
0068 end
0069 else
0070 img_name{1} = fullfile(in.pathdir,[in.img_name,'.img']);
0071 end
0072 else
0073 for c = 1:nimage
0074 img_name{c} = fullfile(in.pathdir,img_mach{c});
0075 end
0076 end
0077
0078
0079
0080 fs_name = PRT.model(model_idx).input.fs(1).fs_name;
0081 samp_idx = PRT.model(model_idx).input.samp_idx;
0082 nfold = length(PRT.model(model_idx).output.fold);
0083
0084
0085
0086 nfs = length(PRT.fs);
0087 for f = 1:nfs
0088 if strcmp(PRT.fs(f).fs_name,fs_name)
0089 fs_idx = f;
0090 end
0091 end
0092 ID = PRT.fs(fs_idx).id_mat(PRT.model(model_idx).input.samp_idx,:);
0093 ID_all = PRT.fs(fs_idx).id_mat;
0094
0095
0096
0097 fas_idx = in.fas_idx;
0098 mm = in.mm;
0099
0100
0101
0102
0103 idROI = [];
0104 idfeat = PRT.fas(fas_idx(1)).idfeat_img;
0105 if isempty(PRT.fs(fs_idx).modality(mm(1)).idfeat_fas)
0106 idfeat_fas = 1:length(idfeat);
0107 else
0108 idfeat_fas = PRT.fs(fs_idx).modality(mm(1)).idfeat_fas;
0109 end
0110 if PRT.fs(fs_idx).multkernelROI
0111 m_train = cell(length(PRT.fs(fs_idx).modality(mm(1)).idfeat_img),1);
0112 for i = 1:length(PRT.fs(fs_idx).modality(mm(1)).idfeat_img)
0113 tmp1 = PRT.fs(fs_idx).modality(mm(1)).idfeat_img{i};
0114 idROI=[idROI;tmp1];
0115 tmp = idfeat_fas(tmp1);
0116 m_train{i} = idfeat(tmp);
0117 end
0118 id2 = idfeat_fas(sort(idROI));
0119 else
0120 id2 = idfeat_fas;
0121 end
0122 mask_train = idfeat(id2);
0123 voxtr = find(ismember(idfeat,mask_train));
0124
0125
0126
0127
0128
0129 if flag
0130
0131 if isfield(PRT.model(model_idx).output,'permutation') && ...
0132 ~isempty(PRT.model(model_idx).output.permutation)
0133 maxp = length(PRT.model(model_idx).output.permutation);
0134 else
0135 disp('No parameters saved for the permutation, building weight image only')
0136 end
0137 else
0138 maxp=0;
0139 end
0140 pthperm = cell(nimage,1);
0141 for p=0:maxp
0142 if p>0
0143 for c = 1:nimage
0144 [pth,nam] = fileparts(img_name{c});
0145 if p==1
0146 pthperm{c} = fullfile(pth,['perm_',nam]);
0147 if ~exist(pthperm{c},'dir')
0148 mkdir(pth,['perm_',nam]);
0149 end
0150 end
0151 img_nam{c} = fullfile(pthperm{c},[nam,'_perm',num2str(p),'.img']);
0152 end
0153 fprintf('Permutation: %d of %d \n',p, ...
0154 length(PRT.model(model_idx).output.permutation));
0155 else
0156 img_nam = img_name;
0157 end
0158
0159
0160 if exist(img_nam{1},'file')
0161 for c = 1:nimage
0162 delete(img_nam{c});
0163
0164 [pth,nam] = fileparts(img_nam{c});
0165 hdr_name = [pth,filesep,nam,'.hdr'];
0166 delete(hdr_name)
0167 end
0168 end
0169
0170 hdr = PRT.fas(fas_idx(1)).hdr.private;
0171 dat_dim = hdr.dat.dim;
0172
0173 if length(dat_dim)==2, dat_dim = [dat_dim 1]; end
0174
0175 img4d = cell(nimage,1);
0176 for c = 1:nimage
0177 if p==0
0178 folds_comp=nfold+1;
0179 else
0180 folds_comp=1;
0181 end
0182 img4d{c} = file_array(img_nam{c},[dat_dim(1),dat_dim(2),...
0183 dat_dim(3),folds_comp],'float32-le',0,1,0);
0184 end
0185
0186 zdim = dat_dim(3);
0187 xydim = dat_dim(1)*dat_dim(2);
0188
0189
0190 disp('Computing weights.......>>')
0191
0192 for z = 1:zdim
0193
0194 fprintf('Slice: %d of %d \n',z,zdim);
0195
0196 img3dav = cell(1,nimage);
0197 for c = 1:nimage
0198 img3dav{c} = zeros(1,xydim);
0199 end
0200
0201 if ~isempty(idROI)
0202 feat_slc = mask_train(mask_train>=(xydim*(z-1)+1) & ...
0203 mask_train<=(xydim*z));
0204 for ir = 1:length(m_train)
0205 tmp = m_train{ir}(m_train{ir}>=(xydim*(z-1)+1) & ...
0206 m_train{ir}<=(xydim*z));
0207 m.args.idfeat_img{ir} = find(ismember(feat_slc,tmp));
0208 end
0209 feat_slc = find(mask_train>=(xydim*(z-1)+1) & ...
0210 mask_train<=(xydim*z));
0211 else
0212 feat_slc = find(mask_train>=(xydim*(z-1)+1) & ...
0213 mask_train<=(xydim*z));
0214 m.args.idfeat_img = {1:length(feat_slc)};
0215 end
0216
0217 if isempty(feat_slc)
0218
0219 for c = 1:nimage
0220 img4d{c}(:,:,z,:) = NaN*zeros(dat_dim(1),dat_dim(2),1,folds_comp);
0221 end
0222
0223 else
0224
0225 for f = 1:nfold
0226
0227 train_idx = PRT.model(model_idx).input.cv_mat(:,f)==1;
0228 train = samp_idx(train_idx);
0229 train_all = zeros(size(ID_all,1),1); train_all(train) = 1;
0230 if p>0
0231 d.coeffs = PRT.model(model_idx).output.permutation(p).fold(f).alpha;
0232 else
0233 d.coeffs = PRT.model(model_idx).output.fold(f).alpha;
0234 end
0235
0236 d.datamat = zeros(length(train), length(feat_slc));
0237 for i = 1:length(fas_idx)
0238
0239 indm = find(PRT.fs(fs_idx).fas.im == fas_idx(i));
0240 if PRT.fs(fs_idx).multkernel
0241 indtr = ID(train_idx,3) == fas_idx(1);
0242 indm = indm(find(train_all));
0243 else
0244 indtr = ID(train_idx,3) == fas_idx(i);
0245 indm = indm(find(train_all(ID_all(:,3)==fas_idx(i))));
0246 end
0247 ifa = PRT.fs(fs_idx).fas.ifa(indm);
0248
0249
0250 d.datamat(indtr,:) = PRT.fas(fas_idx(i)).dat(ifa,voxtr(feat_slc));
0251 end
0252
0253
0254 ops = PRT.model(model_idx).input.operations(PRT.model(model_idx).input.operations ~=0 );
0255 cvdata.train = {d.datamat};
0256 cvdata.tr_id = ID(train_idx,:);
0257 cvdata.use_kernel = false;
0258 for o = 1:length(ops)
0259 cvdata = prt_apply_operation(PRT, cvdata, ops(o));
0260 end
0261 d.datamat = cvdata.train{:};
0262
0263 if strcmpi(mfunc,'prt_machine_sMKL_reg')
0264 if isempty(ibe)
0265 m.args.betas = PRT.model(model_idx).output.fold(f).beta;
0266 else
0267 m.args.betas = PRT.model(model_idx).output.fold(f).beta(ibe);
0268 end
0269 end
0270
0271 if flag2
0272 m.args.flag = 1;
0273 end
0274
0275
0276 wimg = prt_weights(d,m);
0277
0278 for c = 1:nimage,
0279 img3d = zeros(1,xydim);
0280 indi = mask_train(feat_slc)-xydim*(z-1);
0281 indm = setdiff(1:xydim,indi);
0282 img3d(indi) = wimg{c};
0283 norm3d{c}(f) = sum(img3d.^2);
0284 img3d(indm) = NaN;
0285 img3dav{c} = img3dav{c} + img3d;
0286 if p==0
0287 img4d{c}(:,:,z,f) = reshape(img3d,dat_dim(1),dat_dim(2),1,1);
0288 end
0289 end
0290
0291 end
0292
0293
0294
0295
0296
0297 for c = 1:nimage
0298 norm4d{c}(z,:) = norm3d{c};
0299 img3dav{c} = img3dav{c}/nfold;
0300 img4d{c}(:,:,z,folds_comp) = reshape(img3dav{c},dat_dim(1),dat_dim(2),1,1);
0301 norm4dav{c}(z,:) = sum(img3dav{c}(isfinite(img3dav{c})).^2);
0302 end
0303 end
0304
0305 end
0306
0307 for c =1:nimage
0308 norm4d{c} = sqrt(sum(norm4d{c},1));
0309 norm4dav{c} = sqrt(sum(norm4dav{c},1));
0310 end
0311
0312 disp('Normalising weights--------->>')
0313 if p==0
0314 for f = 1:nfold,
0315 for c = 1:nimage
0316 if unique(norm4d{c}(1,f))~=0
0317 img4d{c}(:,:,:,f) = img4d{c}(:,:,:,f)./norm4d{c}(1,f);
0318 else
0319 img4d{c}(:,:,:,f) = img4d{c}(:,:,:,f);
0320 end
0321 end
0322 end
0323 end
0324
0325 for c = 1:nimage
0326 img4d{c}(:,:,:,folds_comp) = img4d{c}(:,:,:,folds_comp)./norm4dav{c};
0327 end
0328
0329
0330
0331 clear No
0332 for c = 1:nimage
0333 fprintf('Creating image %d of %d--------->>\n',c,nimage);
0334 No = hdr;
0335 No.dat = img4d{c};
0336 No.descrip = 'Pronto weigths';
0337 create(No);
0338 disp('Done.')
0339 end
0340 end
0341