Skip to content

Commit

Permalink
refactor BMA function, resolves #54
Browse files Browse the repository at this point in the history
  • Loading branch information
lionel-rigoux committed Apr 25, 2018
1 parent ad786ee commit 4fedc1d
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 73 deletions.
187 changes: 115 additions & 72 deletions VBA_BMA.m
Original file line number Diff line number Diff line change
@@ -1,112 +1,155 @@
function [p_BMA] = VBA_BMA(p0,F0)
function [p_BMA] = VBA_BMA (p0, F0)
% // VBA toolbox //////////////////////////////////////////////////////////
%
% [p_BMA] = VBA_BMA (p0, F0)
% performs Bayesian Model Averaging (BMA)
% function [p,o] = VBA_BMA(posterior,F)
%
% IN:
% - p0: a Kx1 cell-array of VBA posterior structures, which are
% conditional onto specific generative models (where K is the number of
% models)
% - F0: a Kx1 vector of log-model evidences
% - p0: a Kx1 array of VBA posterior structures, which are
% conditional onto specific generative models
% - F0: a Kx1 vector of the respective log-model evidences
% OUT:
% - p_BMA: the resulting posterior structure, with the first two moments of
% the marginal probability density functions
% - p_BMA: the resulting posterior structure that describe the marginal
% (over models) probability density functions
%
% /////////////////////////////////////////////////////////////////////////

K = length(p0); % # models
ps = softmax(F0);
% for retrocompatibility, accept cell array of posterior
if iscell (p0)
p0 = cell2mat (p0);
end

% shortcuts
% =========================================================================
% number of models
K = length (p0);

% posterior model probabilities
% =========================================================================
ps = softmax (F0);

% perform averaging
% =========================================================================

% observation parameters
mus = cell(K,1);
Qs = cell(K,1);
for k=1:K
mus{k} = p0{k}.muPhi;
Qs{k} = p0{k}.SigmaPhi;
% -------------------------------------------------------------------------
try
[p_BMA.muPhi, p_BMA.SigmaPhi] = averageMoments ({p0.muPhi}, {p0.SigmaPhi}, ps);
end
[p_BMA.muPhi,p_BMA.SigmaPhi] = get2moments(mus,Qs,ps);

% evolution parameters
mus = cell(K,1);
Qs = cell(K,1);
for k=1:K
mus{k} = p0{k}.muTheta;
Qs{k} = p0{k}.SigmaTheta;
% -------------------------------------------------------------------------
try
[p_BMA.muTheta, p_BMA.SigmaTheta] = averageMoments ({p0.muTheta}, {p0.SigmaTheta}, ps);
end
[p_BMA.muTheta,p_BMA.SigmaTheta] = get2moments(mus,Qs,ps);

% initial conditions
mus = cell(K,1);
Qs = cell(K,1);
for k=1:K
mus{k} = p0{k}.muX0;
Qs{k} = p0{k}.SigmaX0;
% -------------------------------------------------------------------------
try
[p_BMA.muX0, p_BMA.SigmaX0] = averageMoments ({p0.muX0}, {p0.SigmaX0}, ps);
end
[p_BMA.muX0,p_BMA.SigmaX0] = get2moments(mus,Qs,ps);


% hidden states
% -------------------------------------------------------------------------
try
T = size(p0{1}.muX,2);
for t=1:T
mus = cell(K,1);
Qs = cell(K,1);
for k=1:K
mus{k} = p0{k}.muX(:,t);
Qs{k} = p0{k}.SigmaX.current{t};
% number of timepoints
T = size (p0(1).muX, 2);
% initialisation
mus = cell (K, 1);
Qs = cell (K, 1);
% loop over timepoints
for t = 1 : T
% collect moments
for k=1:K
mus{k} = p0(k).muX(:, t);
Qs{k} = p0(k).SigmaX.current{t};
end
% compute average
[p_BMA.muX(:, t), p_BMA.SigmaX.current{t}] = averageMoments (mus, Qs, ps);
end
[p_BMA.muX(:,t),p_BMA.SigmaX.current{t}] = get2moments(mus,Qs,ps);
end
end

