-
Notifications
You must be signed in to change notification settings - Fork 67
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ad786ee
commit 4fedc1d
Showing
2 changed files
with
116 additions
and
73 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,112 +1,155 @@ | ||
function [p_BMA] = VBA_BMA(p0,F0) | ||
function [p_BMA] = VBA_BMA (p0, F0) | ||
% // VBA toolbox ////////////////////////////////////////////////////////// | ||
% | ||
% [p_BMA] = VBA_BMA (p0, F0) | ||
% performs Bayesian Model Averaging (BMA) | ||
% function [p,o] = VBA_BMA(posterior,F) | ||
% | ||
% IN: | ||
% - p0: a Kx1 cell-array of VBA posterior structures, which are | ||
% conditional onto specific generative models (where K is the number of | ||
% models) | ||
% - F0: a Kx1 vector of log-model evidences | ||
% - p0: a Kx1 array of VBA posterior structures, which are | ||
% conditional onto specific generative models | ||
% - F0: a Kx1 vector of the respective log-model evidences | ||
% OUT: | ||
% - p_BMA: the resulting posterior structure, with the first two moments of | ||
% the marginal probability density functions | ||
% - p_BMA: the resulting posterior structure that describe the marginal | ||
% (over models) probability density functions | ||
% | ||
% ///////////////////////////////////////////////////////////////////////// | ||
|
||
K = length(p0); % # models | ||
ps = softmax(F0); | ||
% for retrocompatibility, accept cell array of posterior | ||
if iscell (p0) | ||
p0 = cell2mat (p0); | ||
end | ||
|
||
% shortcuts | ||
% ========================================================================= | ||
% number of models | ||
K = length (p0); | ||
|
||
% posterior model probabilities | ||
% ========================================================================= | ||
ps = softmax (F0); | ||
|
||
% perform averaging | ||
% ========================================================================= | ||
|
||
% observation parameters | ||
mus = cell(K,1); | ||
Qs = cell(K,1); | ||
for k=1:K | ||
mus{k} = p0{k}.muPhi; | ||
Qs{k} = p0{k}.SigmaPhi; | ||
% ------------------------------------------------------------------------- | ||
try | ||
[p_BMA.muPhi, p_BMA.SigmaPhi] = averageMoments ({p0.muPhi}, {p0.SigmaPhi}, ps); | ||
end | ||
[p_BMA.muPhi,p_BMA.SigmaPhi] = get2moments(mus,Qs,ps); | ||
|
||
% evolution parameters | ||
mus = cell(K,1); | ||
Qs = cell(K,1); | ||
for k=1:K | ||
mus{k} = p0{k}.muTheta; | ||
Qs{k} = p0{k}.SigmaTheta; | ||
% ------------------------------------------------------------------------- | ||
try | ||
[p_BMA.muTheta, p_BMA.SigmaTheta] = averageMoments ({p0.muTheta}, {p0.SigmaTheta}, ps); | ||
end | ||
[p_BMA.muTheta,p_BMA.SigmaTheta] = get2moments(mus,Qs,ps); | ||
|
||
% initial conditions | ||
mus = cell(K,1); | ||
Qs = cell(K,1); | ||
for k=1:K | ||
mus{k} = p0{k}.muX0; | ||
Qs{k} = p0{k}.SigmaX0; | ||
% ------------------------------------------------------------------------- | ||
try | ||
[p_BMA.muX0, p_BMA.SigmaX0] = averageMoments ({p0.muX0}, {p0.SigmaX0}, ps); | ||
end | ||
[p_BMA.muX0,p_BMA.SigmaX0] = get2moments(mus,Qs,ps); | ||
|
||
|
||
% hidden states | ||
% ------------------------------------------------------------------------- | ||
try | ||
T = size(p0{1}.muX,2); | ||
for t=1:T | ||
mus = cell(K,1); | ||
Qs = cell(K,1); | ||
for k=1:K | ||
mus{k} = p0{k}.muX(:,t); | ||
Qs{k} = p0{k}.SigmaX.current{t}; | ||
% number of timepoints | ||
T = size (p0(1).muX, 2); | ||
% initialisation | ||
mus = cell (K, 1); | ||
Qs = cell (K, 1); | ||
% loop over timepoints | ||
for t = 1 : T | ||
% collect moments | ||
for k=1:K | ||
mus{k} = p0(k).muX(:, t); | ||
Qs{k} = p0(k).SigmaX.current{t}; | ||
end | ||
% compute average | ||
[p_BMA.muX(:, t), p_BMA.SigmaX.current{t}] = averageMoments (mus, Qs, ps); | ||
end | ||
[p_BMA.muX(:,t),p_BMA.SigmaX.current{t}] = get2moments(mus,Qs,ps); | ||
end | ||
end | ||
|
||
% data precision | ||
mus = cell(K,1); | ||
Qs = cell(K,1); | ||
for iS=1:numel(p0{1}.a_sigma) % loop over sources | ||
for k=1:K | ||
mus{k} = p0{k}.a_sigma(iS)/p0{k}.b_sigma(iS); | ||
Qs{k} = p0{k}.a_sigma(iS)/p0{k}.b_sigma(iS)^2; | ||
% observation precision | ||
% ------------------------------------------------------------------------- | ||
try | ||
% number of gaussian sources | ||
nS = numel (p0(1).a_sigma); | ||
% initialisation | ||
mus = cell (K, 1); | ||
Qs = cell (K, 1); | ||
% loop over sources | ||
for iS = 1 : nS | ||
% collect moments | ||
for k = 1 : K | ||
mus{k} = p0(k).a_sigma(iS) / p0(k).b_sigma(iS); | ||
Qs{k} = p0(k).a_sigma(iS) / p0(k).b_sigma(iS) ^ 2; | ||
end | ||
% compute average | ||
[m, v] = averageMoments (mus, Qs, ps); | ||
% map to gamma distribution parameters | ||
p_BMA.b_sigma(iS) = m / v; | ||
p_BMA.a_sigma(iS) = m * p_BMA.b_sigma(iS); | ||
end | ||
[m,v] = get2moments(mus,Qs,ps); | ||
p_BMA.b_sigma(iS) = m/v; | ||
p_BMA.a_sigma(iS) = m*p_BMA.b_sigma(iS); | ||
end | ||
|
||
% hidden state precision | ||
try | ||
mus = cell(K,1); | ||
Qs = cell(K,1); | ||
id = zeros(K,1); | ||
|
||
for k=1:K | ||
mus{k} = p0{k}.a_alpha/p0{k}.b_alpha; | ||
Qs{k} = p0{k}.a_alpha/p0{k}.b_alpha^2; | ||
if (isempty(p0{k}.a_alpha) && isempty(p0{k}.b_alpha)) || (isinf(p0{k}.a_alpha) && p0{k}.b_alpha==0) | ||
id(k) = 1; | ||
% ------------------------------------------------------------------------- | ||
% initialisation | ||
mus = cell (K, 1); | ||
Qs = cell (K, 1); | ||
isStochastic = nan (K, 1); | ||
% collect moments, if any | ||
for k = 1 : K | ||
try | ||
mus{k} = p0(k).a_alpha / p0(k).b_alpha; | ||
Qs{k} = p0(k).a_alpha / p0(k).b_alpha ^ 2; | ||
isStochastic(k) = ~ isempty(mus{k}) && ~ isinf(mus{k}); | ||
catch | ||
isStochastic(k) = false; | ||
end | ||
end | ||
|
||
if isequal(sum(id),K) % all deterministic systems | ||
% compute average, if meaninful | ||
% + all deterministic systems | ||
if ~ any (isStochastic) | ||
p_BMA.b_alpha = Inf; | ||
p_BMA.a_alpha = 0; | ||
elseif isequal(sum(id),0) % all stochastic systems | ||
[m,v] = get2moments(mus,Qs,ps); | ||
p_BMA.b_alpha = m/v; | ||
p_BMA.a_alpha = m*p_BMA.b_alpha; | ||
% + all stochastic systems | ||
elseif all (isStochastic) | ||
% average moments | ||
[m, v] = averageMoments(mus, Qs, ps); | ||
% map to gamma distribution parameters | ||
p_BMA.b_alpha = m / v; | ||
p_BMA.a_alpha = m * p_BMA.b_alpha; | ||
else | ||
disp('VBA_MBA: Warning: mixture of deterministic and stochastic models!') | ||
p_BMA.b_alpha = Inf; | ||
p_BMA.a_alpha = 0; | ||
end | ||
% + mixture of stochastic and deterministic systems | ||
disp('VBA_MBA: Warning! mixture of deterministic and stochastic models!') | ||
p_BMA.b_alpha = NaN; | ||
p_BMA.a_alpha = NaN; | ||
end | ||
|
||
end | ||
|
||
|
||
|
||
% ######################################################################### | ||
% Subfunctions | ||
% ######################################################################### | ||
|
||
function [m,V] = get2moments(mus,Qs,ps) | ||
% Compute averages of 1st order (mus) and 2nd order (Qs) moments of | ||
% distributions, weigthed by ps. | ||
function [m, V] = averageMoments (mus, Qs, ps) | ||
% initialisation | ||
V = zeros(size(Qs{1})); | ||
m = zeros(size(mus{1})); | ||
K = length(ps); | ||
for k=1:K | ||
m = m + ps(k).*mus{k}; | ||
% average 1st order moments | ||
for k = 1 : K | ||
m = m + ps(k) .* mus{k}; | ||
end | ||
for k=1:K | ||
% average 2nd order moments | ||
for k = 1 : K | ||
tmp = mus{k} - m; | ||
V = V + ps(k).*(tmp*tmp' + Qs{k}); | ||
V = V + ps(k) .* (tmp * tmp' + Qs{k}); | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters