Skip to content

Commit

Permalink
Small fixes and docs (MBB-team#53)
Browse files Browse the repository at this point in the history
* cleanup demot for GLM, logisitcRegression, Qlearning adn Qlearning simulation

* better unit testing (avoid side effects with inital env)

* fix for issue when subjects had different number of observations in MFX
  • Loading branch information
lionel-rigoux authored Feb 27, 2018
1 parent c32ac71 commit e0d937f
Show file tree
Hide file tree
Showing 19 changed files with 839 additions and 244 deletions.
129 changes: 129 additions & 0 deletions VBA_BPA.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
function [p_BPA] = VBA_BPA(priors0,posteriors0)
% performs Bayesian Parameters Averaging (BPA)
% function [p,o] = VBA_BPA(posterior,F)
% IN:
% - priors0: a Kx1 cell-array of VBA prior structures (where K is the number of subjects)
% - posteriors0: a Kx1 cell-array of VBA posterior structures (where K is the number of subjects)
% OUT:
% - p_BPA: the resulting posterior structure, with the first two moments of
% the group-level probability density functions

p0 = priors0;
p1 = posteriors0;

% observation parameters
[p_BPA.muPhi,p_BPA.SigmaPhi] = get2moments(p0,p1,'Phi');

% evolution parameters
[p_BPA.muTheta,p_BPA.SigmaTheta] = get2moments(p0,p1,'Theta');

% initial conditions
[p_BPA.muX0,p_BPA.SigmaX0] = get2moments(p0,p1,'X0');

% hidden states
p_BPA.muX=[];
p_BPA.SigmaX=[];

% data precision
p_BPA.b_sigma =[];
p_BPA.a_sigma =[];

% hidden state precision
p_BPA.b_alpha = [];
p_BPA.a_alpha =[];

end

function [muG,sigmaG] = get2moments(p0,p1,paramType)

K = length(p0); % # subjects


% define group priors
mu0 = p0{1}.(['mu' paramType]);
sigma0 = p0{1}.(['Sigma' paramType]);
muG = mu0;
sigmaG = sigma0;
muSub = zeros(numel(mu0),K);
sigmaSub = cell(1,K);
a0 = ones(numel(mu0),1);
b0 = ones(numel(mu0),1);
aG = a0;
bG = b0;
ind_ffx = find(infLimit(a0,b0)==1);
ind_in = find(diag(sigma0)~=0);

% loop across subjects
for k=1:K

% subject-level posterior
mu = p1{k}.(['mu' paramType]);
sigma = p1{k}.(['Sigma' paramType]);

% % update
% tempSigma = inv( inv(sigmaG) + inv(sigma) );
% muG = tempSigma*inv(sigmaG)*muG + tempSigma*inv(sigma)*mu;
% sigmaG = tempSigma ;

% store
muSub(:,k) = mu;
sigmaSub{k} = sigma;

end

% VB-updating
[muG,sigmaG,aG,bG] = MFX_VBupdate(muG,VBA_inv(sigmaG),...
muSub,sigmaSub,...
aG,bG,...
a0,b0,...
ind_ffx,ind_in);



end

function [m,V,a,b] = MFX_VBupdate(m0,iV0,ms,Vs,a,b,a0,b0,indffx,indIn)
ns = size(ms,2);
n = size(m0,1);
sm = 0;
sv = 0;
wsm = 0;
sP = 0;
indrfx = setdiff(1:n,indffx);
indrfx = intersect(indrfx,indIn);
indffx = intersect(indffx,indIn);
iQ = diag(a(indrfx)./b(indrfx));
for i=1:ns
% RFX
sm = sm + ms(indrfx,i);
e = ms(indrfx,i)-m0(indrfx);
sv = sv + e.^2 + diag(Vs{i}(indrfx,indrfx));
% FFX
tmp = VBA_inv(Vs{i});
wsm = wsm + tmp*ms(:,i);
sP = sP + tmp;
end
% RFX
V = zeros(n,n);
m = m0;
V(indrfx,indrfx) = VBA_inv(iV0(indrfx,indrfx)+ns*iQ);
m(indrfx) = V(indrfx,indrfx)*(iV0(indrfx,indrfx)*m0(indrfx)+iQ*sm);
a(indrfx) = a0(indrfx) + 0.5*ns;
b(indrfx) = b0(indrfx) + 0.5*(sv(indrfx)+ns*diag(V(indrfx,indrfx)));
% FFX
if ~isempty(indffx)
tmp = VBA_inv(sP);
V(indffx,indffx) = tmp(indffx,indffx);
m(indffx) = V(indffx,indffx)*wsm(indffx);
end
end

function il = infLimit(a,b)
il = isinf(a).*isequal(b,0);
end






13 changes: 12 additions & 1 deletion VBA_MFX.m
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@
fprintf(1,'%6.2f %%',0)
end
kernelSize0 = 0; % max lag of volterra kernel

% save here to acces subject specific trial numbers later
if numel(dim.n_t) == 1
n_t = repmat(dim.n_t,1,ns);
else
n_t = dim.n_t;
end

for i=1:ns
if opt.verbose
fprintf(1,repmat('\b',1,8))
Expand Down Expand Up @@ -203,6 +211,9 @@
options{i} = VBA_check_struct(options{i},'kernelSize',16);
kernelSize0 = max([kernelSize0,options{i}.kernelSize]);
options{i}.kernelSize = 0;

dim.n_t = n_t(i); % subject number of trials

[p_sub{i},o_sub{i}] = VBA_NLStateSpaceModel(y{i},u{i},f_fname,g_fname,dim,options{i});
% store options for future inversions
options{i} = o_sub{i}.options;
Expand Down Expand Up @@ -492,6 +503,6 @@
S = 0.5*n*(1+log(2*pi)) + 0.5*VBA_logDet(V);

function il = infLimit(a,b)
il = isinf(a).*isequal(b,0);
il = isinf(a).*eq(b,0);


6 changes: 5 additions & 1 deletion VBA_ReDisplay.m
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@

function mySummary(hfig)

try hfig; catch, hfig = get(gco,'parent'); end
try
hfig;
catch
hfig = get(gco,'parent');
end
cleanPanel(hfig);

ud = get(hfig,'userdata');
Expand Down
5 changes: 3 additions & 2 deletions VBA_unit_tests.m
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@


%% find demos
[~,list]=system('find . -name demo*') ;
vba_info = VBA_version();
[~,list]=system(['find ' vba_info.path ' -name demo_*']) ;

demos = {};
for p = strsplit(list)
Expand All @@ -14,7 +15,7 @@
end

% setup for base
setup = 'pause off; warning off; ' ;
setup = 'pause off; warning off; clear all; close all; ' ;


%% run demos
Expand Down
102 changes: 102 additions & 0 deletions demos/behavioural/demo_Qlearning.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
function [posterior, out]=demo_Qlearning(choices, feedbacks)
% // VBA toolbox //////////////////////////////////////////////////////////
%
% [posterior, out] = demo_Qlearning([choices, feedbacks])
% Demo of Q-learning simulation and inference
%
% This is a simple example of reinforcement learning algorithm.
% This demo
%
% Background:
% ~~~~~~~~~~~
% In psychological terms, motivation can be defined as the set of processes
% that generate goals and thus determine behaviour. A goal is nothing else
% than a 'state of affairs', to which people attribute (subjective) value.
% Empirically speaking, one can access these values by many means,
% including subjective verbal report or decision making. These measures
% have been used to demonstrate how value change as people learn a new
% operant response. This is predicted by reinforcement learning theories,
% which essentially relate behavioural response frequency to reward. In
% this context, value is expected reward, and it changes in proportion to
% the agent's prediction error, i.e. the difference between actual and
% expected reward.
%
% /////////////////////////////////////////////////////////////////////////

% check inputs
% =========================================================================

switch nargin
case 0
fprintf('No inputs provided, generating simulated behavior...\n\n');
[choices, feedbacks, simulation]=demo_QlearningSimulation();
case 2
fprintf('Performing inversion of provided behaviour...\n\n');
otherwise
error('*** Wrong number of arguments.')
end

% reformat data
% =========================================================================
% observations
y = choices;
% inputs
u = [ nan, choices(1:end-1) ; % previous choice
nan, feedbacks(1:end-1) ]; % previous feedback

% specify model
% =========================================================================
f_fname = @f_Qlearning; % evolution function (Q-learning)
g_fname = @g_softmax; % observation function (softmax mapping)

% provide dimensions
dim = struct( ...
'n', 2, ... number of hidden states (2 Q-values)
'n_theta', 1, ... number of evolution parameters (1: learning rate)
'n_phi', 1 ... number of observation parameters (1: temperature)
);

% options for the simulation
% -------------------------------------------------------------------------
% use the default priors except for the initial state
options.priors.muX0 = [0.5; 0.5];
options.priors.SimaX0 = 0.1 * eye(2);

% options for the simulation
% -------------------------------------------------------------------------
% number of trials
n_t = numel(choices);
% fitting binary data
options.binomial = 1;
% Normally, the expected first observation is g(x1), ie. after
% a first iteratition x1 = f(x0, u0). The skipf flag will prevent this evolution
% and thus set x1 = x0
options.skipf = [1 zeros(1,n_t-1)];

% invert model
% =========================================================================
[posterior, out] = VBA_NLStateSpaceModel(y, u, f_fname, g_fname, dim, options);

% display estimated parameters:
% -------------------------------------------------------------------------
fprintf('=============================================================\n');
fprintf('\nEstimated parameters: \n');
fprintf(' - learning rate: %3.2f\n', sigm(posterior.muTheta));
fprintf(' - inverse temp.: %3.2f\n\n', exp(posterior.muPhi));
fprintf('=============================================================\n');

% invert model
% =========================================================================
if exist('simulation','var') % used simulated data from demo_QlearningSimulation
displayResults( ...
posterior, ...
out, ...
choices, ...
simulation.state, ...
simulation.initial, ...
simulation.evolution, ...
simulation.observation, ...
Inf, Inf ...
);
end

Loading

0 comments on commit e0d937f

Please sign in to comment.