function h = plot_schemaball(R, labels, networks)
% plot_schemaball Plots correlation matrix as a schemaball
%
%
%
%
% Additional features:
% - FEX schemaball page
% - Origin: question on Stackoverflow.com
% - Schemaball by Gunther Struyf
%
% See also: CORR, CORRPLOT
% Author: Oleg Komarov (oleg.komarov@hotmail.it)
% Tested on R2013a Win7 64 and Vista 32
% 15 jun 2013 - Created
%% Parameters
% Tweak these only
% Number of color shades/buckets (large N simply creates many perceptually indifferent color shades)
N = 100;
% Points in [0, 1] for bezier curves: leave space at the extremes to detach a bit the nodes.
% Smaller step will use more points to plot the curves.
t = (0.025: 0.05 :1)';
% Nodes edge color
% ecolor = [.25 .103922 .012745];
ecolor = [0 0 0];
% Text color
tcolor = [.3 .3 .3];
%% Checks
% create rgb colors for networks
nParcel = length(R);
% create rgb colors for networks
[C, ia, ic] = unique(networks);
networkColors = zeros(nParcel, 3);
colors = parula(numel(ia));
for iNetwork = 1:nParcel
networkColors(iNetwork, :) = colors(ic(iNetwork), :);
end
%% Weight Preprocessing
% scale weights to min max of -1 and 1
R(isnan(R)) = 0;
sizeR = size(R);
%% Engine
% Create figure
figure('renderer','zbuffer','visible','off')
axes('NextPlot','add')
% Use tau http://tauday.com/tau-manifesto
tau = 2*pi;
% Positions of nodes on the circle starting from (0,-1), useful later for label orientation
step = tau/sizeR(1);
theta = -.25*tau : step : .75*tau - step;
% Get cartesian x-y coordinates of the nodes
x = cos(theta) * 2;
y = sin(theta) * 2;
% PLOT BEZIER CURVES
% Calculate Bx and By positions of quadratic Bezier curves with P1 at (0,0)
% B(t) = (1-t)^2*P0 + t^2*P2 where t is a vector of points in [0, 1] and determines, i.e.
% how many points are used for each curve, and P0-P2 is the node pair with (x,y) coordinates.
t2 = [1-t, t].^2;
N2 = 2*N;
s.l = NaN(N2,1);
% Index only low triangular matrix without main diag
tf = tril(true(sizeR),-1);
minR = min(R(tf));
maxR = max(R(tf));
R(~tf) = 0;
% Retrieve pairings of nodes
[row,col] = find(R);
for ci = 1:numel(row)
v = R(row(ci), col(ci));
if v
Bx = [t2*[x(col(ci)); x(row(ci))]];
By = [t2*[y(col(ci)); y(row(ci))]];
if v > 0
color = [0 89/256 179/256 abs(v)/maxR];
else
color = [204/256 0 0 abs(v)/abs(minR)];
end
s.l = plot(Bx(:),By(:),'Color', color, 'LineWidth', 3);
end
end
% PLOT NODES
% Do not rely that r is symmetric and base the mean on lower triangular part only
[row,col] = find(tf(end:-1:1,end:-1:1) | tf);
subs = col;
iswap = row < col;
tmp = row(iswap);
row(iswap) = col(iswap);
col(iswap) = tmp;
% Plot in brighter color those nodes with larger within connections
[Z,isrt] = sort(diag(R));
%s.s = scatter(x(isrt),y(isrt),100, networkColors,'fill');
cmap = hsv(numel(unique(networks)));
cmap(:, 4) = 0.5;
s.s = gscatter(x(isrt),y(isrt), networks, cmap);
lgd = legend();
lgd.FontSize = 14;
legend boxoff
% PLACE TEXT LABELS such that you always read 'left to right'
if numel(labels) > 200
labels = num2cell(1:numel(labels));
end
ipos = x > 0;
s.t = zeros(sizeR(1),1);
s.t( ipos) = text(x( ipos)*1.05, y( ipos)*1.05, labels( ipos),'FontSize',6,'Color',tcolor);
set(s.t( ipos),{'Rotation'}, num2cell(theta(ipos)'/tau*360))
s.t(~ipos) = text(x(~ipos)*1.05, y(~ipos)*1.05, labels(~ipos),'FontSize',6,'Color',tcolor);
set(s.t(~ipos),{'Rotation'}, num2cell(theta(~ipos)'/tau*360-180),'Horiz','right')
% ADJUST FIGURE height width to fit text labels
xtn = cell2mat(get(s.t,'extent'));
post = cell2mat(get(s.t,'pos'));
sg = sign(post(:,2));
posfa = cell2mat(get([gcf gca],'pos'));
% Calculate xlim and ylim in data units as x (y) position + extension along x (y)
ylims = post(:,2) + xtn(:,4).*sg;
ylims = [min(ylims), max(ylims)];
xlims = post(:,1) + xtn(:,3).*sg;
xlims = [min(xlims), max(xlims)];
% Stretch figure
posfa(1,3) = (( diff(xlims)/2 - 1)*posfa(2,3) + 1) * posfa(1,3);
posfa(1,4) = (( diff(ylims)/2 - 1)*posfa(2,4) + 1) * posfa(1,4);
% Position it a bit lower (movegui slow)
posfa(1,2) = 100;
% Axis settings
set(gca, 'Xlim',xlims,'Ylim',ylims, 'XColor','none','YColor','none',...
'clim',[min(R(tf)),max(R(tf))])
set(gcf, 'pos' ,posfa(1,:),'Visible','on')
axis equal
if nargout == 1
h = s;
end
end