diff --git a/src/gfn/gym/discrete_ebm.py b/src/gfn/gym/discrete_ebm.py index c7d0da60..4b601e67 100644 --- a/src/gfn/gym/discrete_ebm.py +++ b/src/gfn/gym/discrete_ebm.py @@ -67,6 +67,13 @@ def __init__( ): """Discrete EBM environment. + The states are represented as 1d tensors of length `ndim` with values in + {-1, 0, 1}. s0 is empty (represented as -1), so s0=[-1, -1, ..., -1], + An action corresponds to replacing a -1 with a 0 or a 1. + Action i in [0, ndim - 1] corresponds to replacing s[i] with 0 + Action i in [ndim, 2 * ndim - 1] corresponds to replacing s[i - ndim] with 1 + The last action is the exit action that is only available for complete states (those with no -1) + Args: ndim: dimension D of the sampling space {0, 1}^D. energy: energy function of the EBM. Defaults to None. If @@ -90,7 +97,7 @@ def __init__( n_actions = 2 * ndim + 1 # the last action is the exit action that is only available for complete states - # Action i in [0, ndim - 1] corresponds to replacing s[i] with 0 + # # Action i in [ndim, 2 * ndim - 1] corresponds to replacing s[i - ndim] with 1 if preprocessor_name == "Identity": @@ -207,7 +214,12 @@ def log_reward(self, final_states: DiscreteStates) -> torch.Tensor: return log_reward def get_states_indices(self, states: DiscreteStates) -> torch.Tensor: - """The chosen encoding is the following: -1 -> 0, 0 -> 1, 1 -> 2, then we convert to base 3 + """Given that each state is of length ndim with values in {-1, 0, 1}, + there are 3**ndim states, which we can label from 0 to 3**ndim - 1. + + The easiest way to map each state to a unique integer is to consider the + state as a number in base 3, where each digit can be in {0, 1, 2}. + We thus need to shift this number by 1 so that {-1, 0, 1} -> {0, 1, 2}. Args: states: DiscreteStates object representing the states. @@ -221,7 +233,11 @@ def get_states_indices(self, states: DiscreteStates) -> torch.Tensor: return states_indices def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor: - """Get the indices of the terminating states in the canonical ordering from the submitted states. + """Given that each terminating state is of length ndim with values in {0, 1}, + there are 2**ndim terminating states, which we can label from 0 to 2**ndim - 1. + + The easiest way to map each state to a unique integer is to consider the + state as a number in base 2. Args: states: DiscreteStates object representing the states.