Home > . > prt_compute_weights_regre.m

prt_compute_weights_regre

PURPOSE ^

FORMAT prt_compute_weights_regre(PRT,in,model_idx)

SYNOPSIS ^

function img_name = prt_compute_weights_regre(PRT,in,model_idx,flag, ibe, flag2)

DESCRIPTION ^

 FORMAT prt_compute_weights_regre(PRT,in,model_idx)

 This function calls prt_weights to compute weights
 Inputs:
       PRT             - data/design/model structure (it needs to contain
                         at least one estimated model).
         in            - structure with specific information to create
                         weights
           .model_name - model name (string)
           .img_name   - (optional) name of the file to be created
                         (string)
           .pathdir    - directory path where to save weights (same as the
                         one for PRT.mat) (string)
         model_idx     - model index (integer)
         flag          - compute weight images for each permutation if 1
         ibe           - which beta to use for MKL and multiple modalities
         flag2         - build image of weights per region
 Output:
       img_name        - name of the .img file created
       + image file created on disk
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function img_name = prt_compute_weights_regre(PRT,in,model_idx,flag, ibe, flag2)
0002 % FORMAT prt_compute_weights_regre(PRT,in,model_idx)
0003 %
0004 % This function calls prt_weights to compute weights
0005 % Inputs:
0006 %       PRT             - data/design/model structure (it needs to contain
0007 %                         at least one estimated model).
0008 %         in            - structure with specific information to create
0009 %                         weights
0010 %           .model_name - model name (string)
0011 %           .img_name   - (optional) name of the file to be created
0012 %                         (string)
0013 %           .pathdir    - directory path where to save weights (same as the
0014 %                         one for PRT.mat) (string)
0015 %         model_idx     - model index (integer)
0016 %         flag          - compute weight images for each permutation if 1
0017 %         ibe           - which beta to use for MKL and multiple modalities
0018 %         flag2         - build image of weights per region
0019 % Output:
0020 %       img_name        - name of the .img file created
0021 %       + image file created on disk
0022 %__________________________________________________________________________
0023 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0024 
0025 % Written by M.J.Rosa
0026 % $Id$
0027 
0028 % Find machine
0029 % -------------------------------------------------------------------------
0030 mfunc       = PRT.model(model_idx).input.machine.function;
0031 mname       = PRT.model(model_idx).model_name;
0032 m.args      = [];
0033 
0034 if nargin<5
0035     ibe=[];
0036 end
0037 if nargin<6
0038     flag2 = 0;
0039 end
0040 
0041 % unfortunately a bug somewhere causes shifts in weight image if
0042 % .nii is used...
0043 
0044 switch mfunc
0045     case 'prt_machine_sMKL_reg'
0046         m.function = 'prt_weights_sMKL_reg';
0047         img_mach{1} = ['weights_',mname,'.img'];
0048     case 'prt_machine_RT_bin'
0049         error('prt_compute_weights:MachineNotSupported',...
0050             'Error: weights computation not supported for this machine!');
0051     otherwise
0052         m.function  = 'prt_weights_bin_linkernel';
0053         img_mach{1} = ['weights_',mname,'.img'];
0054 end
0055 
0056 nimage = length(img_mach);
0057 % Image name
0058 % -------------------------------------------------------------------------
0059 if ~isempty(in.img_name)
0060     if ~(prt_checkAlphaNumUnder(in.img_name))
0061         error('prt_compute_weights:NameNotAlphaNumeric',...
0062             'Error: image name should contain only alpha-numeric elements!');
0063     end
0064     if nimage>1 && ~flag2
0065         for c = 1:nimage
0066             in.img_name_c  = [in.img_name,'_',num2str(c),'.img'];
0067             img_name{c}    = fullfile(in.pathdir,in.img_name_c);
0068         end
0069     else
0070         img_name{1}   = fullfile(in.pathdir,[in.img_name,'.img']);
0071     end
0072 else
0073     for c = 1:nimage
0074         img_name{c}    = fullfile(in.pathdir,img_mach{c});
0075     end
0076 end
0077 
0078 % Other info
0079 % -------------------------------------------------------------------------
0080 fs_name  = PRT.model(model_idx).input.fs(1).fs_name;
0081 samp_idx = PRT.model(model_idx).input.samp_idx;
0082 nfold    = length(PRT.model(model_idx).output.fold);
0083 
0084 % Find feature set
0085 % -------------------------------------------------------------------------
0086 nfs = length(PRT.fs);
0087 for f = 1:nfs
0088     if strcmp(PRT.fs(f).fs_name,fs_name)
0089         fs_idx = f;
0090     end
0091 end
0092 ID     = PRT.fs(fs_idx).id_mat(PRT.model(model_idx).input.samp_idx,:);
0093 ID_all = PRT.fs(fs_idx).id_mat;
0094 
0095 % Find modality (now as inputs)
0096 % -------------------------------------------------------------------------
0097 fas_idx = in.fas_idx;
0098 mm = in.mm;
0099 
0100 % Get the indexes of the voxels which are in the first/second level mask
0101 % -------------------------------------------------------------------------
0102 
0103 idROI = [];
0104 idfeat = PRT.fas(fas_idx(1)).idfeat_img;
0105 if isempty(PRT.fs(fs_idx).modality(mm(1)).idfeat_fas) % get the 2nd level masking
0106     idfeat_fas = 1:length(idfeat);
0107 else
0108     idfeat_fas = PRT.fs(fs_idx).modality(mm(1)).idfeat_fas;
0109 end
0110 if PRT.fs(fs_idx).multkernelROI
0111     m_train = cell(length(PRT.fs(fs_idx).modality(mm(1)).idfeat_img),1);
0112     for i = 1:length(PRT.fs(fs_idx).modality(mm(1)).idfeat_img)
0113         tmp1 = PRT.fs(fs_idx).modality(mm(1)).idfeat_img{i};
0114         idROI=[idROI;tmp1];
0115         tmp = idfeat_fas(tmp1);
0116         m_train{i} = idfeat(tmp);
0117     end
0118     id2 = idfeat_fas(sort(idROI));
0119 else
0120     id2 = idfeat_fas;
0121 end
0122 mask_train = idfeat(id2);
0123 voxtr = find(ismember(idfeat,mask_train));
0124 
0125 
0126 % Create image
0127 % -------------------------------------------------------------------------
0128 
0129 if flag
0130     %create images for each permutation
0131     if isfield(PRT.model(model_idx).output,'permutation') && ...
0132             ~isempty(PRT.model(model_idx).output.permutation)
0133         maxp = length(PRT.model(model_idx).output.permutation);
0134     else
0135         disp('No parameters saved for the permutation, building weight image only')
0136     end
0137 else
0138     maxp=0;
0139 end
0140 pthperm = cell(nimage,1);
0141 for p=0:maxp
0142     if p>0
0143         for c = 1:nimage
0144             [pth,nam] = fileparts(img_name{c});            
0145             if p==1
0146                 pthperm{c} = fullfile(pth,['perm_',nam]);
0147                 if ~exist(pthperm{c},'dir')
0148                     mkdir(pth,['perm_',nam]);
0149                 end
0150             end            
0151             img_nam{c} = fullfile(pthperm{c},[nam,'_perm',num2str(p),'.img']);
0152         end
0153         fprintf('Permutation: %d of %d \n',p, ...
0154             length(PRT.model(model_idx).output.permutation));
0155     else
0156         img_nam = img_name;
0157     end
0158     
0159     % check that image does not exist, otherwise, delete
0160     if exist(img_nam{1},'file')
0161         for c = 1:nimage
0162             delete(img_nam{c});
0163             % delete hdr:
0164             [pth,nam] = fileparts(img_nam{c});
0165             hdr_name  = [pth,filesep,nam,'.hdr'];
0166             delete(hdr_name)
0167         end
0168     end
0169     
0170     hdr        = PRT.fas(fas_idx(1)).hdr.private;
0171     dat_dim    = hdr.dat.dim;
0172     
0173     if length(dat_dim)==2, dat_dim = [dat_dim 1]; end % handling case of 2D image
0174     
0175     img4d = cell(nimage,1); % afm
0176     for c = 1:nimage
0177         if p==0 %save folds for the 'true' image
0178             folds_comp=nfold+1;
0179         else    %save the average across folds only for permutations
0180             folds_comp=1;
0181         end
0182         img4d{c} = file_array(img_nam{c},[dat_dim(1),dat_dim(2),...
0183             dat_dim(3),folds_comp],'float32-le',0,1,0);
0184     end
0185     
0186     zdim    = dat_dim(3);
0187     xydim   = dat_dim(1)*dat_dim(2);
0188     % norm3d  = 0;
0189     
0190     disp('Computing weights.......>>')
0191     
0192     for z = 1:zdim
0193         
0194         fprintf('Slice: %d of %d \n',z,zdim);
0195         
0196         img3dav = cell(1,nimage);
0197         for c = 1:nimage
0198             img3dav{c}  = zeros(1,xydim); % average weight map
0199         end
0200         
0201         if ~isempty(idROI) %get indexes in each slice for each ROI
0202             feat_slc = mask_train(mask_train>=(xydim*(z-1)+1) & ...
0203                 mask_train<=(xydim*z));
0204             for ir = 1:length(m_train)
0205                 tmp = m_train{ir}(m_train{ir}>=(xydim*(z-1)+1) & ...
0206                     m_train{ir}<=(xydim*z));
0207                 m.args.idfeat_img{ir} = find(ismember(feat_slc,tmp));
0208             end
0209             feat_slc = find(mask_train>=(xydim*(z-1)+1) & ...
0210                 mask_train<=(xydim*z));
0211         else
0212             feat_slc = find(mask_train>=(xydim*(z-1)+1) & ...
0213                 mask_train<=(xydim*z));
0214             m.args.idfeat_img = {1:length(feat_slc)};
0215         end      
0216         
0217         if isempty(feat_slc)
0218             
0219             for c = 1:nimage
0220                 img4d{c}(:,:,z,:) = NaN*zeros(dat_dim(1),dat_dim(2),1,folds_comp);
0221             end
0222             
0223         else
0224             
0225             for f = 1:nfold
0226                 
0227                 train_idx      = PRT.model(model_idx).input.cv_mat(:,f)==1;
0228                 train          = samp_idx(train_idx);
0229                 train_all      = zeros(size(ID_all,1),1); train_all(train) = 1;
0230                 if p>0
0231                     d.coeffs   = PRT.model(model_idx).output.permutation(p).fold(f).alpha;
0232                 else
0233                     d.coeffs   = PRT.model(model_idx).output.fold(f).alpha;
0234                 end
0235                 
0236                 d.datamat = zeros(length(train), length(feat_slc));
0237                 for i = 1:length(fas_idx)
0238                     % indexes to access the file array
0239                     indm = find(PRT.fs(fs_idx).fas.im == fas_idx(i));
0240                     if PRT.fs(fs_idx).multkernel
0241                         indtr = ID(train_idx,3) == fas_idx(1);
0242                         indm = indm(find(train_all));
0243                     else
0244                         indtr = ID(train_idx,3) == fas_idx(i);
0245                         indm = indm(find(train_all(ID_all(:,3)==fas_idx(i))));
0246                     end
0247                     ifa  = PRT.fs(fs_idx).fas.ifa(indm);
0248                     
0249                     % index for the target data matrix
0250                     d.datamat(indtr,:) = PRT.fas(fas_idx(i)).dat(ifa,voxtr(feat_slc));
0251                 end
0252                 
0253                 % Apply any operations specified during training
0254                 ops = PRT.model(model_idx).input.operations(PRT.model(model_idx).input.operations ~=0 );
0255                 cvdata.train      = {d.datamat};
0256                 cvdata.tr_id      = ID(train_idx,:);
0257                 cvdata.use_kernel = false; % need to apply the operation to the data
0258                 for o = 1:length(ops)
0259                     cvdata = prt_apply_operation(PRT, cvdata, ops(o));
0260                 end
0261                 d.datamat = cvdata.train{:};
0262                 
0263                 if strcmpi(mfunc,'prt_machine_sMKL_reg')
0264                     if isempty(ibe)
0265                         m.args.betas = PRT.model(model_idx).output.fold(f).beta;
0266                     else
0267                         m.args.betas = PRT.model(model_idx).output.fold(f).beta(ibe);
0268                     end
0269                 end
0270                 
0271                 if flag2
0272                     m.args.flag = 1;
0273                 end
0274                 
0275                 % COMPUTE WEIGHTS
0276                 wimg      = prt_weights(d,m);
0277                 
0278                 for c = 1:nimage,
0279                     img3d              = zeros(1,xydim);
0280                     indi               = mask_train(feat_slc)-xydim*(z-1);
0281                     indm               = setdiff(1:xydim,indi);
0282                     img3d(indi)        = wimg{c};
0283                     norm3d{c}(f)       = sum(img3d.^2);
0284                     img3d(indm)        = NaN;
0285                     img3dav{c}         = img3dav{c} + img3d;
0286                     if p==0
0287                         img4d{c}(:,:,z,f)  = reshape(img3d,dat_dim(1),dat_dim(2),1,1);
0288                     end
0289                 end
0290                 
0291             end
0292             
0293             
0294             
0295             % Create average fold
0296             %------------------------------------------------------------------
0297             for c = 1:nimage
0298                 norm4d{c}(z,:)             = norm3d{c};
0299                 img3dav{c}                 = img3dav{c}/nfold; %afm
0300                 img4d{c}(:,:,z,folds_comp) = reshape(img3dav{c},dat_dim(1),dat_dim(2),1,1); %afm
0301                 norm4dav{c}(z,:)           = sum(img3dav{c}(isfinite(img3dav{c})).^2); %afm
0302             end
0303         end
0304         
0305     end
0306     
0307     for c =1:nimage
0308         norm4d{c}   = sqrt(sum(norm4d{c},1));
0309         norm4dav{c} = sqrt(sum(norm4dav{c},1)); %afm
0310     end
0311     
0312     disp('Normalising weights--------->>')
0313     if p==0
0314         for f = 1:nfold,
0315             for c = 1:nimage
0316                 if unique(norm4d{c}(1,f))~=0
0317                     img4d{c}(:,:,:,f) = img4d{c}(:,:,:,f)./norm4d{c}(1,f);
0318                 else
0319                     img4d{c}(:,:,:,f) = img4d{c}(:,:,:,f);
0320                 end
0321             end
0322         end
0323     end
0324     
0325     for c = 1:nimage %afm
0326         img4d{c}(:,:,:,folds_comp) = img4d{c}(:,:,:,folds_comp)./norm4dav{c}; %afm
0327     end %afm
0328     
0329     % Create weigths file
0330     %-------------------------------------------------------------------------
0331     clear No
0332     for c = 1:nimage
0333         fprintf('Creating image %d of %d--------->>\n',c,nimage);
0334         No         = hdr;              % copy header
0335         No.dat     = img4d{c};         % change file_array
0336         No.descrip = 'Pronto weigths'; % description
0337         create(No);                    % write header
0338         disp('Done.')
0339     end
0340 end
0341

Generated on Tue 10-Feb-2015 18:16:33 by m2html © 2005