Home > . > prt_compute_weights.m

prt_compute_weights

PURPOSE ^

FORMAT prt_compute_weights(PRT,in)

SYNOPSIS ^

function img_name = prt_compute_weights(PRT,in,flag,flag2)

DESCRIPTION ^

 FORMAT prt_compute_weights(PRT,in)

 This function calls prt_weights to compute weights
 Inputs:
       PRT             - data/design/model structure (it needs to contain
                         at least one estimated model).
       in              - structure with specific information to create
                         weights
           .model_name - model name (string)
           .img_name   - (optional) name of the file to be created
                         (string)
           .pathdir    - directory path where to save weights (same as the
                         one for PRT.mat) (string)
           .atl_name   - name of the atlas for post-hoc local averages of
       flag            - set to 1 to compute the weight images for each
                         permutation (default: 0)
       flag2           - set to 1 to build image of weight per ROI
                         weights according to atlas
 Output:
       img_name        - name of the .img file created
       + image file created on disk
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function img_name = prt_compute_weights(PRT,in,flag,flag2)
0002 % FORMAT prt_compute_weights(PRT,in)
0003 %
0004 % This function calls prt_weights to compute weights
0005 % Inputs:
0006 %       PRT             - data/design/model structure (it needs to contain
0007 %                         at least one estimated model).
0008 %       in              - structure with specific information to create
0009 %                         weights
0010 %           .model_name - model name (string)
0011 %           .img_name   - (optional) name of the file to be created
0012 %                         (string)
0013 %           .pathdir    - directory path where to save weights (same as the
0014 %                         one for PRT.mat) (string)
0015 %           .atl_name   - name of the atlas for post-hoc local averages of
0016 %       flag            - set to 1 to compute the weight images for each
0017 %                         permutation (default: 0)
0018 %       flag2           - set to 1 to build image of weight per ROI
0019 %                         weights according to atlas
0020 % Output:
0021 %       img_name        - name of the .img file created
0022 %       + image file created on disk
0023 %__________________________________________________________________________
0024 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0025 
0026 % Written by M.J.Rosa
0027 % $Id$
0028 
0029 % Find model
0030 % -------------------------------------------------------------------------
0031 nmodel = length(PRT.model);
0032 model_idx = 0;
0033 for i = 1:nmodel
0034     if strcmp(PRT.model(i).model_name,in.model_name)
0035         model_idx = i;
0036     end
0037 end
0038 % Check if model exists
0039 if model_idx == 0, error('prt_compute_weights:ModelNotFound',...
0040         'Error: model not found in PRT.mat!'); end
0041 
0042 mtype = PRT.model(model_idx).input.type;
0043 mname = PRT.model(model_idx).model_name;
0044 
0045 % Initialize: get feature set and modalities indexes and deal with MK
0046 % -------------------------------------------------------------------
0047 % Get index of feature set
0048 fs_name  = PRT.model(model_idx).input.fs.fs_name;
0049 nfs = length(PRT.fs);
0050 for f = 1:nfs
0051     if strcmp(PRT.fs(f).fs_name,fs_name)
0052         fs_idx = f;
0053     end
0054 end
0055 
0056 % Find modality
0057 nfas = length(PRT.fas);
0058 mods = {PRT.fs(fs_idx).modality.mod_name};
0059 fas  = zeros(1,nfas);
0060 mm=zeros(length(mods),nfas);
0061 for i = 1:nfas
0062     for j = 1:length(mods)
0063         if strcmpi(PRT.fas(i).mod_name,mods{j})
0064             fas(i) = 1;
0065             mm(i,j)= 1;
0066         end
0067     end
0068 end
0069 fas_idx = find(fas);
0070 
0071 % Loop over the different feature sets if they were considered as separate
0072 % kernels (i.e. one or more kernel(s) per modality)
0073 ibeta_mod = cell(length(fas_idx),1);
0074 if PRT.fs(fs_idx).multkernelROI   %multiple ROI kernels in feature set
0075     mult_kern_ROI = 1;
0076     if PRT.fs(fs_idx).multkernel % Multiple modalities treated separately
0077         count = 0;
0078         % get the indexes of the betas for each modality
0079         for i=1:length(fas_idx)
0080 %             mim = find(mm(i,:));
0081             numk = length(PRT.fs(fs_idx).modality(i).idfeat_img);
0082             ibeta_mod{i} = (1:numk)+count;
0083             count = count + numk;
0084         end
0085     else % Multiple modalities concatenated or only one modality
0086         ibeta_mod{1} = 1:length(PRT.fs(fs_idx).modality(1).idfeat_img);
0087     end
0088     nim = length(fas_idx);
0089 else
0090     if PRT.fs(fs_idx).multkernel % Multiple modalities treated separately
0091         for i=1:length(fas_idx)
0092             ibeta_mod{i} = i;
0093         end
0094         nim = length(fas_idx);
0095     else
0096         nim = 1;
0097     end
0098     mult_kern_ROI = 0;   
0099 end
0100 
0101 % We also need to know whether those multiple kernels have been added in a
0102 % non-MKL machine or if a MKL machine was used.
0103 if ~isfield(PRT.model(model_idx).output.fold(1),'beta') || ...
0104         isempty(PRT.model(model_idx).output.fold(1).beta)
0105     added = 1;
0106 else
0107     added = 0;
0108 end
0109 
0110 % Compute the total number of images to be computed to initialize the
0111 % outputs
0112 switch mtype
0113     case 'classification'
0114         nc = size(PRT.model(model_idx).output.stats.con_mat, 2);
0115     case 'regression'
0116         nc = 1;
0117 end
0118 if nc > 2
0119     nim = nim*nc;
0120 end
0121 
0122 % Check inputs for weights per region
0123 if exist('flag2','var') && flag2
0124     if isempty(in.atl_name) && ~mult_kern_ROI
0125         error('prt_compute_weights:NoAtlas',...
0126             'Error: Atlas should be provided to compute weights per region')
0127     end
0128 end
0129 
0130 
0131 % Build weights
0132 %--------------------------------------------------------------------------
0133 if isfield(PRT.model(model_idx).output,'weight_idfeatroi') && ...
0134         ~isempty(PRT.model(model_idx).output.weight_idfeatroi)
0135     PRT.model(model_idx).output.weight_idfeatroi =[];
0136 end
0137 
0138 if isfield(PRT.model(model_idx).output,'weight_atlas') && ...
0139         ~isempty(PRT.model(model_idx).output.weight_atlas)
0140     PRT.model(model_idx).output.weight_atlas ={};
0141 end
0142 PRT.model(model_idx).output.weight_ROI = cell(nim,1);
0143 
0144 if PRT.fs(fs_idx).multkernel && length(fas_idx)>1 && ~added % Need to loop over the modalities since multiple kernels
0145     summroi  = 0;
0146     %get/set image names by appending the modality name at the end
0147     im_name = cell(1,length(fas_idx));
0148     if ~isempty(in.img_name)
0149         if ~(prt_checkAlphaNumUnder(in.img_name))
0150             error('prt_compute_weights:NameNotAlphaNumeric',...
0151                 'Error: image name should contain only alpha-numeric elements!');
0152         end
0153         for i = 1:length(fas_idx)
0154             im_name{i} = [in.img_name,'_',PRT.fas(fas_idx(i)).mod_name];
0155         end
0156     else
0157         for i = 1:length(fas_idx)
0158             im_name{i} = ['weights_',mname,'_',PRT.fas(fas_idx(i)).mod_name];
0159         end
0160     end
0161     
0162     % Get the indexes in the feature set and ID mat for each modality
0163     ifa_all = PRT.fs(fs_idx).fas.ifa;
0164     im_all = PRT.fs(fs_idx).fas.im;
0165     name_fin = [];
0166     
0167     % Prepare outputs
0168     PRT.model(model_idx).output.weight_ROI = cell(nim,1);
0169     if exist('flag2','var') && flag2 && ~mult_kern_ROI
0170         PRT.model(model_idx).output.weight_idfeatroi = cell(nim,1);
0171         PRT.model(model_idx).output.weight_atlas = cell(nim,1);
0172     end
0173     
0174     imgcnt = 1;
0175     
0176     for i = 1:length(fas_idx)
0177         in.img_name = im_name{i};
0178         in.fas_idx = fas_idx(i);
0179         in.mm = find(mm(fas_idx(i),:));
0180         %Modify inputs according to file array and modality
0181         PRT.fs(fs_idx).id_mat(:,3) = in.fas_idx * ones(size(PRT.fs(fs_idx).id_mat,1),1);
0182         PRT.fs(fs_idx).fas.im = im_all(im_all == fas_idx(i));
0183         PRT.fs(fs_idx).fas.ifa = ifa_all(im_all == fas_idx(i));
0184         switch mtype
0185             case 'classification'
0186                 
0187                 % Compute image of voxel weights
0188                 img_name = prt_compute_weights_class(PRT,in,model_idx,flag,ibeta_mod{i});
0189                     
0190                 % Get the image names (multiple classes possible)
0191                 name_f = cell(length(img_name),1);
0192                 for j=1:size(name_f,1)
0193                     [du,name_f{j}] = spm_fileparts(img_name{j});
0194                 end
0195                 
0196                 % Build image of weights per region if asked for (flag2==1)
0197                 if exist('flag2','var') && flag2 
0198                     
0199                     if mult_kern_ROI % Kernels built from an atlas directly
0200                         disp('Building image of weights per region')
0201                         if length(name_f)>1 % multiple classes
0202                             in.img_name = ['ROI_',name_f{j}(1:end-2)];
0203                         else
0204                             in.img_name = ['ROI_',name_f{1}];
0205                         end
0206                         prt_compute_weights_class(PRT,in,model_idx,flag,ibeta_mod{i},1);
0207                         
0208                     else % Need to summarize the weights per region
0209                         disp('Building image of weights per region')
0210                         in.flag = flag;
0211                         summroi  = 1;
0212                         nimage = size(name_f,1); % Multiclass?
0213                         for c = 1:nimage
0214                             if c>1
0215                                 imgcnt = imgcnt + 1;
0216                             end
0217                             [NW idfeatroi] = prt_build_region_weights(img_name(c),in.atl_name,1,in.flag);
0218                             PRT.model(model_idx).output.weight_ROI(imgcnt) = {NW};
0219                             PRT.model(model_idx).output.weight_idfeatroi(imgcnt) = {idfeatroi};
0220                             PRT.model(model_idx).output.weight_atlas{imgcnt} = in.atl_name;
0221                         end
0222                     end
0223                 end
0224             case 'regression'
0225                 % Compute image of voxel weights
0226                 img_name = prt_compute_weights_regre(PRT,in,model_idx,flag,ibeta_mod{i});
0227                     
0228                 % Get the image names
0229                 [du,name_f{1}] = spm_fileparts(img_name{1});
0230                 
0231                 % Build image of weights per region if asked for (flag2==1)
0232                 if exist('flag2','var') && flag2 
0233                     
0234                     if mult_kern_ROI % Kernels built from an atlas directly
0235                         disp('Building image of weights per region')
0236                         in.img_name = ['ROI_',name_f{1}];
0237                         prt_compute_weights_regre(PRT,in,model_idx,flag,ibeta_mod{i},1);
0238                         
0239                     else % Need to summarize the weights per region
0240                         disp('Building image of weights per region')
0241                         in.flag = flag;
0242                         summroi = 1;
0243                         [NW idfeatroi] = prt_build_region_weights(img_name,in.atl_name,1,in.flag);
0244                         PRT.model(model_idx).output.weight_ROI(imgcnt) = {NW};
0245                         PRT.model(model_idx).output.weight_idfeatroi(imgcnt) = {idfeatroi};
0246                         PRT.model(model_idx).output.weight_atlas{imgcnt} = in.atl_name;
0247                     end
0248                 end
0249         end
0250         if ~iscell(img_name)
0251             img_name={img_name};
0252         end
0253         name_fin = [name_fin; img_name];
0254         imgcnt = imgcnt + 1;
0255     end
0256     PRT.fs(fs_idx).fas.ifa = ifa_all;
0257     PRT.fs(fs_idx).fas.im = im_all;
0258     PRT.fs(fs_idx).id_mat(:,3) = ones(size(PRT.fs(fs_idx).id_mat,1),1);
0259     
0260     % Used for the display of the weights per modality in
0261     % prt_ui_disp_weights
0262     if PRT.fs(fs_idx).multkernel && ~summroi    %create one image per modality, from MKL learning
0263         for i=1:size(name_fin,1)
0264             [du,name_fin{i}] = spm_fileparts(name_fin{i});
0265             if ~mult_kern_ROI
0266                 idb = 1:length(fas_idx);
0267             else
0268                 idb = ibeta_mod{i};
0269             end
0270             tmp = zeros(length(idb),length(PRT.model(model_idx).output.fold));
0271             for j = 1:length(PRT.model(model_idx).output.fold)
0272                 tmp(:,j) = [PRT.model(model_idx).output.fold(j).beta(idb)]';
0273             end
0274             betas = [tmp, mean(tmp,2)];
0275             if ~flag2 && ~mult_kern_ROI
0276                 PRT.model(model_idx).output.weight_ROI(i) = {betas}; % for now, replicate the betas for each modality and fill table
0277                 PRT.model(model_idx).output.weight_MOD(i) = {betas};
0278             elseif flag2 && mult_kern_ROI
0279                 PRT.model(model_idx).output.weight_ROI(i) = {betas}; % for now, replicate the betas for each modality and fill table
0280                 PRT.model(model_idx).output.weight_MOD(i) = {sum(betas,1)}; % sum the betas across regions for each modality
0281             end
0282         end
0283     else
0284         if PRT.fs(fs_idx).multkernel && summroi
0285             for i=1:size(name_fin,1)
0286                 idb = ibeta_mod{i};
0287                 tmp = zeros(length(idb),length(PRT.model(model_idx).output.fold));
0288                 for j = 1:length(PRT.model(model_idx).output.fold)
0289                     tmp(:,j) = [PRT.model(model_idx).output.fold(j).beta(idb)]';
0290                 end
0291                 betas = [tmp, mean(tmp,2)];                
0292                 PRT.model(model_idx).output.weight_MOD(i) = {betas}; %average of a multiple kernel on modalities
0293             end
0294         end
0295         for i=1:size(name_fin,1)
0296             [du,name_fin{i}] = spm_fileparts(name_fin{i}); %get rid of path
0297         end
0298     end
0299     
0300  % Only one modality or they have been concatenated
0301 else
0302     in.fas_idx=fas_idx;
0303     in.mm = [];
0304     for i=1:length(fas_idx)
0305         in.mm = [in.mm, find(mm(fas_idx(i),:))];
0306     end
0307     switch mtype
0308         case 'classification'
0309             img_name = prt_compute_weights_class(PRT,in,model_idx,flag);
0310             name_fin = cell(length(img_name),1);
0311             for i=1:length(name_fin)
0312                 [du,name_fin{i}] = spm_fileparts(img_name{i}); 
0313             end
0314             if exist('flag2','var') && flag2 % Build image of weights per region
0315                 disp('Building image of weights per region')
0316 
0317                 if mult_kern_ROI && ...
0318                         isfield(PRT.model(model_idx).output.fold(1),'beta') && ...
0319                         ~isempty(PRT.model(model_idx).output.fold(1).beta)
0320                     
0321                     if length(name_fin)>1 % multiple classes
0322                         in.img_name = ['ROI_',name_fin{j}(1:end-2)];
0323                     else
0324                         in.img_name = ['ROI_',name_fin{1}];
0325                     end
0326                     prt_compute_weights_class(PRT,in,model_idx,flag,[],1);
0327                     % Get the weights per region, which are the same for
0328                     % each class
0329                     tmp = [PRT.model(model_idx).output.fold(:).beta];
0330                     tmp = reshape(tmp,length(PRT.model(model_idx).output.fold(1).beta),...
0331                         length(PRT.model(model_idx).output.fold));
0332                     betas = [tmp, mean(tmp,2)];
0333                     for i = 1:size(name_fin,1)
0334                         PRT.model(model_idx).output.weight_ROI(i) = {betas};
0335                     end
0336                 else
0337                     in.flag = flag;
0338                     if isempty(in.atl_name) && mult_kern_ROI
0339                         in.atl_name = PRT.fs(fs_idx).atlas_name;
0340                     end                    
0341                     nimage = size(name_fin,1); % Multiclass?
0342                     PRT.model(model_idx).output.weight_ROI = cell(nimage,1);
0343                     for c = 1:nimage
0344                         [NW idfeatroi] = prt_build_region_weights(img_name(c),in.atl_name,1,in.flag);
0345                         PRT.model(model_idx).output.weight_ROI(c) = {NW};
0346                     end
0347                     PRT.model(model_idx).output.weight_idfeatroi{1} = idfeatroi;
0348                     PRT.model(model_idx).output.weight_atlas{1} = in.atl_name;
0349                 end
0350             else
0351                 PRT.model(model_idx).output.weight_ROI = [];
0352             end
0353         case 'regression'
0354             img_name = prt_compute_weights_regre(PRT,in,model_idx,flag);
0355             name_fin = cell(length(img_name),1);
0356             for i=1:length(name_fin)
0357                 [du,name_fin{i}] = spm_fileparts(img_name{i}); 
0358             end
0359              if exist('flag2','var') && flag2 % Build image of weights per region
0360                 if mult_kern_ROI && ...
0361                         isfield(PRT.model(model_idx).output.fold(1),'beta') && ...
0362                         ~isempty(PRT.model(model_idx).output.fold(1).beta)
0363                     disp('Building image of weights per region')                   
0364                     in.img_name = ['ROI_',name_fin{1}];
0365                     prt_compute_weights_regre(PRT,in,model_idx,flag,[],1);
0366                     tmp = [PRT.model(model_idx).output.fold(:).beta];
0367                     tmp = reshape(tmp,length(PRT.model(model_idx).output.fold(1).beta),...
0368                         length(PRT.model(model_idx).output.fold));
0369                     betas = [tmp, mean(tmp,2)];
0370                     PRT.model(model_idx).output.weight_ROI(1) = {betas}; %only one class for now
0371                 else
0372                     disp('Building image of weights per region')
0373                     in.flag = flag;
0374                     if isempty(in.atl_name) && mult_kern_ROI
0375                         in.atl_name = PRT.fs(fs_idx).atlas_name;
0376                     end
0377                     [NW idfeatroi] = prt_build_region_weights(img_name,in.atl_name,1,in.flag);
0378                     PRT.model(model_idx).output.weight_ROI(1) = {NW};
0379                     PRT.model(model_idx).output.weight_idfeatroi{1} = idfeatroi;
0380                     PRT.model(model_idx).output.weight_atlas{1} = in.atl_name;
0381                 end
0382              else
0383                  PRT.model(model_idx).output.weight_ROI = [];
0384              end
0385     end
0386 end
0387 
0388 if ~iscell(name_fin)
0389     name_fin = {name_fin};
0390 end
0391 PRT.model(model_idx).output.weight_img = name_fin;
0392 
0393 % Save the updated PRT
0394 %--------------------------------------------------------------------------
0395 outfile = fullfile(in.pathdir, 'PRT.mat');
0396 disp('Updating PRT.mat.......>>')
0397 if spm_check_version('MATLAB','7') < 0
0398     save(outfile,'-V6','PRT');
0399 else
0400     save(outfile,'PRT');
0401 end
0402 end
0403 
0404

Generated on Tue 10-Feb-2015 18:16:33 by m2html © 2005