Home > . > prt_compute_weights.m

prt_compute_weights

PURPOSE ^

FORMAT prt_compute_weights(PRT,in)

SYNOPSIS ^

function prt_compute_weights(PRT,in)

DESCRIPTION ^

 FORMAT prt_compute_weights(PRT,in)

 This function calls prt_weights to compute weights 
 Inputs:
       PRT             - data/design/model structure (it needs to contain
                         at least one estimated model).
         in            - structure with specific information to create
                         weights
           .model_name - model name (string)
           .img_name   - (optional) name of the file to be created
                         (string)
           .pathdir    - directory path where to save weights (same as the
                         one for PRT.mat) (string)
 Output:
       empty           - does not return anything (it creates an .img file)
__________________________________________________________________________
 Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function prt_compute_weights(PRT,in)
0002 % FORMAT prt_compute_weights(PRT,in)
0003 %
0004 % This function calls prt_weights to compute weights
0005 % Inputs:
0006 %       PRT             - data/design/model structure (it needs to contain
0007 %                         at least one estimated model).
0008 %         in            - structure with specific information to create
0009 %                         weights
0010 %           .model_name - model name (string)
0011 %           .img_name   - (optional) name of the file to be created
0012 %                         (string)
0013 %           .pathdir    - directory path where to save weights (same as the
0014 %                         one for PRT.mat) (string)
0015 % Output:
0016 %       empty           - does not return anything (it creates an .img file)
0017 %__________________________________________________________________________
0018 % Copyright (C) 2011 Machine Learning & Neuroimaging Laboratory
0019 
0020 % Written by M.J.Rosa
0021 % $Id: prt_compute_weights.m 526 2012-05-16 10:40:55Z amarquan $
0022 
0023 % Find model
0024 % -------------------------------------------------------------------------
0025 nmodel = length(PRT.model);
0026 model_idx = 0;
0027 for i = 1:nmodel
0028     if strcmp(PRT.model(i).model_name,in.model_name)
0029         model_idx = i;
0030     end
0031 end
0032 % Check if model exists
0033 if model_idx == 0, error('prt_compute_weights:ModelNotFound',...
0034         'Error: model not found in PRT.mat!'); end
0035 
0036 % Find machine
0037 % -------------------------------------------------------------------------
0038 mfunc       = PRT.model(model_idx).input.machine.function;
0039 mname       = PRT.model(model_idx).model_name;
0040 m.args      = [];
0041 m.function  = 'prt_weights_bin_linkernel';
0042 
0043 % unfortunately a bug somewhere causes shifts in weight image if
0044 % .nii is used...
0045 img_mach    = ['weights_',mname,'.img'];
0046 
0047 switch mfunc
0048     case 'prt_machine_RT_bin'
0049         error('prt_compute_weights:MachineNotSupported',...
0050             'Error: weights computation not supported for this machine!');
0051     case 'prt_machine_gpclap'
0052         error('prt_compute_weights:MachineNotSupported',...
0053             'Error: weights computation not supported yet for this machine!');
0054 end
0055 
0056 % Image name
0057 % -------------------------------------------------------------------------
0058 
0059 if ~isempty(in.img_name)
0060     if ~(prt_checkAlphaNumUnder(in.img_name))
0061         error('prt_compute_weights:NameNotAlphaNumeric',...
0062             'Error: image name should contain only alpha-numeric elements!');
0063     end
0064     in.img_name = [in.img_name,'.img'];
0065     img_name    = fullfile(in.pathdir,in.img_name);
0066 else
0067     img_name    = fullfile(in.pathdir,img_mach);
0068 end
0069 
0070 % Other info
0071 % -------------------------------------------------------------------------
0072 fs_name  = PRT.model(model_idx).input.fs(1).fs_name;
0073 samp_idx = PRT.model(model_idx).input.samp_idx;
0074 nfold    = length(PRT.model(model_idx).output.fold);
0075 
0076 % Find feature set
0077 % -------------------------------------------------------------------------
0078 nfs = length(PRT.fs);
0079 for f = 1:nfs
0080     if strcmp(PRT.fs(f).fs_name,fs_name)
0081         fs_idx = f;
0082     end
0083 end
0084 ID     = PRT.fs(fs_idx).id_mat(PRT.model(model_idx).input.samp_idx,:);
0085 ID_all = PRT.fs(fs_idx).id_mat;
0086 
0087 % Find modality
0088 % -------------------------------------------------------------------------
0089 nfas = length(PRT.fas);
0090 mods = {PRT.fs(fs_idx).modality.mod_name};
0091 fas  = zeros(1,nfas);
0092 for i = 1:nfas
0093     for j = 1:length(mods)
0094         if strcmpi(PRT.fas(i).mod_name,mods{j})
0095             fas(i) = 1;
0096             mm=j;
0097         end
0098     end
0099 end
0100 fas_idx = find(fas);
0101 
0102 % Get the indexes of the voxels which are in the second level mask
0103 % -------------------------------------------------------------------------
0104 idfeat=PRT.fas(fas_idx(1)).idfeat_img;
0105 if ~isempty(PRT.fs(fs_idx).modality(mm).idfeat_fas)
0106     mask_train=idfeat(PRT.fs(fs_idx).modality(mm).idfeat_fas);
0107     voxtr=find(ismember(idfeat,mask_train));
0108 else
0109     mask_train=idfeat;
0110     voxtr=1:length(idfeat);
0111 end
0112 
0113 % Create image
0114 % -------------------------------------------------------------------------
0115 hdr        = PRT.fas(fas_idx(1)).hdr.private;
0116 img4d      = file_array(img_name,[hdr.dat.dim(1),hdr.dat.dim(2),...
0117     hdr.dat.dim(3),nfold+1],'float64-le',0,1,0);
0118 
0119 zdim    = hdr.dat.dim(3);
0120 xydim   = hdr.dat.dim(1)*hdr.dat.dim(2);
0121 norm3d  = 0;
0122 
0123 disp('Computing weights.......>>')
0124 
0125 for z = 1:zdim
0126     
0127     disp(sprintf('Slice: %d of %d',z,zdim))
0128     
0129     img3dav  = zeros(1,xydim); % average weight map
0130     
0131     feat_slc = find(mask_train>=(xydim*(z-1)+1) & ...
0132         mask_train<=(xydim*z));
0133     
0134     if isempty(feat_slc)
0135         
0136         img4d(:,:,z,:) = zeros(hdr.dat.dim(1),hdr.dat.dim(2),1,nfold+1);
0137         
0138     else
0139         
0140         for f = 1:nfold
0141             
0142             train_idx      = PRT.model(model_idx).input.cv_mat(:,f)==1;
0143             train          = samp_idx(train_idx);
0144             train_all      = zeros(size(ID_all,1),1); train_all(train) = 1;
0145             
0146             d.coeffs       = PRT.model(model_idx).output.fold(f).alpha;
0147             
0148             d.datamat = zeros(length(train), length(feat_slc));
0149             for i = 1:length(fas_idx)
0150                 % indexes to access the file array
0151                 indm = PRT.fs(fs_idx).fas.im == fas_idx(i) & train_all; 
0152                 ifa  = PRT.fs(fs_idx).fas.ifa(indm);
0153                 
0154                 % index for the target data matrix
0155                 indtr = ID(train_idx,3) == fas_idx(i);
0156                 d.datamat(indtr,:) = PRT.fas(fas_idx(i)).dat(ifa,voxtr(feat_slc));
0157             end
0158             
0159             % Apply any operations specified during training
0160             ops = PRT.model(model_idx).input.operations(PRT.model(model_idx).input.operations ~=0 );
0161             cvdata.train      = {d.datamat};
0162             cvdata.tr_id      = ID(train_idx,:);
0163             cvdata.use_kernel = false; % need to apply the operation to the data
0164             for o = 1:length(ops)
0165                 cvdata = prt_apply_operation(PRT, cvdata, ops(o));
0166             end
0167             d.datamat = cvdata.train{:};
0168             
0169             wimg           = prt_weights(d,m);
0170             
0171             img3d          = zeros(1,xydim);
0172             
0173             img3d(mask_train(feat_slc)-xydim*(z-1)) = wimg;
0174             
0175             norm3d(f)      = sum(img3d.^2);
0176             
0177             img3dav        = img3dav + img3d;
0178             
0179             img4d(:,:,z,f) = reshape(img3d,hdr.dat.dim(1),hdr.dat.dim(2),1,1);
0180             
0181         end
0182         
0183         norm4d(z,:) = norm3d;
0184         
0185         % Create average fold
0186         %--------------------------------------------------------------------------
0187         img4d(:,:,z,nfold+1) = reshape(img3dav,hdr.dat.dim(1),hdr.dat.dim(2),...
0188             1,1)/nfold;        
0189     end
0190     
0191 end
0192 
0193 norm4d = sqrt(sum(norm4d,1));
0194 
0195 disp('Normalising weights--------->>')
0196 for f = 1:nfold,
0197     img4d(:,:,:,f) = img4d(:,:,:,f)./norm4d(1,f);
0198 end
0199 
0200 % Create weigths file
0201 %--------------------------------------------------------------------------
0202 disp('Creating image--------->>')
0203 No         = hdr;              % copy header
0204 No.dat     = img4d;            % change file_array
0205 No.descrip = 'Pronto weigths'; % description
0206 create(No);                    % write header
0207 disp('Done.')

Generated on Mon 03-Sep-2012 18:07:18 by m2html © 2005