Skip to content

Commit

Permalink
fixed typecasting error
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Feb 17, 2024
1 parent 61c08be commit f41a31f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,9 @@ def set_nonexit_action_masks(self, cond, allow_exit: bool):
trajectory - if so, it should be set to True.
"""
if allow_exit:
exit_idx = torch.zeros(self.batch_shape + (1,))
exit_idx = torch.zeros(self.batch_shape + (1,)).to(cond.device)
else:
exit_idx = torch.ones(self.batch_shape + (1,))
exit_idx = torch.ones(self.batch_shape + (1,)).to(cond.device)
self.forward_masks[torch.cat([cond, exit_idx], dim=-1).bool()] = False

def set_exit_masks(self, batch_idx):
Expand Down

0 comments on commit f41a31f

Please sign in to comment.