Skip to content

Commit

Permalink
document state_indexing for discrete_ebm
Browse files Browse the repository at this point in the history
  • Loading branch information
Salem Lahlou committed Nov 26, 2024
1 parent 6f132a8 commit b9f777f
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions src/gfn/gym/discrete_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ def __init__(
):
"""Discrete EBM environment.
The states are represented as 1d tensors of length `ndim` with values in
{-1, 0, 1}. s0 is empty (represented as -1), so s0=[-1, -1, ..., -1],
An action corresponds to replacing a -1 with a 0 or a 1.
Action i in [0, ndim - 1] corresponds to replacing s[i] with 0
Action i in [ndim, 2 * ndim - 1] corresponds to replacing s[i - ndim] with 1
The last action is the exit action that is only available for complete states (those with no -1)
Args:
ndim: dimension D of the sampling space {0, 1}^D.
energy: energy function of the EBM. Defaults to None. If
Expand All @@ -90,7 +97,7 @@ def __init__(

n_actions = 2 * ndim + 1
# the last action is the exit action that is only available for complete states
# Action i in [0, ndim - 1] corresponds to replacing s[i] with 0
#
# Action i in [ndim, 2 * ndim - 1] corresponds to replacing s[i - ndim] with 1

if preprocessor_name == "Identity":
Expand Down Expand Up @@ -207,7 +214,12 @@ def log_reward(self, final_states: DiscreteStates) -> torch.Tensor:
return log_reward

def get_states_indices(self, states: DiscreteStates) -> torch.Tensor:
"""The chosen encoding is the following: -1 -> 0, 0 -> 1, 1 -> 2, then we convert to base 3
"""Given that each state is of length ndim with values in {-1, 0, 1},
there are 3**ndim states, which we can label from 0 to 3**ndim - 1.
The easiest way to map each state to a unique integer is to consider the
state as a number in base 3, where each digit can be in {0, 1, 2}.
We thus need to shift this number by 1 so that {-1, 0, 1} -> {0, 1, 2}.
Args:
states: DiscreteStates object representing the states.
Expand All @@ -221,7 +233,11 @@ def get_states_indices(self, states: DiscreteStates) -> torch.Tensor:
return states_indices

def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor:
"""Get the indices of the terminating states in the canonical ordering from the submitted states.
"""Given that each terminating state is of length ndim with values in {0, 1},
there are 2**ndim terminating states, which we can label from 0 to 2**ndim - 1.
The easiest way to map each state to a unique integer is to consider the
state as a number in base 2.
Args:
states: DiscreteStates object representing the states.
Expand Down

0 comments on commit b9f777f

Please sign in to comment.