forked from MBB-team/VBA-toolbox
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* cleanup demot for GLM, logisitcRegression, Qlearning adn Qlearning simulation * better unit testing (avoid side effects with inital env) * fix for issue when subjects had different number of observations in MFX
- Loading branch information
1 parent
c32ac71
commit e0d937f
Showing
19 changed files
with
839 additions
and
244 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
function [p_BPA] = VBA_BPA(priors0,posteriors0) | ||
% performs Bayesian Parameters Averaging (BPA) | ||
% function [p,o] = VBA_BPA(posterior,F) | ||
% IN: | ||
% - priors0: a Kx1 cell-array of VBA prior structures (where K is the number of subjects) | ||
% - posteriors0: a Kx1 cell-array of VBA posterior structures (where K is the number of subjects) | ||
% OUT: | ||
% - p_BPA: the resulting posterior structure, with the first two moments of | ||
% the group-level probability density functions | ||
|
||
p0 = priors0; | ||
p1 = posteriors0; | ||
|
||
% observation parameters | ||
[p_BPA.muPhi,p_BPA.SigmaPhi] = get2moments(p0,p1,'Phi'); | ||
|
||
% evolution parameters | ||
[p_BPA.muTheta,p_BPA.SigmaTheta] = get2moments(p0,p1,'Theta'); | ||
|
||
% initial conditions | ||
[p_BPA.muX0,p_BPA.SigmaX0] = get2moments(p0,p1,'X0'); | ||
|
||
% hidden states | ||
p_BPA.muX=[]; | ||
p_BPA.SigmaX=[]; | ||
|
||
% data precision | ||
p_BPA.b_sigma =[]; | ||
p_BPA.a_sigma =[]; | ||
|
||
% hidden state precision | ||
p_BPA.b_alpha = []; | ||
p_BPA.a_alpha =[]; | ||
|
||
end | ||
|
||
function [muG,sigmaG] = get2moments(p0,p1,paramType) | ||
|
||
K = length(p0); % # subjects | ||
|
||
|
||
% define group priors | ||
mu0 = p0{1}.(['mu' paramType]); | ||
sigma0 = p0{1}.(['Sigma' paramType]); | ||
muG = mu0; | ||
sigmaG = sigma0; | ||
muSub = zeros(numel(mu0),K); | ||
sigmaSub = cell(1,K); | ||
a0 = ones(numel(mu0),1); | ||
b0 = ones(numel(mu0),1); | ||
aG = a0; | ||
bG = b0; | ||
ind_ffx = find(infLimit(a0,b0)==1); | ||
ind_in = find(diag(sigma0)~=0); | ||
|
||
% loop across subjects | ||
for k=1:K | ||
|
||
% subject-level posterior | ||
mu = p1{k}.(['mu' paramType]); | ||
sigma = p1{k}.(['Sigma' paramType]); | ||
|
||
% % update | ||
% tempSigma = inv( inv(sigmaG) + inv(sigma) ); | ||
% muG = tempSigma*inv(sigmaG)*muG + tempSigma*inv(sigma)*mu; | ||
% sigmaG = tempSigma ; | ||
|
||
% store | ||
muSub(:,k) = mu; | ||
sigmaSub{k} = sigma; | ||
|
||
end | ||
|
||
% VB-updating | ||
[muG,sigmaG,aG,bG] = MFX_VBupdate(muG,VBA_inv(sigmaG),... | ||
muSub,sigmaSub,... | ||
aG,bG,... | ||
a0,b0,... | ||
ind_ffx,ind_in); | ||
|
||
|
||
|
||
end | ||
|
||
function [m,V,a,b] = MFX_VBupdate(m0,iV0,ms,Vs,a,b,a0,b0,indffx,indIn) | ||
ns = size(ms,2); | ||
n = size(m0,1); | ||
sm = 0; | ||
sv = 0; | ||
wsm = 0; | ||
sP = 0; | ||
indrfx = setdiff(1:n,indffx); | ||
indrfx = intersect(indrfx,indIn); | ||
indffx = intersect(indffx,indIn); | ||
iQ = diag(a(indrfx)./b(indrfx)); | ||
for i=1:ns | ||
% RFX | ||
sm = sm + ms(indrfx,i); | ||
e = ms(indrfx,i)-m0(indrfx); | ||
sv = sv + e.^2 + diag(Vs{i}(indrfx,indrfx)); | ||
% FFX | ||
tmp = VBA_inv(Vs{i}); | ||
wsm = wsm + tmp*ms(:,i); | ||
sP = sP + tmp; | ||
end | ||
% RFX | ||
V = zeros(n,n); | ||
m = m0; | ||
V(indrfx,indrfx) = VBA_inv(iV0(indrfx,indrfx)+ns*iQ); | ||
m(indrfx) = V(indrfx,indrfx)*(iV0(indrfx,indrfx)*m0(indrfx)+iQ*sm); | ||
a(indrfx) = a0(indrfx) + 0.5*ns; | ||
b(indrfx) = b0(indrfx) + 0.5*(sv(indrfx)+ns*diag(V(indrfx,indrfx))); | ||
% FFX | ||
if ~isempty(indffx) | ||
tmp = VBA_inv(sP); | ||
V(indffx,indffx) = tmp(indffx,indffx); | ||
m(indffx) = V(indffx,indffx)*wsm(indffx); | ||
end | ||
end | ||
|
||
function il = infLimit(a,b) | ||
il = isinf(a).*isequal(b,0); | ||
end | ||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
function [posterior, out]=demo_Qlearning(choices, feedbacks) | ||
% // VBA toolbox ////////////////////////////////////////////////////////// | ||
% | ||
% [posterior, out] = demo_Qlearning([choices, feedbacks]) | ||
% Demo of Q-learning simulation and inference | ||
% | ||
% This is a simple example of reinforcement learning algorithm. | ||
% This demo | ||
% | ||
% Background: | ||
% ~~~~~~~~~~~ | ||
% In psychological terms, motivation can be defined as the set of processes | ||
% that generate goals and thus determine behaviour. A goal is nothing else | ||
% than a 'state of affairs', to which people attribute (subjective) value. | ||
% Empirically speaking, one can access these values by many means, | ||
% including subjective verbal report or decision making. These measures | ||
% have been used to demonstrate how value change as people learn a new | ||
% operant response. This is predicted by reinforcement learning theories, | ||
% which essentially relate behavioural response frequency to reward. In | ||
% this context, value is expected reward, and it changes in proportion to | ||
% the agent's prediction error, i.e. the difference between actual and | ||
% expected reward. | ||
% | ||
% ///////////////////////////////////////////////////////////////////////// | ||
|
||
% check inputs | ||
% ========================================================================= | ||
|
||
switch nargin | ||
case 0 | ||
fprintf('No inputs provided, generating simulated behavior...\n\n'); | ||
[choices, feedbacks, simulation]=demo_QlearningSimulation(); | ||
case 2 | ||
fprintf('Performing inversion of provided behaviour...\n\n'); | ||
otherwise | ||
error('*** Wrong number of arguments.') | ||
end | ||
|
||
% reformat data | ||
% ========================================================================= | ||
% observations | ||
y = choices; | ||
% inputs | ||
u = [ nan, choices(1:end-1) ; % previous choice | ||
nan, feedbacks(1:end-1) ]; % previous feedback | ||
|
||
% specify model | ||
% ========================================================================= | ||
f_fname = @f_Qlearning; % evolution function (Q-learning) | ||
g_fname = @g_softmax; % observation function (softmax mapping) | ||
|
||
% provide dimensions | ||
dim = struct( ... | ||
'n', 2, ... number of hidden states (2 Q-values) | ||
'n_theta', 1, ... number of evolution parameters (1: learning rate) | ||
'n_phi', 1 ... number of observation parameters (1: temperature) | ||
); | ||
|
||
% options for the simulation | ||
% ------------------------------------------------------------------------- | ||
% use the default priors except for the initial state | ||
options.priors.muX0 = [0.5; 0.5]; | ||
options.priors.SimaX0 = 0.1 * eye(2); | ||
|
||
% options for the simulation | ||
% ------------------------------------------------------------------------- | ||
% number of trials | ||
n_t = numel(choices); | ||
% fitting binary data | ||
options.binomial = 1; | ||
% Normally, the expected first observation is g(x1), ie. after | ||
% a first iteratition x1 = f(x0, u0). The skipf flag will prevent this evolution | ||
% and thus set x1 = x0 | ||
options.skipf = [1 zeros(1,n_t-1)]; | ||
|
||
% invert model | ||
% ========================================================================= | ||
[posterior, out] = VBA_NLStateSpaceModel(y, u, f_fname, g_fname, dim, options); | ||
|
||
% display estimated parameters: | ||
% ------------------------------------------------------------------------- | ||
fprintf('=============================================================\n'); | ||
fprintf('\nEstimated parameters: \n'); | ||
fprintf(' - learning rate: %3.2f\n', sigm(posterior.muTheta)); | ||
fprintf(' - inverse temp.: %3.2f\n\n', exp(posterior.muPhi)); | ||
fprintf('=============================================================\n'); | ||
|
||
% invert model | ||
% ========================================================================= | ||
if exist('simulation','var') % used simulated data from demo_QlearningSimulation | ||
displayResults( ... | ||
posterior, ... | ||
out, ... | ||
choices, ... | ||
simulation.state, ... | ||
simulation.initial, ... | ||
simulation.evolution, ... | ||
simulation.observation, ... | ||
Inf, Inf ... | ||
); | ||
end | ||
|
Oops, something went wrong.