-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
State inference: Considerable performance difference with MATLAB implementation? #82
Comments
Hi @SamGijsen, Thanks for your interest and the nice code snippet you shared -- I'm glad to see people are using There are a few issues with the code you've shared that immediately pop out to me, and are almost certainly the cause of the 'stubborn', suboptimal behaviour you're observing in your agents. I'll describe them in order of decreasing importance:
I hope this helps! |
Thanks you for your quick and elaborate reply @conorheins - that's very much appreciated.
pD = D.copy()
pD[1] = pD[1]*0.1
model = Agent(A=A, B=B, C=C, D=D, pD=pD, policy_len=1, action_selection=AS)
for trial in range(number_of_trials):
model.reset()
model.D = utils.norm_dist_obj_arr(model.pD)
...
model.update_D(qs_t0 = model.qs) ultimately yields a pD distribution of something akin to [\alpha1, \alpha2] ~= [19,2] after the initial 20 trials. However, after 20 more trials in the other state, we approach a [20,20] distribution. Thus, it takes many trials before the expectation of pD starts to align with the true but unknown hidden state. Please let me know if I misunderstood the suggestion in any way. Alternatively, I'm guessing some additional learning algorithm or even simple forgetting on the pD parameters to keep precision relatively low should do reasonably well in combination with setup described above.
Offtopic: running "stochastic" action selection errors for me when using the current version. I believe it's because actions are sampled even if num_controls=1, in which case utils.sample() attempts to squeeze an array of size 1. Reverting to deterministic selection if |
Hi @SamGijsen, sure thing -- happy to help out, especially if it leads to a possible improvement of
This is an astute observation -- basically, you need to "dig yourself" out of the Side-note: I know there is this paper by @AnnaCSales et al. where they implement a sort of 'decay' rate on learning the Dirichlet hyperparameters (consult equation for updating
The choice to use MMP shouldn't really be appropriate here, unless I'm making incorrect assumptions about the temporal structure of each trial in your task. If each trial lasts more than 1 timestep (
Thanks a lot for noticing this! Can you create an issue describing this issue (you can reference your comment in this thread as well)? |
No worries, that's very understandable, I'm thankful for the comments. :)
Ah, this might be the point of discrepancy as you suggest. My understanding of modeling the current task in the MATLAB toolbox would require T=2 because the policy needs to transition from a certain state into another, requiring a "start state" and an "ending state" (modeled as A[1] in the opening post). The policy depth would indeed be one. This should indeed then give rise to pre- and post-dictive beliefs. My approach should be similar to this repo from this paper, basically no learning but only state inference (Im using learning here in the sense of accruing Dirichlet-parameters). I'll share some example MATLAB code illustrating my understanding below. However, after taking a look at the paper you linked (thank you for this!) I must say its method to flexibly adapt 'd'-parameters seems a much more appropriate way to deal with reversal learning (rather than pure state inference) and is what I was hinting at previously. I'm going to have a look at how such a model behaves.
Sure, will do! MATLAB code Trial structure with a for-loop: mdp = {};
num_trials = 40;
z = [1:20];
pstate = [];
true_state = [];
for i = 1:num_trials
MDP = simple_model(0.9, 0.9);
if i>1
MDP.D{2} = X;
end
if sum(z==i) == 1
[MDP.s] = [1 2]';
else % switch state
[MDP.s] = [1 1]';
end
MDP_p = spm_MDP_VB_X(MDP);
mdp{i} = MDP_p;
X = MDP_p.X{2}(:,end);
true_state = [ true_state MDP.s(2) ];
pstate = [ pstate mdp{i}.X{2}(2,2) ];
end
plot(pstate, 'k'); hold on
plot(true_state-1, 'r') SPM model: function mdp = simple_model(r1, r2)
D{1} = [1 0 0]'; % Choice State {start, choose B1, choose B2}
D{2} = [1 1]'; % Context (B1 better, B2 better)
% P(o|s)
Nf = numel(D);
for f = 1:Nf
Ns(f) = numel(D{f});
end
No = [3 3]; % 3 rewards [Pos, Neg, Neutral], 3 choice observations
Ng = numel(No);
for g = 1:Ng
A{g} = zeros([No(g),Ns]);
end
%S, b1, b2
A{1}(:,:,1) = [0, 1-r1, r2; % Pos reward
0, r1, 1-r2; % neg reward
1, 0, 0];% Neutral reward
A{1}(:,:,2) = [0, r1, 1-r2; % Pos reward
0, 1-r1, r2; % neg reward
1, 0, 0];% Neutral reward
A{2}(:,:,1) = eye(3);
A{2}(:,:,2) = eye(3);
a{1} = A{1};
a{1}(1:2,:,:) = a{1}(1:2,:,:)*2;
a{1}(3,:,:) = a{1}(3,:,:)*10000;
a{2} = A{2}*10000;
% B{state factor}(state at time tau+1, state at time tau, action number)
for f = 1:Nf
B{f} = zeros(Ns(f));
end
% action 1: move from start to button 1
% a2: start to b2
for c = 1:2
B{1}(1,:,c) = ones(1,3);
end
B{1}(:,1,1) = [0 1 0]'; % start -> b1
B{1}(:,1,2) = [0 0 1]'; % start -> b2
B{2} = eye(2);
% Rewards
%t1 2
C{1} = [0 2.5; % pos
0 -2.5; % neg
0 0 ];% neu
C{2} = zeros(3);
% press b1 or b2
U(:,:,1) = [1 2];
U(:,:,2) = [1 1];
mdp.A = A;
%mdp.a = a; % likelihood learning on/off
mdp.B = B;
mdp.C = C;
mdp.D = D;
mdp.U = U;
mdp.T = 2;
mdp.Aname = ["RewardsA", "ChoicesA"];
mdp.Bname = ["ChoicesB", "Context"];
mdp.label.modality = {"Choice State"};
end |
Hi all,
First of, many thanks for your work, it looks very promising!
I was comparing some simple agents between the current and matlab implementation and noticed that in terms of reversal learning, the pymdp version appears to adapt considerably slower. I've played around with a variety of setups and hyperparameters but the difference is quite significant.
Example setup: slot machine task without hints
2 actions: 'button' 0 and 'button' 1
A 'button 0 is better' context and a 'button 1 is better' context.
40 trials, with the hidden context switching after 20 trials. Here I plot the state posterior (black) of 100 agents starting with flat context priors, compared to the true but unknown state/context (red). Below I'll include the pymdp code. I'm assuming I'm using the package wrong, and would love to know my misunderstanding.
The text was updated successfully, but these errors were encountered: