diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index 1884fa08d..d81ac1dea 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -556,7 +556,10 @@ def _compute_parents(self): """ self.parents = [] self.parents_indices = [] - indices = [] + + indices_dict = {} + indices_next = 0 + # Iterate over the trajectories to obtain the parents from the states for traj_idx, batch_indices in self.trajectories.items(): # parent is source @@ -567,12 +570,18 @@ def _compute_parents(self): # TODO: check if tensor and sort without iter self.parents.extend([self.states[idx] for idx in batch_indices[:-1]]) self.parents_indices.extend(batch_indices[:-1]) - indices.extend(batch_indices) + + # Store the indices required to reorder the parents lists in the same + # order as the states + for b_idx in batch_indices: + indices_dict[b_idx] = indices_next + indices_next += 1 + # Sort parents list in the same order as states # TODO: check if tensor and sort without iter - self.parents = [self.parents[indices.index(idx)] for idx in range(len(self))] + self.parents = [self.parents[indices_dict[idx]] for idx in range(len(self))] self.parents_indices = tlong( - [self.parents_indices[indices.index(idx)] for idx in range(len(self))], + [self.parents_indices[indices_dict[idx]] for idx in range(len(self))], device=self.device, ) self.parents_available = True diff --git a/tests/gflownet/envs/common.py b/tests/gflownet/envs/common.py index c499995dc..cf8dd2448 100644 --- a/tests/gflownet/envs/common.py +++ b/tests/gflownet/envs/common.py @@ -66,7 +66,6 @@ def test__step__returns_same_state_action_and_invalid_if_done(self, n_repeat=1): def test__sample_actions__backward__returns_eos_if_done( self, n_repeat=1, n_states=5 ): - if _get_current_method_name() in self.n_states: n_states = self.n_states[_get_current_method_name()] @@ -96,7 +95,6 @@ def test__sample_actions__backward__returns_eos_if_done( def test__get_logprobs__backward__returns_zero_if_done( self, n_repeat=1, n_states=5 ): - if _get_current_method_name() in self.n_states: n_states = self.n_states[_get_current_method_name()] @@ -161,7 +159,6 @@ def test__forward_actions_have_nonzero_backward_prob(self, n_repeat=1): def test__backward_actions_have_nonzero_forward_prob( self, n_repeat=1, n_states=100 ): - if _get_current_method_name() in self.n_states: n_states = self.n_states[_get_current_method_name()] @@ -398,7 +395,6 @@ def test__state2readable__is_reversible(self, n_repeat=1): def test__get_parents__returns_same_state_and_eos_if_done( self, n_repeat=1, n_states=10 ): - if _get_current_method_name() in self.n_states: n_states = self.n_states[_get_current_method_name()]