% data precision
mus = cell(K,1);
Qs = cell(K,1);
for iS=1:numel(p0{1}.a_sigma) % loop over sources
for k=1:K
mus{k} = p0{k}.a_sigma(iS)/p0{k}.b_sigma(iS);
Qs{k} = p0{k}.a_sigma(iS)/p0{k}.b_sigma(iS)^2;
% observation precision
% -------------------------------------------------------------------------
try
% number of gaussian sources
nS = numel (p0(1).a_sigma);
% initialisation
mus = cell (K, 1);
Qs = cell (K, 1);
% loop over sources
for iS = 1 : nS
% collect moments
for k = 1 : K
mus{k} = p0(k).a_sigma(iS) / p0(k).b_sigma(iS);
Qs{k} = p0(k).a_sigma(iS) / p0(k).b_sigma(iS) ^ 2;
end
% compute average
[m, v] = averageMoments (mus, Qs, ps);
% map to gamma distribution parameters
p_BMA.b_sigma(iS) = m / v;
p_BMA.a_sigma(iS) = m * p_BMA.b_sigma(iS);
end
[m,v] = get2moments(mus,Qs,ps);
p_BMA.b_sigma(iS) = m/v;
p_BMA.a_sigma(iS) = m*p_BMA.b_sigma(iS);
end

% hidden state precision
try
mus = cell(K,1);
Qs = cell(K,1);
id = zeros(K,1);

for k=1:K
mus{k} = p0{k}.a_alpha/p0{k}.b_alpha;
Qs{k} = p0{k}.a_alpha/p0{k}.b_alpha^2;
if (isempty(p0{k}.a_alpha) && isempty(p0{k}.b_alpha)) || (isinf(p0{k}.a_alpha) && p0{k}.b_alpha==0)
id(k) = 1;
% -------------------------------------------------------------------------
% initialisation
mus = cell (K, 1);
Qs = cell (K, 1);
isStochastic = nan (K, 1);
% collect moments, if any
for k = 1 : K
try
mus{k} = p0(k).a_alpha / p0(k).b_alpha;
Qs{k} = p0(k).a_alpha / p0(k).b_alpha ^ 2;
isStochastic(k) = ~ isempty(mus{k}) && ~ isinf(mus{k});
catch
isStochastic(k) = false;
end
end

if isequal(sum(id),K) % all deterministic systems
% compute average, if meaninful
% + all deterministic systems
if ~ any (isStochastic)
p_BMA.b_alpha = Inf;
p_BMA.a_alpha = 0;
elseif isequal(sum(id),0) % all stochastic systems
[m,v] = get2moments(mus,Qs,ps);
p_BMA.b_alpha = m/v;
p_BMA.a_alpha = m*p_BMA.b_alpha;
% + all stochastic systems
elseif all (isStochastic)
% average moments
[m, v] = averageMoments(mus, Qs, ps);
% map to gamma distribution parameters
p_BMA.b_alpha = m / v;
p_BMA.a_alpha = m * p_BMA.b_alpha;
else
disp('VBA_MBA: Warning: mixture of deterministic and stochastic models!')
p_BMA.b_alpha = Inf;
p_BMA.a_alpha = 0;
end
% + mixture of stochastic and deterministic systems
disp('VBA_MBA: Warning! mixture of deterministic and stochastic models!')
p_BMA.b_alpha = NaN;
p_BMA.a_alpha = NaN;
end

end



% #########################################################################
% Subfunctions
% #########################################################################

function [m,V] = get2moments(mus,Qs,ps)
% Compute averages of 1st order (mus) and 2nd order (Qs) moments of
% distributions, weigthed by ps.
function [m, V] = averageMoments (mus, Qs, ps)
% initialisation
V = zeros(size(Qs{1}));
m = zeros(size(mus{1}));
K = length(ps);
for k=1:K
m = m + ps(k).*mus{k};
% average 1st order moments
for k = 1 : K
m = m + ps(k) .* mus{k};
end
for k=1:K
% average 2nd order moments
for k = 1 : K
tmp = mus{k} - m;
V = V + ps(k).*(tmp*tmp' + Qs{k});
V = V + ps(k) .* (tmp * tmp' + Qs{k});
end
end
2 changes: 1 addition & 1 deletion stats&plots/bayesian_ttest.m
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
h = (ep>=0.95);

% posterior estimates
p = VBA_BMA({p0;p1},[F0;F1]);
p = VBA_BMA([p0;p1],[F0;F1]);
% posterior.mu = p.muPhi ;
posterior.mu = [ p.muPhi(1) , sparsify(p.muPhi(2),log(2)) ];

Expand Down

0 comments on commit 4fedc1d

Please sign in to comment.