diff --git a/VBA_BMA.m b/VBA_BMA.m index 4aa92f26..540cd463 100644 --- a/VBA_BMA.m +++ b/VBA_BMA.m @@ -1,112 +1,155 @@ -function [p_BMA] = VBA_BMA(p0,F0) +function [p_BMA] = VBA_BMA (p0, F0) +% // VBA toolbox ////////////////////////////////////////////////////////// +% +% [p_BMA] = VBA_BMA (p0, F0) % performs Bayesian Model Averaging (BMA) -% function [p,o] = VBA_BMA(posterior,F) +% % IN: -% - p0: a Kx1 cell-array of VBA posterior structures, which are -% conditional onto specific generative models (where K is the number of -% models) -% - F0: a Kx1 vector of log-model evidences +% - p0: a Kx1 array of VBA posterior structures, which are +% conditional onto specific generative models +% - F0: a Kx1 vector of the respective log-model evidences % OUT: -% - p_BMA: the resulting posterior structure, with the first two moments of -% the marginal probability density functions +% - p_BMA: the resulting posterior structure that describe the marginal +% (over models) probability density functions +% +% ///////////////////////////////////////////////////////////////////////// -K = length(p0); % # models -ps = softmax(F0); +% for retrocompatibility, accept cell array of posterior +if iscell (p0) + p0 = cell2mat (p0); +end + +% shortcuts +% ========================================================================= +% number of models +K = length (p0); + +% posterior model probabilities +% ========================================================================= +ps = softmax (F0); + +% perform averaging +% ========================================================================= % observation parameters -mus = cell(K,1); -Qs = cell(K,1); -for k=1:K - mus{k} = p0{k}.muPhi; - Qs{k} = p0{k}.SigmaPhi; +% ------------------------------------------------------------------------- +try + [p_BMA.muPhi, p_BMA.SigmaPhi] = averageMoments ({p0.muPhi}, {p0.SigmaPhi}, ps); end -[p_BMA.muPhi,p_BMA.SigmaPhi] = get2moments(mus,Qs,ps); % evolution parameters -mus = cell(K,1); -Qs = cell(K,1); -for k=1:K - mus{k} = p0{k}.muTheta; - Qs{k} = p0{k}.SigmaTheta; +% ------------------------------------------------------------------------- +try + [p_BMA.muTheta, p_BMA.SigmaTheta] = averageMoments ({p0.muTheta}, {p0.SigmaTheta}, ps); end -[p_BMA.muTheta,p_BMA.SigmaTheta] = get2moments(mus,Qs,ps); % initial conditions -mus = cell(K,1); -Qs = cell(K,1); -for k=1:K - mus{k} = p0{k}.muX0; - Qs{k} = p0{k}.SigmaX0; +% ------------------------------------------------------------------------- +try + [p_BMA.muX0, p_BMA.SigmaX0] = averageMoments ({p0.muX0}, {p0.SigmaX0}, ps); end -[p_BMA.muX0,p_BMA.SigmaX0] = get2moments(mus,Qs,ps); + % hidden states +% ------------------------------------------------------------------------- try -T = size(p0{1}.muX,2); -for t=1:T - mus = cell(K,1); - Qs = cell(K,1); - for k=1:K - mus{k} = p0{k}.muX(:,t); - Qs{k} = p0{k}.SigmaX.current{t}; + % number of timepoints + T = size (p0(1).muX, 2); + % initialisation + mus = cell (K, 1); + Qs = cell (K, 1); + % loop over timepoints + for t = 1 : T + % collect moments + for k=1:K + mus{k} = p0(k).muX(:, t); + Qs{k} = p0(k).SigmaX.current{t}; + end + % compute average + [p_BMA.muX(:, t), p_BMA.SigmaX.current{t}] = averageMoments (mus, Qs, ps); end - [p_BMA.muX(:,t),p_BMA.SigmaX.current{t}] = get2moments(mus,Qs,ps); -end end -% data precision -mus = cell(K,1); -Qs = cell(K,1); -for iS=1:numel(p0{1}.a_sigma) % loop over sources - for k=1:K - mus{k} = p0{k}.a_sigma(iS)/p0{k}.b_sigma(iS); - Qs{k} = p0{k}.a_sigma(iS)/p0{k}.b_sigma(iS)^2; +% observation precision +% ------------------------------------------------------------------------- +try + % number of gaussian sources + nS = numel (p0(1).a_sigma); + % initialisation + mus = cell (K, 1); + Qs = cell (K, 1); + % loop over sources + for iS = 1 : nS + % collect moments + for k = 1 : K + mus{k} = p0(k).a_sigma(iS) / p0(k).b_sigma(iS); + Qs{k} = p0(k).a_sigma(iS) / p0(k).b_sigma(iS) ^ 2; + end + % compute average + [m, v] = averageMoments (mus, Qs, ps); + % map to gamma distribution parameters + p_BMA.b_sigma(iS) = m / v; + p_BMA.a_sigma(iS) = m * p_BMA.b_sigma(iS); end - [m,v] = get2moments(mus,Qs,ps); - p_BMA.b_sigma(iS) = m/v; - p_BMA.a_sigma(iS) = m*p_BMA.b_sigma(iS); end % hidden state precision -try -mus = cell(K,1); -Qs = cell(K,1); -id = zeros(K,1); - -for k=1:K - mus{k} = p0{k}.a_alpha/p0{k}.b_alpha; - Qs{k} = p0{k}.a_alpha/p0{k}.b_alpha^2; - if (isempty(p0{k}.a_alpha) && isempty(p0{k}.b_alpha)) || (isinf(p0{k}.a_alpha) && p0{k}.b_alpha==0) - id(k) = 1; +% ------------------------------------------------------------------------- +% initialisation +mus = cell (K, 1); +Qs = cell (K, 1); +isStochastic = nan (K, 1); +% collect moments, if any +for k = 1 : K + try + mus{k} = p0(k).a_alpha / p0(k).b_alpha; + Qs{k} = p0(k).a_alpha / p0(k).b_alpha ^ 2; + isStochastic(k) = ~ isempty(mus{k}) && ~ isinf(mus{k}); + catch + isStochastic(k) = false; end end - -if isequal(sum(id),K) % all deterministic systems +% compute average, if meaninful +% + all deterministic systems +if ~ any (isStochastic) p_BMA.b_alpha = Inf; p_BMA.a_alpha = 0; -elseif isequal(sum(id),0) % all stochastic systems - [m,v] = get2moments(mus,Qs,ps); - p_BMA.b_alpha = m/v; - p_BMA.a_alpha = m*p_BMA.b_alpha; +% + all stochastic systems +elseif all (isStochastic) + % average moments + [m, v] = averageMoments(mus, Qs, ps); + % map to gamma distribution parameters + p_BMA.b_alpha = m / v; + p_BMA.a_alpha = m * p_BMA.b_alpha; else - disp('VBA_MBA: Warning: mixture of deterministic and stochastic models!') - p_BMA.b_alpha = Inf; - p_BMA.a_alpha = 0; -end +% + mixture of stochastic and deterministic systems + disp('VBA_MBA: Warning! mixture of deterministic and stochastic models!') + p_BMA.b_alpha = NaN; + p_BMA.a_alpha = NaN; end +end +% ######################################################################### +% Subfunctions +% ######################################################################### -function [m,V] = get2moments(mus,Qs,ps) +% Compute averages of 1st order (mus) and 2nd order (Qs) moments of +% distributions, weigthed by ps. +function [m, V] = averageMoments (mus, Qs, ps) +% initialisation V = zeros(size(Qs{1})); m = zeros(size(mus{1})); K = length(ps); -for k=1:K - m = m + ps(k).*mus{k}; +% average 1st order moments +for k = 1 : K + m = m + ps(k) .* mus{k}; end -for k=1:K +% average 2nd order moments +for k = 1 : K tmp = mus{k} - m; - V = V + ps(k).*(tmp*tmp' + Qs{k}); + V = V + ps(k) .* (tmp * tmp' + Qs{k}); +end end \ No newline at end of file diff --git a/stats&plots/bayesian_ttest.m b/stats&plots/bayesian_ttest.m index c804e4b4..014279d5 100644 --- a/stats&plots/bayesian_ttest.m +++ b/stats&plots/bayesian_ttest.m @@ -128,7 +128,7 @@ h = (ep>=0.95); % posterior estimates -p = VBA_BMA({p0;p1},[F0;F1]); +p = VBA_BMA([p0;p1],[F0;F1]); % posterior.mu = p.muPhi ; posterior.mu = [ p.muPhi(1) , sparsify(p.muPhi(2),log(2)) ];