Home > . > prt_compute_weights_class.m

prt_compute_weights_class

PURPOSE ^

FORMAT prt_compute_weights_class(PRT,in,model_idx)

SYNOPSIS ^

function img_name = prt_compute_weights_class(PRT,in,model_idx,flag, ibe, flag2)

DESCRIPTION ^

 FORMAT prt_compute_weights_class(PRT,in,model_idx)

 This function calls prt_weights to compute weights 
 Inputs:
       PRT             - data/design/model structure (it needs to contain
                         at least one estimated model).
         in            - structure with specific information to create
                         weights
           .model_name - model name (string)
           .img_name   - (optional) name of the file to be created
                         (string)
           .pathdir    - directory path where to save weights (same as the
                         one for PRT.mat) (string)
         model_idx     - model index (integer)
         flag          - compute weight images for each permutation if 1
         ibe           - which beta to use for MKL and multiple modalities
         flag2         - build image of weights per region
 Output:
       img_name        - name of the .img file created
       + image file created on disk
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function img_name = prt_compute_weights_class(PRT,in,model_idx,flag, ibe, flag2)
0002 % FORMAT prt_compute_weights_class(PRT,in,model_idx)
0003 %
0004 % This function calls prt_weights to compute weights
0005 % Inputs:
0006 %       PRT             - data/design/model structure (it needs to contain
0007 %                         at least one estimated model).
0008 %         in            - structure with specific information to create
0009 %                         weights
0010 %           .model_name - model name (string)
0011 %           .img_name   - (optional) name of the file to be created
0012 %                         (string)
0013 %           .pathdir    - directory path where to save weights (same as the
0014 %                         one for PRT.mat) (string)
0015 %         model_idx     - model index (integer)
0016 %         flag          - compute weight images for each permutation if 1
0017 %         ibe           - which beta to use for MKL and multiple modalities
0018 %         flag2         - build image of weights per region
0019 % Output:
0020 %       img_name        - name of the .img file created
0021 %       + image file created on disk
0022 %__________________________________________________________________________
0023 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0024 
0025 % Written by M.J.Rosa
0026 % $Id$
0027 
0028 % Find machine
0029 % -------------------------------------------------------------------------
0030 mfunc       = PRT.model(model_idx).input.machine.function;
0031 mname       = PRT.model(model_idx).model_name;
0032 nclass      = length(PRT.model(model_idx).input.class);
0033 if nclass > 2, mfunc = 'multiclass_machine'; end
0034 m.args      = [];
0035 
0036 if nargin<5
0037     ibe=[];
0038 end
0039 if nargin<6
0040     flag2 = 0;
0041 end
0042 
0043 % unfortunately a bug somewhere causes shifts in weight image if
0044 % .nii is used...
0045 
0046 switch mfunc
0047     
0048     case 'multiclass_machine'
0049         m.function  = 'prt_weights_gpclap';
0050         nclass      = length(PRT.model(model_idx).input.class);
0051         for c = 1:nclass
0052             img_mach{c} = ['weights_',mname,'_',num2str(c),'.img'];
0053         end
0054     case 'prt_machine_sMKL_cla'
0055         m.function = 'prt_weights_sMKL_cla';
0056         img_mach{1} = ['weights_',mname,'.img'];
0057     case 'prt_machine_RT_bin'
0058         error('prt_compute_weights:MachineNotSupported',...
0059             'Error: weights computation not supported for this machine!');
0060     otherwise
0061         m.function  = 'prt_weights_bin_linkernel';
0062         img_mach{1} = ['weights_',mname,'.img'];
0063 end
0064 
0065 nimage = length(img_mach);
0066 % Image name
0067 % -------------------------------------------------------------------------
0068 if ~isempty(in.img_name)
0069     if ~(prt_checkAlphaNumUnder(in.img_name))
0070         error('prt_compute_weights:NameNotAlphaNumeric',...
0071             'Error: image name should contain only alpha-numeric elements!');
0072     end
0073     if nimage>1 && ~flag2
0074         for c = 1:nimage
0075             in.img_name_c  = [in.img_name,'_',num2str(c),'.img'];
0076             img_name{c}    = fullfile(in.pathdir,in.img_name_c);
0077         end
0078     else
0079         img_name{1}   = fullfile(in.pathdir,[in.img_name,'.img']);
0080     end
0081 else
0082     for c = 1:nimage
0083         img_name{c}    = fullfile(in.pathdir,img_mach{c});
0084     end
0085 end
0086 
0087 % Other info
0088 % -------------------------------------------------------------------------
0089 fs_name  = PRT.model(model_idx).input.fs(1).fs_name;
0090 samp_idx = PRT.model(model_idx).input.samp_idx;
0091 nfold    = length(PRT.model(model_idx).output.fold);
0092 
0093 % Find feature set
0094 % -------------------------------------------------------------------------
0095 nfs = length(PRT.fs);
0096 for f = 1:nfs
0097     if strcmp(PRT.fs(f).fs_name,fs_name)
0098         fs_idx = f;
0099     end
0100 end
0101 ID     = PRT.fs(fs_idx).id_mat(PRT.model(model_idx).input.samp_idx,:);
0102 ID_all = PRT.fs(fs_idx).id_mat;
0103 
0104 % Find modality (now as inputs)
0105 % -------------------------------------------------------------------------
0106 fas_idx = in.fas_idx;
0107 mm = in.mm;
0108 
0109 % Get the indexes of the voxels which are in the first/second level mask
0110 % -------------------------------------------------------------------------
0111 
0112 idROI = [];
0113 idfeat = PRT.fas(fas_idx(1)).idfeat_img;
0114 if isempty(PRT.fs(fs_idx).modality(mm(1)).idfeat_fas) % get the 2nd level masking
0115     idfeat_fas = 1:length(idfeat);
0116 else
0117     idfeat_fas = PRT.fs(fs_idx).modality(mm(1)).idfeat_fas;
0118 end
0119 if PRT.fs(fs_idx).multkernelROI
0120     m_train = cell(length(PRT.fs(fs_idx).modality(mm(1)).idfeat_img),1);
0121     for i = 1:length(PRT.fs(fs_idx).modality(mm(1)).idfeat_img)
0122         tmp1 = PRT.fs(fs_idx).modality(mm(1)).idfeat_img{i};
0123         idROI=[idROI;tmp1];
0124         tmp = idfeat_fas(tmp1);
0125         m_train{i} = idfeat(tmp);
0126     end
0127     id2 = idfeat_fas(sort(idROI));
0128 else
0129     id2 = idfeat_fas;
0130 end
0131 mask_train = idfeat(id2);
0132 voxtr = find(ismember(idfeat,mask_train));
0133 
0134 
0135 % Create image
0136 % -------------------------------------------------------------------------
0137 
0138 if flag
0139     %create images for each permutation
0140     if isfield(PRT.model(model_idx).output,'permutation') && ...
0141             ~isempty(PRT.model(model_idx).output.permutation)
0142         maxp = length(PRT.model(model_idx).output.permutation);
0143     else
0144         disp('No parameters saved for the permutation, building weight image only')
0145     end
0146 else
0147     maxp=0;
0148 end
0149 pthperm = cell(nimage,1);
0150 for p=0:maxp
0151     if p>0
0152         for c = 1:nimage
0153             [pth,nam] = fileparts(img_name{c});            
0154             if p==1
0155                 pthperm{c} = fullfile(pth,['perm_',nam]);
0156                 if ~exist(pthperm{c},'dir')
0157                     mkdir(pth,['perm_',nam]);
0158                 end
0159             end            
0160             img_nam{c} = fullfile(pthperm{c},[nam,'_perm',num2str(p),'.img']);
0161         end
0162         fprintf('Permutation: %d of %d \n',p, ...
0163             length(PRT.model(model_idx).output.permutation));
0164     else
0165         img_nam = img_name;
0166     end
0167     
0168     % check that image does not exist, otherwise, delete
0169     if exist(img_nam{1},'file')
0170         for c = 1:nimage
0171             delete(img_nam{c});
0172             % delete hdr:
0173             [pth,nam] = fileparts(img_nam{c});
0174             hdr_name  = [pth,filesep,nam,'.hdr'];
0175             delete(hdr_name)
0176         end
0177     end
0178     
0179     hdr        = PRT.fas(fas_idx(1)).hdr.private;
0180     dat_dim    = hdr.dat.dim;
0181     
0182     if length(dat_dim)==2, dat_dim = [dat_dim 1]; end % handling case of 2D image
0183     
0184     img4d = cell(nimage,1); % afm
0185     for c = 1:nimage
0186         if p==0 %save folds for the 'true' image
0187             folds_comp=nfold+1;
0188         else    %save the average across folds only for permutations
0189             folds_comp=1;
0190         end
0191         img4d{c} = file_array(img_nam{c},[dat_dim(1),dat_dim(2),...
0192             dat_dim(3),folds_comp],'float32-le',0,1,0);
0193     end
0194     
0195     zdim    = dat_dim(3);
0196     xydim   = dat_dim(1)*dat_dim(2);
0197     % norm3d  = 0;
0198     
0199     disp('Computing weights.......>>')
0200     
0201     for z = 1:zdim
0202         
0203         fprintf('Slice: %d of %d \n',z,zdim);
0204         
0205         img3dav = cell(1,nimage);
0206         for c = 1:nimage
0207             img3dav{c}  = zeros(1,xydim); % average weight map
0208         end
0209         
0210         if ~isempty(idROI) %get indexes in each slice for each ROI
0211             feat_slc = mask_train(mask_train>=(xydim*(z-1)+1) & ...
0212                 mask_train<=(xydim*z));
0213             for ir = 1:length(m_train)
0214                 tmp = m_train{ir}(m_train{ir}>=(xydim*(z-1)+1) & ...
0215                     m_train{ir}<=(xydim*z));
0216                 m.args.idfeat_img{ir} = find(ismember(feat_slc,tmp));
0217             end
0218             feat_slc = find(mask_train>=(xydim*(z-1)+1) & ...
0219                 mask_train<=(xydim*z));
0220         else
0221             feat_slc = find(mask_train>=(xydim*(z-1)+1) & ...
0222                 mask_train<=(xydim*z));
0223             m.args.idfeat_img = {1:length(feat_slc)};
0224         end      
0225         
0226         if isempty(feat_slc)
0227             
0228             for c = 1:nimage
0229                 img4d{c}(:,:,z,:) = NaN*zeros(dat_dim(1),dat_dim(2),1,folds_comp);
0230             end
0231             
0232         else
0233             
0234             for f = 1:nfold
0235                 
0236                 train_idx      = PRT.model(model_idx).input.cv_mat(:,f)==1;
0237                 train          = samp_idx(train_idx);
0238                 train_all      = zeros(size(ID_all,1),1); train_all(train) = 1;
0239                 if p>0
0240                     d.coeffs   = PRT.model(model_idx).output.permutation(p).fold(f).alpha;
0241                 else
0242                     d.coeffs   = PRT.model(model_idx).output.fold(f).alpha;
0243                 end
0244                 
0245                 d.datamat = zeros(length(train), length(feat_slc));
0246                 for i = 1:length(fas_idx)
0247                     % indexes to access the file array
0248                     indm = find(PRT.fs(fs_idx).fas.im == fas_idx(i));
0249                     if PRT.fs(fs_idx).multkernel
0250                         indtr = ID(train_idx,3) == fas_idx(1);
0251                         indm = indm(find(train_all));
0252                     else
0253                         indtr = ID(train_idx,3) == fas_idx(i);
0254                         indm = indm(find(train_all(ID_all(:,3)==fas_idx(i))));
0255                     end
0256                     ifa  = PRT.fs(fs_idx).fas.ifa(indm);
0257                     
0258                     % index for the target data matrix
0259                     d.datamat(indtr,:) = PRT.fas(fas_idx(i)).dat(ifa,voxtr(feat_slc));
0260                 end
0261                 
0262                 % Apply any operations specified during training
0263                 ops = PRT.model(model_idx).input.operations(PRT.model(model_idx).input.operations ~=0 );
0264                 cvdata.train      = {d.datamat};
0265                 cvdata.tr_id      = ID(train_idx,:);
0266                 cvdata.use_kernel = false; % need to apply the operation to the data
0267                 for o = 1:length(ops)
0268                     cvdata = prt_apply_operation(PRT, cvdata, ops(o));
0269                 end
0270                 d.datamat = cvdata.train{:};
0271                 
0272                 if strcmpi(mfunc,'prt_machine_sMKL_cla')
0273                     if isempty(ibe)
0274                         m.args.betas = PRT.model(model_idx).output.fold(f).beta;
0275                     else
0276                         m.args.betas = PRT.model(model_idx).output.fold(f).beta(ibe);
0277                     end
0278                 end
0279                 
0280                 if flag2
0281                     m.args.flag = 1;
0282                 end
0283                 
0284                 % COMPUTE WEIGHTS
0285                 wimg      = prt_weights(d,m);
0286                 
0287                 for c = 1:nimage,
0288                     img3d              = zeros(1,xydim);
0289                     indi               = mask_train(feat_slc)-xydim*(z-1);
0290                     indm               = setdiff(1:xydim,indi);
0291                     img3d(indi)        = wimg{c};
0292                     norm3d{c}(f)       = sum(img3d.^2);
0293                     img3d(indm)        = NaN;
0294                     img3dav{c}         = img3dav{c} + img3d;
0295                     if p==0
0296                         img4d{c}(:,:,z,f)  = reshape(img3d,dat_dim(1),dat_dim(2),1,1);
0297                     end
0298                 end
0299                 
0300             end
0301             
0302             
0303             
0304             % Create average fold
0305             %------------------------------------------------------------------
0306             for c = 1:nimage
0307                 norm4d{c}(z,:)             = norm3d{c};
0308                 img3dav{c}                 = img3dav{c}/nfold; %afm
0309                 img4d{c}(:,:,z,folds_comp) = reshape(img3dav{c},dat_dim(1),dat_dim(2),1,1); %afm
0310                 norm4dav{c}(z,:)           = sum(img3dav{c}(isfinite(img3dav{c})).^2); %afm
0311             end
0312         end
0313         
0314     end
0315     
0316     for c =1:nimage
0317         norm4d{c}   = sqrt(sum(norm4d{c},1));
0318         norm4dav{c} = sqrt(sum(norm4dav{c},1)); %afm
0319     end
0320     
0321     disp('Normalising weights--------->>')
0322     if p==0
0323         for f = 1:nfold,
0324             for c = 1:nimage
0325                 if unique(norm4d{c}(1,f))~=0
0326                     img4d{c}(:,:,:,f) = img4d{c}(:,:,:,f)./norm4d{c}(1,f);
0327                 else
0328                     img4d{c}(:,:,:,f) = img4d{c}(:,:,:,f);
0329                 end
0330             end
0331         end
0332     end
0333     
0334     for c = 1:nimage %afm
0335         img4d{c}(:,:,:,folds_comp) = img4d{c}(:,:,:,folds_comp)./norm4dav{c}; %afm
0336     end %afm
0337     
0338     % Create weigths file
0339     %-------------------------------------------------------------------------
0340     clear No
0341     for c = 1:nimage
0342         fprintf('Creating image %d of %d--------->>\n',c,nimage);
0343         No         = hdr;              % copy header
0344         No.dat     = img4d{c};         % change file_array
0345         No.descrip = 'Pronto weigths'; % description
0346         create(No);                    % write header
0347         disp('Done.')
0348     end
0349 end
0350

Generated on Tue 10-Feb-2015 18:16:33 by m2html © 2005