Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into graph-states
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Jan 12, 2025
2 parents 27d192a + 7ae5839 commit 5d99739
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 9 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
name: Python package

on: [push]
on:
pull_request:
push:
branches: [master]

jobs:
build:
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
# This GitHub Action assumes that the repo contains a valid .pre-commit-config.yaml file.
---
name: pre-commit
on: [push]
on:
pull_request:
push:
branches: [master]

permissions:
contents: read
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/python-package-conda.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
name: Python Package using Conda

on: [push]
on:
pull_request:
push:
branches: [master]

jobs:
build-linux:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ git clone https://github.com/GFNOrg/torchgfn.git
conda create -n gfn python=3.10
conda activate gfn
cd torchgfn
pip install .
pip install -e ".[all]"
```


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ all = [
"wandb",
]

[project.urls]
[tool.poetry.urls]
"Homepage" = "https://gfn.readthedocs.io/en/latest/"
"Bug Tracker" = "https://github.com/saleml/gfn/issues"

Expand Down
22 changes: 18 additions & 4 deletions src/gfn/gym/discrete_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ def __init__(
):
"""Discrete EBM environment.
The states are represented as 1d tensors of length `ndim` with values in
{-1, 0, 1}. s0 is empty (represented as -1), so s0=[-1, -1, ..., -1],
An action corresponds to replacing a -1 with a 0 or a 1.
Action i in [0, ndim - 1] corresponds to replacing s[i] with 0
Action i in [ndim, 2 * ndim - 1] corresponds to replacing s[i - ndim] with 1
The last action is the exit action that is only available for complete states (those with no -1)
Args:
ndim: dimension D of the sampling space {0, 1}^D.
energy: energy function of the EBM. Defaults to None. If
Expand All @@ -90,8 +97,6 @@ def __init__(

n_actions = 2 * ndim + 1
# the last action is the exit action that is only available for complete states
# Action i in [0, ndim - 1] corresponds to replacing s[i] with 0
# Action i in [ndim, 2 * ndim - 1] corresponds to replacing s[i - ndim] with 1

if preprocessor_name == "Identity":
preprocessor = IdentityPreprocessor(output_dim=ndim)
Expand Down Expand Up @@ -207,7 +212,12 @@ def log_reward(self, final_states: DiscreteStates) -> torch.Tensor:
return log_reward

def get_states_indices(self, states: DiscreteStates) -> torch.Tensor:
"""The chosen encoding is the following: -1 -> 0, 0 -> 1, 1 -> 2, then we convert to base 3
"""Given that each state is of length ndim with values in {-1, 0, 1},
there are 3**ndim states, which we can label from 0 to 3**ndim - 1.
The easiest way to map each state to a unique integer is to consider the
state as a number in base 3, where each digit can be in {0, 1, 2}.
We thus need to shift this number by 1 so that {-1, 0, 1} -> {0, 1, 2}.
Args:
states: DiscreteStates object representing the states.
Expand All @@ -221,7 +231,11 @@ def get_states_indices(self, states: DiscreteStates) -> torch.Tensor:
return states_indices

def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor:
"""Get the indices of the terminating states in the canonical ordering from the submitted states.
"""Given that each terminating state is of length ndim with values in {0, 1},
there are 2**ndim terminating states, which we can label from 0 to 2**ndim - 1.
The easiest way to map each state to a unique integer is to consider the
state as a number in base 2.
Args:
states: DiscreteStates object representing the states.
Expand Down

0 comments on commit 5d99739

Please sign in to comment.