diff --git a/atari/mingpt/model_atari.py b/atari/mingpt/model_atari.py index 8428811c..68219e2d 100644 --- a/atari/mingpt/model_atari.py +++ b/atari/mingpt/model_atari.py @@ -265,7 +265,7 @@ def forward(self, states, actions, targets=None, rtgs=None, timesteps=None): if actions is not None and self.model_type == 'reward_conditioned': logits = logits[:, 1::3, :] # only keep predictions from state_embeddings elif actions is None and self.model_type == 'reward_conditioned': - logits = logits[:, 1:, :] + logits = logits[:, 1::2, :] elif actions is not None and self.model_type == 'naive': logits = logits[:, ::2, :] # only keep predictions from state_embeddings elif actions is None and self.model_type == 'naive':