0001 function img_name = prt_compute_weights_class(PRT,in,model_idx,flag, ibe, flag2)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030 mfunc = PRT.model(model_idx).input.machine.function;
0031 mname = PRT.model(model_idx).model_name;
0032 nclass = length(PRT.model(model_idx).input.class);
0033 if nclass > 2, mfunc = 'multiclass_machine'; end
0034 m.args = [];
0035
0036 if nargin<5
0037 ibe=[];
0038 end
0039 if nargin<6
0040 flag2 = 0;
0041 end
0042
0043
0044
0045
0046 switch mfunc
0047
0048 case 'multiclass_machine'
0049 m.function = 'prt_weights_gpclap';
0050 nclass = length(PRT.model(model_idx).input.class);
0051 for c = 1:nclass
0052 img_mach{c} = ['weights_',mname,'_',num2str(c),'.img'];
0053 end
0054 case 'prt_machine_sMKL_cla'
0055 m.function = 'prt_weights_sMKL_cla';
0056 img_mach{1} = ['weights_',mname,'.img'];
0057 case 'prt_machine_RT_bin'
0058 error('prt_compute_weights:MachineNotSupported',...
0059 'Error: weights computation not supported for this machine!');
0060 otherwise
0061 m.function = 'prt_weights_bin_linkernel';
0062 img_mach{1} = ['weights_',mname,'.img'];
0063 end
0064
0065 nimage = length(img_mach);
0066
0067
0068 if ~isempty(in.img_name)
0069 if ~(prt_checkAlphaNumUnder(in.img_name))
0070 error('prt_compute_weights:NameNotAlphaNumeric',...
0071 'Error: image name should contain only alpha-numeric elements!');
0072 end
0073 if nimage>1 && ~flag2
0074 for c = 1:nimage
0075 in.img_name_c = [in.img_name,'_',num2str(c),'.img'];
0076 img_name{c} = fullfile(in.pathdir,in.img_name_c);
0077 end
0078 else
0079 img_name{1} = fullfile(in.pathdir,[in.img_name,'.img']);
0080 end
0081 else
0082 for c = 1:nimage
0083 img_name{c} = fullfile(in.pathdir,img_mach{c});
0084 end
0085 end
0086
0087
0088
0089 fs_name = PRT.model(model_idx).input.fs(1).fs_name;
0090 samp_idx = PRT.model(model_idx).input.samp_idx;
0091 nfold = length(PRT.model(model_idx).output.fold);
0092
0093
0094
0095 nfs = length(PRT.fs);
0096 for f = 1:nfs
0097 if strcmp(PRT.fs(f).fs_name,fs_name)
0098 fs_idx = f;
0099 end
0100 end
0101 ID = PRT.fs(fs_idx).id_mat(PRT.model(model_idx).input.samp_idx,:);
0102 ID_all = PRT.fs(fs_idx).id_mat;
0103
0104
0105
0106 fas_idx = in.fas_idx;
0107 mm = in.mm;
0108
0109
0110
0111
0112 idROI = [];
0113 idfeat = PRT.fas(fas_idx(1)).idfeat_img;
0114 if isempty(PRT.fs(fs_idx).modality(mm(1)).idfeat_fas)
0115 idfeat_fas = 1:length(idfeat);
0116 else
0117 idfeat_fas = PRT.fs(fs_idx).modality(mm(1)).idfeat_fas;
0118 end
0119 if PRT.fs(fs_idx).multkernelROI
0120 m_train = cell(length(PRT.fs(fs_idx).modality(mm(1)).idfeat_img),1);
0121 for i = 1:length(PRT.fs(fs_idx).modality(mm(1)).idfeat_img)
0122 tmp1 = PRT.fs(fs_idx).modality(mm(1)).idfeat_img{i};
0123 idROI=[idROI;tmp1];
0124 tmp = idfeat_fas(tmp1);
0125 m_train{i} = idfeat(tmp);
0126 end
0127 id2 = idfeat_fas(sort(idROI));
0128 else
0129 id2 = idfeat_fas;
0130 end
0131 mask_train = idfeat(id2);
0132 voxtr = find(ismember(idfeat,mask_train));
0133
0134
0135
0136
0137
0138 if flag
0139
0140 if isfield(PRT.model(model_idx).output,'permutation') && ...
0141 ~isempty(PRT.model(model_idx).output.permutation)
0142 maxp = length(PRT.model(model_idx).output.permutation);
0143 else
0144 disp('No parameters saved for the permutation, building weight image only')
0145 end
0146 else
0147 maxp=0;
0148 end
0149 pthperm = cell(nimage,1);
0150 for p=0:maxp
0151 if p>0
0152 for c = 1:nimage
0153 [pth,nam] = fileparts(img_name{c});
0154 if p==1
0155 pthperm{c} = fullfile(pth,['perm_',nam]);
0156 if ~exist(pthperm{c},'dir')
0157 mkdir(pth,['perm_',nam]);
0158 end
0159 end
0160 img_nam{c} = fullfile(pthperm{c},[nam,'_perm',num2str(p),'.img']);
0161 end
0162 fprintf('Permutation: %d of %d \n',p, ...
0163 length(PRT.model(model_idx).output.permutation));
0164 else
0165 img_nam = img_name;
0166 end
0167
0168
0169 if exist(img_nam{1},'file')
0170 for c = 1:nimage
0171 delete(img_nam{c});
0172
0173 [pth,nam] = fileparts(img_nam{c});
0174 hdr_name = [pth,filesep,nam,'.hdr'];
0175 delete(hdr_name)
0176 end
0177 end
0178
0179 hdr = PRT.fas(fas_idx(1)).hdr.private;
0180 dat_dim = hdr.dat.dim;
0181
0182 if length(dat_dim)==2, dat_dim = [dat_dim 1]; end
0183
0184 img4d = cell(nimage,1);
0185 for c = 1:nimage
0186 if p==0
0187 folds_comp=nfold+1;
0188 else
0189 folds_comp=1;
0190 end
0191 img4d{c} = file_array(img_nam{c},[dat_dim(1),dat_dim(2),...
0192 dat_dim(3),folds_comp],'float32-le',0,1,0);
0193 end
0194
0195 zdim = dat_dim(3);
0196 xydim = dat_dim(1)*dat_dim(2);
0197
0198
0199 disp('Computing weights.......>>')
0200
0201 for z = 1:zdim
0202
0203 fprintf('Slice: %d of %d \n',z,zdim);
0204
0205 img3dav = cell(1,nimage);
0206 for c = 1:nimage
0207 img3dav{c} = zeros(1,xydim);
0208 end
0209
0210 if ~isempty(idROI)
0211 feat_slc = mask_train(mask_train>=(xydim*(z-1)+1) & ...
0212 mask_train<=(xydim*z));
0213 for ir = 1:length(m_train)
0214 tmp = m_train{ir}(m_train{ir}>=(xydim*(z-1)+1) & ...
0215 m_train{ir}<=(xydim*z));
0216 m.args.idfeat_img{ir} = find(ismember(feat_slc,tmp));
0217 end
0218 feat_slc = find(mask_train>=(xydim*(z-1)+1) & ...
0219 mask_train<=(xydim*z));
0220 else
0221 feat_slc = find(mask_train>=(xydim*(z-1)+1) & ...
0222 mask_train<=(xydim*z));
0223 m.args.idfeat_img = {1:length(feat_slc)};
0224 end
0225
0226 if isempty(feat_slc)
0227
0228 for c = 1:nimage
0229 img4d{c}(:,:,z,:) = NaN*zeros(dat_dim(1),dat_dim(2),1,folds_comp);
0230 end
0231
0232 else
0233
0234 for f = 1:nfold
0235
0236 train_idx = PRT.model(model_idx).input.cv_mat(:,f)==1;
0237 train = samp_idx(train_idx);
0238 train_all = zeros(size(ID_all,1),1); train_all(train) = 1;
0239 if p>0
0240 d.coeffs = PRT.model(model_idx).output.permutation(p).fold(f).alpha;
0241 else
0242 d.coeffs = PRT.model(model_idx).output.fold(f).alpha;
0243 end
0244
0245 d.datamat = zeros(length(train), length(feat_slc));
0246 for i = 1:length(fas_idx)
0247
0248 indm = find(PRT.fs(fs_idx).fas.im == fas_idx(i));
0249 if PRT.fs(fs_idx).multkernel
0250 indtr = ID(train_idx,3) == fas_idx(1);
0251 indm = indm(find(train_all));
0252 else
0253 indtr = ID(train_idx,3) == fas_idx(i);
0254 indm = indm(find(train_all(ID_all(:,3)==fas_idx(i))));
0255 end
0256 ifa = PRT.fs(fs_idx).fas.ifa(indm);
0257
0258
0259 d.datamat(indtr,:) = PRT.fas(fas_idx(i)).dat(ifa,voxtr(feat_slc));
0260 end
0261
0262
0263 ops = PRT.model(model_idx).input.operations(PRT.model(model_idx).input.operations ~=0 );
0264 cvdata.train = {d.datamat};
0265 cvdata.tr_id = ID(train_idx,:);
0266 cvdata.use_kernel = false;
0267 for o = 1:length(ops)
0268 cvdata = prt_apply_operation(PRT, cvdata, ops(o));
0269 end
0270 d.datamat = cvdata.train{:};
0271
0272 if strcmpi(mfunc,'prt_machine_sMKL_cla')
0273 if isempty(ibe)
0274 m.args.betas = PRT.model(model_idx).output.fold(f).beta;
0275 else
0276 m.args.betas = PRT.model(model_idx).output.fold(f).beta(ibe);
0277 end
0278 end
0279
0280 if flag2
0281 m.args.flag = 1;
0282 end
0283
0284
0285 wimg = prt_weights(d,m);
0286
0287 for c = 1:nimage,
0288 img3d = zeros(1,xydim);
0289 indi = mask_train(feat_slc)-xydim*(z-1);
0290 indm = setdiff(1:xydim,indi);
0291 img3d(indi) = wimg{c};
0292 norm3d{c}(f) = sum(img3d.^2);
0293 img3d(indm) = NaN;
0294 img3dav{c} = img3dav{c} + img3d;
0295 if p==0
0296 img4d{c}(:,:,z,f) = reshape(img3d,dat_dim(1),dat_dim(2),1,1);
0297 end
0298 end
0299
0300 end
0301
0302
0303
0304
0305
0306 for c = 1:nimage
0307 norm4d{c}(z,:) = norm3d{c};
0308 img3dav{c} = img3dav{c}/nfold;
0309 img4d{c}(:,:,z,folds_comp) = reshape(img3dav{c},dat_dim(1),dat_dim(2),1,1);
0310 norm4dav{c}(z,:) = sum(img3dav{c}(isfinite(img3dav{c})).^2);
0311 end
0312 end
0313
0314 end
0315
0316 for c =1:nimage
0317 norm4d{c} = sqrt(sum(norm4d{c},1));
0318 norm4dav{c} = sqrt(sum(norm4dav{c},1));
0319 end
0320
0321 disp('Normalising weights--------->>')
0322 if p==0
0323 for f = 1:nfold,
0324 for c = 1:nimage
0325 if unique(norm4d{c}(1,f))~=0
0326 img4d{c}(:,:,:,f) = img4d{c}(:,:,:,f)./norm4d{c}(1,f);
0327 else
0328 img4d{c}(:,:,:,f) = img4d{c}(:,:,:,f);
0329 end
0330 end
0331 end
0332 end
0333
0334 for c = 1:nimage
0335 img4d{c}(:,:,:,folds_comp) = img4d{c}(:,:,:,folds_comp)./norm4dav{c};
0336 end
0337
0338
0339
0340 clear No
0341 for c = 1:nimage
0342 fprintf('Creating image %d of %d--------->>\n',c,nimage);
0343 No = hdr;
0344 No.dat = img4d{c};
0345 No.descrip = 'Pronto weigths';
0346 create(No);
0347 disp('Done.')
0348 end
0349 end
0350