From a12d1280b29d2de2679a10e0d991c9aaeebad13a Mon Sep 17 00:00:00 2001 From: saleml Date: Sat, 11 Jan 2025 21:47:31 +0400 Subject: [PATCH 1/3] Fixpyproject (#225) * fix pyproject * make all the default install * make default the editable mode --------- Co-authored-by: Salem Lahlou --- README.md | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 80687343..a00ec643 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ git clone https://github.com/GFNOrg/torchgfn.git conda create -n gfn python=3.10 conda activate gfn cd torchgfn -pip install . +pip install -e ".[all]" ``` diff --git a/pyproject.toml b/pyproject.toml index 0523821a..55583eca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,7 @@ all = [ "wandb", ] -[project.urls] +[tool.poetry.urls] "Homepage" = "https://gfn.readthedocs.io/en/latest/" "Bug Tracker" = "https://github.com/saleml/gfn/issues" From 65e5f47121d77927db4c35c61d10dc750938e059 Mon Sep 17 00:00:00 2001 From: saleml Date: Sat, 11 Jan 2025 21:50:24 +0400 Subject: [PATCH 2/3] document state_indexing for discrete_ebm (#222) * document state_indexing for discrete_ebm * remove useless comment --------- Co-authored-by: Salem Lahlou --- src/gfn/gym/discrete_ebm.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/gfn/gym/discrete_ebm.py b/src/gfn/gym/discrete_ebm.py index c7d0da60..9180e8b3 100644 --- a/src/gfn/gym/discrete_ebm.py +++ b/src/gfn/gym/discrete_ebm.py @@ -67,6 +67,13 @@ def __init__( ): """Discrete EBM environment. + The states are represented as 1d tensors of length `ndim` with values in + {-1, 0, 1}. s0 is empty (represented as -1), so s0=[-1, -1, ..., -1], + An action corresponds to replacing a -1 with a 0 or a 1. + Action i in [0, ndim - 1] corresponds to replacing s[i] with 0 + Action i in [ndim, 2 * ndim - 1] corresponds to replacing s[i - ndim] with 1 + The last action is the exit action that is only available for complete states (those with no -1) + Args: ndim: dimension D of the sampling space {0, 1}^D. energy: energy function of the EBM. Defaults to None. If @@ -90,8 +97,6 @@ def __init__( n_actions = 2 * ndim + 1 # the last action is the exit action that is only available for complete states - # Action i in [0, ndim - 1] corresponds to replacing s[i] with 0 - # Action i in [ndim, 2 * ndim - 1] corresponds to replacing s[i - ndim] with 1 if preprocessor_name == "Identity": preprocessor = IdentityPreprocessor(output_dim=ndim) @@ -207,7 +212,12 @@ def log_reward(self, final_states: DiscreteStates) -> torch.Tensor: return log_reward def get_states_indices(self, states: DiscreteStates) -> torch.Tensor: - """The chosen encoding is the following: -1 -> 0, 0 -> 1, 1 -> 2, then we convert to base 3 + """Given that each state is of length ndim with values in {-1, 0, 1}, + there are 3**ndim states, which we can label from 0 to 3**ndim - 1. + + The easiest way to map each state to a unique integer is to consider the + state as a number in base 3, where each digit can be in {0, 1, 2}. + We thus need to shift this number by 1 so that {-1, 0, 1} -> {0, 1, 2}. Args: states: DiscreteStates object representing the states. @@ -221,7 +231,11 @@ def get_states_indices(self, states: DiscreteStates) -> torch.Tensor: return states_indices def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor: - """Get the indices of the terminating states in the canonical ordering from the submitted states. + """Given that each terminating state is of length ndim with values in {0, 1}, + there are 2**ndim terminating states, which we can label from 0 to 2**ndim - 1. + + The easiest way to map each state to a unique integer is to consider the + state as a number in base 2. Args: states: DiscreteStates object representing the states. From 9cd94a83c86ea894c1c299fe6e23a1294bc14f23 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sun, 12 Jan 2025 00:39:53 +0100 Subject: [PATCH 3/3] make CI running in PRs --- .github/workflows/ci.yml | 5 ++++- .github/workflows/pre-commit.yml | 5 ++++- .github/workflows/python-package-conda.yml | 5 ++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 74279369..61d8e9f7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,6 +1,9 @@ name: Python package -on: [push] +on: + pull_request: + push: + branches: [master] jobs: build: diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 6b552a0d..3f9863b2 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -2,7 +2,10 @@ # This GitHub Action assumes that the repo contains a valid .pre-commit-config.yaml file. --- name: pre-commit -on: [push] +on: + pull_request: + push: + branches: [master] permissions: contents: read diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml index 9d45abb2..aec0afe1 100644 --- a/.github/workflows/python-package-conda.yml +++ b/.github/workflows/python-package-conda.yml @@ -1,6 +1,9 @@ name: Python Package using Conda -on: [push] +on: + pull_request: + push: + branches: [master] jobs: build-linux: