From de9edeaea73b346f041d684fdd131f2623f51722 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 14 Nov 2024 20:12:12 -0500 Subject: [PATCH] terminating_states are now calculated using self.all_states --- src/gfn/gym/hypergrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index ac76a8df..d206dff5 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -178,7 +178,7 @@ def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor Returns the indices of the terminating states in the canonical ordering as a tensor of shape `batch_shape`. """ - return self.get_states_indices(states) + return self.get_states_indices(self.all_states) @property def n_states(self) -> int: