0001 function prt_compute_weights(PRT,in)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025 nmodel = length(PRT.model);
0026 model_idx = 0;
0027 for i = 1:nmodel
0028 if strcmp(PRT.model(i).model_name,in.model_name)
0029 model_idx = i;
0030 end
0031 end
0032
0033 if model_idx == 0, error('prt_compute_weights:ModelNotFound',...
0034 'Error: model not found in PRT.mat!'); end
0035
0036
0037
0038 mfunc = PRT.model(model_idx).input.machine.function;
0039 mname = PRT.model(model_idx).model_name;
0040 m.args = [];
0041 m.function = 'prt_weights_bin_linkernel';
0042
0043
0044
0045 img_mach = ['weights_',mname,'.img'];
0046
0047 switch mfunc
0048 case 'prt_machine_RT_bin'
0049 error('prt_compute_weights:MachineNotSupported',...
0050 'Error: weights computation not supported for this machine!');
0051 case 'prt_machine_gpclap'
0052 error('prt_compute_weights:MachineNotSupported',...
0053 'Error: weights computation not supported yet for this machine!');
0054 end
0055
0056
0057
0058
0059 if ~isempty(in.img_name)
0060 if ~(prt_checkAlphaNumUnder(in.img_name))
0061 error('prt_compute_weights:NameNotAlphaNumeric',...
0062 'Error: image name should contain only alpha-numeric elements!');
0063 end
0064 in.img_name = [in.img_name,'.img'];
0065 img_name = fullfile(in.pathdir,in.img_name);
0066 else
0067 img_name = fullfile(in.pathdir,img_mach);
0068 end
0069
0070
0071
0072 fs_name = PRT.model(model_idx).input.fs(1).fs_name;
0073 samp_idx = PRT.model(model_idx).input.samp_idx;
0074 nfold = length(PRT.model(model_idx).output.fold);
0075
0076
0077
0078 nfs = length(PRT.fs);
0079 for f = 1:nfs
0080 if strcmp(PRT.fs(f).fs_name,fs_name)
0081 fs_idx = f;
0082 end
0083 end
0084 ID = PRT.fs(fs_idx).id_mat(PRT.model(model_idx).input.samp_idx,:);
0085 ID_all = PRT.fs(fs_idx).id_mat;
0086
0087
0088
0089 nfas = length(PRT.fas);
0090 mods = {PRT.fs(fs_idx).modality.mod_name};
0091 fas = zeros(1,nfas);
0092 for i = 1:nfas
0093 for j = 1:length(mods)
0094 if strcmpi(PRT.fas(i).mod_name,mods{j})
0095 fas(i) = 1;
0096 mm=j;
0097 end
0098 end
0099 end
0100 fas_idx = find(fas);
0101
0102
0103
0104 idfeat=PRT.fas(fas_idx(1)).idfeat_img;
0105 if ~isempty(PRT.fs(fs_idx).modality(mm).idfeat_fas)
0106 mask_train=idfeat(PRT.fs(fs_idx).modality(mm).idfeat_fas);
0107 voxtr=find(ismember(idfeat,mask_train));
0108 else
0109 mask_train=idfeat;
0110 voxtr=1:length(idfeat);
0111 end
0112
0113
0114
0115 hdr = PRT.fas(fas_idx(1)).hdr.private;
0116 img4d = file_array(img_name,[hdr.dat.dim(1),hdr.dat.dim(2),...
0117 hdr.dat.dim(3),nfold+1],'float64-le',0,1,0);
0118
0119 zdim = hdr.dat.dim(3);
0120 xydim = hdr.dat.dim(1)*hdr.dat.dim(2);
0121 norm3d = 0;
0122
0123 disp('Computing weights.......>>')
0124
0125 for z = 1:zdim
0126
0127 disp(sprintf('Slice: %d of %d',z,zdim))
0128
0129 img3dav = zeros(1,xydim);
0130
0131 feat_slc = find(mask_train>=(xydim*(z-1)+1) & ...
0132 mask_train<=(xydim*z));
0133
0134 if isempty(feat_slc)
0135
0136 img4d(:,:,z,:) = zeros(hdr.dat.dim(1),hdr.dat.dim(2),1,nfold+1);
0137
0138 else
0139
0140 for f = 1:nfold
0141
0142 train_idx = PRT.model(model_idx).input.cv_mat(:,f)==1;
0143 train = samp_idx(train_idx);
0144 train_all = zeros(size(ID_all,1),1); train_all(train) = 1;
0145
0146 d.coeffs = PRT.model(model_idx).output.fold(f).alpha;
0147
0148 d.datamat = zeros(length(train), length(feat_slc));
0149 for i = 1:length(fas_idx)
0150
0151 indm = PRT.fs(fs_idx).fas.im == fas_idx(i) & train_all;
0152 ifa = PRT.fs(fs_idx).fas.ifa(indm);
0153
0154
0155 indtr = ID(train_idx,3) == fas_idx(i);
0156 d.datamat(indtr,:) = PRT.fas(fas_idx(i)).dat(ifa,voxtr(feat_slc));
0157 end
0158
0159
0160 ops = PRT.model(model_idx).input.operations(PRT.model(model_idx).input.operations ~=0 );
0161 cvdata.train = {d.datamat};
0162 cvdata.tr_id = ID(train_idx,:);
0163 cvdata.use_kernel = false;
0164 for o = 1:length(ops)
0165 cvdata = prt_apply_operation(PRT, cvdata, ops(o));
0166 end
0167 d.datamat = cvdata.train{:};
0168
0169 wimg = prt_weights(d,m);
0170
0171 img3d = zeros(1,xydim);
0172
0173 img3d(mask_train(feat_slc)-xydim*(z-1)) = wimg;
0174
0175 norm3d(f) = sum(img3d.^2);
0176
0177 img3dav = img3dav + img3d;
0178
0179 img4d(:,:,z,f) = reshape(img3d,hdr.dat.dim(1),hdr.dat.dim(2),1,1);
0180
0181 end
0182
0183 norm4d(z,:) = norm3d;
0184
0185
0186
0187 img4d(:,:,z,nfold+1) = reshape(img3dav,hdr.dat.dim(1),hdr.dat.dim(2),...
0188 1,1)/nfold;
0189 end
0190
0191 end
0192
0193 norm4d = sqrt(sum(norm4d,1));
0194
0195 disp('Normalising weights--------->>')
0196 for f = 1:nfold,
0197 img4d(:,:,:,f) = img4d(:,:,:,f)./norm4d(1,f);
0198 end
0199
0200
0201
0202 disp('Creating image--------->>')
0203 No = hdr;
0204 No.dat = img4d;
0205 No.descrip = 'Pronto weigths';
0206 create(No);
0207 disp('Done.')