Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Easier environment definition #143

Merged
merged 30 commits into from
Nov 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
dc774e9
fix __repr__ of modules - no more env
saleml Sep 1, 2023
ec4b415
Merge pull request #135 from saleml/fixrepr
josephdviviano Sep 4, 2023
6c618f4
bump up version
saleml Sep 4, 2023
298405b
convience functions for common mask operations in DiscreteStates
josephdviviano Oct 20, 2023
38ab3dc
change name of enforce_exit_masks method for consistency
josephdviviano Oct 20, 2023
6ecb60d
added default log reward clipping. Also, reward() is by default not i…
josephdviviano Oct 20, 2023
b77ccad
black formatting
josephdviviano Oct 20, 2023
4b8cf79
updated scripts to be consistent with new mask and reward defintion p…
josephdviviano Oct 20, 2023
a357475
log reward clipping is now -100 (much smaller)
josephdviviano Oct 20, 2023
d3a72bb
I'm actually trying to trigger actions...
josephdviviano Oct 20, 2023
4f6e965
added flake8 back into deps
josephdviviano Oct 20, 2023
e2c74b6
missing quote
josephdviviano Oct 20, 2023
d3bcdf7
letting notebooks back in
josephdviviano Oct 24, 2023
382f396
added smiley FM & TB tutorial
josephdviviano Oct 24, 2023
4a2861a
bugfix: unrelated, but logz was not being learned during the task any…
josephdviviano Oct 24, 2023
77abec6
black and autoflake reformatting
saleml Oct 25, 2023
0400d04
specify flake8 and black versions
saleml Oct 25, 2023
03f6ca6
ignore one error from flake8 and revert version selection
saleml Oct 25, 2023
ad12af3
add wandb for pytest ci, and add quotations around python versions
saleml Oct 25, 2023
e0d750b
remove explicit wandb, and use pip install.[all]
saleml Oct 25, 2023
8678f9c
line environment v1 done
josephdviviano Nov 7, 2023
b1f1040
Merge branch 'easier_environment_definition' of github.com:saleml/tor…
josephdviviano Nov 7, 2023
f4f69d6
smiley tweaks
josephdviviano Nov 7, 2023
0735ceb
isort
josephdviviano Nov 23, 2023
4161dba
typo fixed
josephdviviano Nov 23, 2023
4d8a145
documentation nits and using convienience builtins
josephdviviano Nov 23, 2023
459203c
using set_nonexit_action_masks properly, with a comment for clarity
josephdviviano Nov 23, 2023
fb7e819
Improved method name and documentaiton
josephdviviano Nov 23, 2023
226185f
changed test logic (will run on github instead of this terrible compu…
josephdviviano Nov 23, 2023
84bb169
new targets
josephdviviano Nov 24, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11"]
python-version: ['3.10', '3.11']

steps:
- uses: actions/checkout@v3
Expand All @@ -20,15 +20,15 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[dev]
pip install .[all]
- name: Black Formatting
run: |
black .

- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --ignore=E203,E266,E501,W503,F403,F401 --show-source --statistics
flake8 . --count --select=E9,F63,F7,F82 --ignore=E203,E266,E501,W503,F403,F401,F821 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=18 --max-line-length=89 --statistics

Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/python-package-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ jobs:
- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: 3.10
python-version: '3.10'
- name: Add conda to system path
run: |
# $CONDA is an environment variable pointing to the root of the miniconda directory
echo $CONDA/bin >> $GITHUB_PATH
- name: Install dependencies
run: |
conda env update --file environment.yml --name base
python -m pip install --upgrade pip
pip install .[all]
- name: Lint with flake8
run: |
conda install flake8
Expand All @@ -30,5 +31,4 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
conda install pytest
pytest
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,5 @@ wandb/

scripts.py

*.ipynb
models/
*.DS_Store
*.DS_Store
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@

## Installing the package

The codebase requires python >= 3.10

To install the latest stable version:
The codebase requires python >= 3.10. To install the latest stable version:

```bash
pip install torchgfn
Expand Down
23 changes: 13 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "torchgfn"
packages = [{include = "gfn", from = "src"}]
version = "1.1"
version = "1.1.1"
description = "A torch implementation of GFlowNets"
authors = ["Salem Lahou <[email protected]>", "Joseph Viviano <[email protected]>", "Victor Schmidt <[email protected]>"]
license = "MIT"
Expand All @@ -26,16 +26,17 @@ torch = ">=1.9.0"
torchtyping = ">=0.1.4"

# dev dependencies.
black = { version = "22.3.0", optional = true }
black = { version = "*", optional = true }
flake8 = { version = "*", optional = true }
gitmopy = { version = "*", optional = true }
myst-parser = { version = "*", optional = true }
pre-commit = { version = "*", optional = true }
pytest = { version = "*", optional = true }
renku-sphinx-theme = { version = "*", optional = true }
sphinx = { version = "*", optional = true }
sphinx_rtd_theme = { version = "*", optional = true }
sphinx-autoapi = { version = "*", optional = true }
sphinx-math-dollar = { version = "*", optional = true }
sphinx_rtd_theme = { version = "*", optional = true }
tox = { version = "*", optional = true }

# scripts dependencies.
Expand All @@ -52,30 +53,32 @@ dev = [
"pre-commit",
"pytest",
"renku-sphinx-theme",
"sphinx",
"sphinx_rtd_theme",
"sphinx-autoapi",
"sphinx-math-dollar",
"sphinx_rtd_theme",
"tox"
"sphinx",
"tox",
"flake8",
]

scripts = ["tqdm", "wandb", "scikit-learn", "scipy"]

all = [
"black",
"flake8",
"myst-parser",
"pre-commit",
"pytest",
"renku-sphinx-theme",
"scikit-learn",
"scipy",
"sphinx_rtd_theme",
"sphinx-autoapi",
"sphinx-math-dollar",
"sphinx",
"tox",
"black",
"myst-parser",
"tqdm",
"wandb",
"scikit-learn",
"scipy"
]

[project.urls]
Expand Down
44 changes: 30 additions & 14 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,21 @@ def __init__(
sf: Optional[TT["state_shape", torch.float]] = None,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
log_reward_clip: Optional[float] = -100.0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea !

):
"""Initializes an environment.

Args:
s0: Representation of the initial state. All individual states would be of the same shape.
sf (optional): Representation of the final state. Only used for a human readable representation of
the states or trajectories.
device_str (Optional[str], optional): 'cpu' or 'cuda'. Defaults to None, in which case the device is inferred from s0.
preprocessor (Optional[Preprocessor], optional): a Preprocessor object that converts raw states to a tensor that can be fed
into a neural network. Defaults to None, in which case the IdentityPreprocessor is used.
s0: Representation of the initial state. All individual states would be of
the same shape.
sf: Representation of the final state. Only used for a human
readable representation of the states or trajectories.
device_str: 'cpu' or 'cuda'. Defaults to None, in which case the device is
inferred from s0.
preprocessor: a Preprocessor object that converts raw states to a tensor
that can be fed into a neural network. Defaults to None, in which case
the IdentityPreprocessor is used.
log_reward_clip: Used to clip small rewards (in particular, log(0) rewards).
"""
self.device = torch.device(device_str) if device_str is not None else s0.device

Expand All @@ -53,6 +58,7 @@ def __init__(

self.preprocessor = preprocessor
self.is_discrete = False
self.log_reward_clip = log_reward_clip

@abstractmethod
def make_States_class(self) -> type[States]:
Expand Down Expand Up @@ -184,12 +190,15 @@ def backward_step(
return new_states

def reward(self, final_states: States) -> TT["batch_shape", torch.float]:
"""Either this or log_reward needs to be implemented."""
return torch.exp(self.log_reward(final_states))
"""The environment's reward given a state.

This or log_reward must be implemented.
"""
raise NotImplementedError("Reward function is not implemented.")

def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]:
"""Either this or reward needs to be implemented."""
raise NotImplementedError("log_reward function not implemented")
"""Calculates the log reward (clipping small rewards)."""
return torch.log(self.reward(final_states)).clip(self.log_reward_clip)

@property
def log_partition(self) -> float:
Expand All @@ -203,8 +212,9 @@ class DiscreteEnv(Env, ABC):
"""
Base class for discrete environments, where actions are represented by a number in
{0, ..., n_actions - 1}, the last one being the exit action.
`DiscreteEnv` allow specifying the validity of actions (forward and backward), via mask tensors, that
are directly attached to `States` objects.

`DiscreteEnv` allows for specifying the validity of actions (forward and backward),
via mask tensors, that are directly attached to `States` objects.
"""

def __init__(
Expand All @@ -214,16 +224,22 @@ def __init__(
sf: Optional[TT["state_shape", torch.float]] = None,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
log_reward_clip: Optional[float] = -100.0,
):
"""Initializes a discrete environment.

Args:
n_actions: The number of actions in the environment.

s0: The initial state tensor (shared among all trajectories).
sf: The final state tensor (shared among all trajectories).
device_str: String representation of a torch.device.
preprocessor: An optional preprocessor for intermediate states.
log_reward_clip: Used to clip small rewards (in particular, log(0) rewards).
"""
self.n_actions = n_actions
super().__init__(s0, sf, device_str, preprocessor)
super().__init__(s0, sf, device_str, preprocessor, log_reward_clip)
self.is_discrete = True
self.log_reward_clip = log_reward_clip
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary


def make_Actions_class(self) -> type[Actions]:
env = self
Expand Down
8 changes: 5 additions & 3 deletions src/gfn/gym/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
R2: float = 2.0,
epsilon: float = 1e-4,
device_str: Literal["cpu", "cuda"] = "cpu",
log_reward_clip: float = -100.0,
):
assert 0 < delta <= 1, "delta must be in (0, 1]"
self.delta = delta
Expand All @@ -30,7 +31,7 @@ def __init__(
self.R1 = R1
self.R2 = R2

super().__init__(s0=s0)
super().__init__(s0=s0, log_reward_clip=log_reward_clip)

def make_States_class(self) -> type[States]:
env = self
Expand Down Expand Up @@ -116,14 +117,15 @@ def is_action_valid(

return True

def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]:
def reward(self, final_states: States) -> TT["batch_shape", torch.float]:
"""Reward is distance from the goal point."""
R0, R1, R2 = (self.R0, self.R1, self.R2)
ax = abs(final_states.tensor - 0.5)
reward = (
R0 + (0.25 < ax).prod(-1) * R1 + ((0.3 < ax) * (ax < 0.4)).prod(-1) * R2
)

return reward.log()
return reward

@property
def log_partition(self) -> float:
Expand Down
45 changes: 24 additions & 21 deletions src/gfn/gym/discrete_ebm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import ClassVar, Literal, Tuple, cast
from typing import ClassVar, Literal, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -48,14 +48,19 @@ def __init__(
alpha: float = 1.0,
device_str: Literal["cpu", "cuda"] = "cpu",
preprocessor_name: Literal["Identity", "Enum"] = "Identity",
log_reward_clip: float = -100.0,
):
"""Discrete EBM environment.

Args:
ndim (int, optional): dimension D of the sampling space {0, 1}^D.
energy (EnergyFunction): energy function of the EBM. Defaults to None. If None, the Ising model with Identity matrix is used.
alpha (float, optional): interaction strength the EBM. Defaults to 1.0.
device_str (str, optional): "cpu" or "cuda". Defaults to "cpu".
ndim: dimension D of the sampling space {0, 1}^D.
energy: energy function of the EBM. Defaults to None. If
None, the Ising model with Identity matrix is used.
alpha: interaction strength the EBM. Defaults to 1.0.
device_str: "cpu" or "cuda". Defaults to "cpu".
preprocessor_name: "KHot" or "OneHot" or "Identity".
Defaults to "KHot".
log_reward_clip: Minimum log reward allowable (namely, for log(0)).
"""
self.ndim = ndim

Expand Down Expand Up @@ -89,6 +94,7 @@ def __init__(
sf=sf,
device_str=device_str,
preprocessor=preprocessor,
log_reward_clip=log_reward_clip,
)

def make_States_class(self) -> type[DiscreteStates]:
Expand Down Expand Up @@ -133,16 +139,7 @@ def make_masks(
return forward_masks, backward_masks

def update_masks(self) -> None:
# The following two lines are for typing only.
self.forward_masks = cast(
TT["batch_shape", "n_actions", torch.bool],
self.forward_masks,
)
self.backward_masks = cast(
TT["batch_shape", "n_actions - 1", torch.bool],
self.backward_masks,
)

self.set_default_typing()
self.forward_masks[..., : env.ndim] = self.tensor == -1
self.forward_masks[..., env.ndim : 2 * env.ndim] = self.tensor == -1
self.forward_masks[..., -1] = torch.all(self.tensor != -1, dim=-1)
Expand Down Expand Up @@ -183,16 +180,22 @@ def maskless_backward_step(
# action i in [ndim, 2*ndim-1] corresponds to replacing s[i - ndim] with 1.
# A backward action asks "what index should be set back to -1", hence the fmod
# to enable wrapping of indices.
return states.tensor.scatter(
-1,
actions.tensor.fmod(self.ndim),
-1,
)
return states.tensor.scatter(-1, actions.tensor.fmod(self.ndim), -1)

def reward(self, final_states: DiscreteStates) -> TT["batch_shape"]:
"""Not used during training but provided for completeness.

Note the effect of clipping will be seen in these values.
"""
return torch.exp(self.log_reward(final_states))

def log_reward(self, final_states: DiscreteStates) -> TT["batch_shape"]:
"""The energy weighted by alpha is our log reward."""
raw_states = final_states.tensor
canonical = 2 * raw_states - 1
return -self.alpha * self.energy(canonical)
log_reward = -self.alpha * self.energy(canonical)

return log_reward.clip(self.log_reward_clip)

def get_states_indices(self, states: DiscreteStates) -> TT["batch_shape"]:
"""The chosen encoding is the following: -1 -> 0, 0 -> 1, 1 -> 2, then we convert to base 3"""
Expand Down
3 changes: 2 additions & 1 deletion src/gfn/gym/helpers/box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Beta, Categorical, Distribution, MixtureSameFamily
from torchtyping import TensorType as TT

Expand Down Expand Up @@ -600,8 +601,8 @@ class BoxStateFlowModule(NeuralNet):
"""A deep neural network for the state flow function."""

def __init__(self, logZ_value: torch.Tensor, **kwargs):
self.logZ_value = logZ_value
super().__init__(**kwargs)
self.logZ_value = nn.Parameter(logZ_value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only place BoxStateFlowModule is used is in train_box.py:

logZ = torch.tensor(0.0, device=env.device, requires_grad=True)
        # We need a LogStateFlowEstimator

        module = BoxStateFlowModule(
            input_dim=env.preprocessor.output_dim,
            output_dim=1,
            hidden_dim=args.hidden_dim,
            n_hidden_layers=args.n_hidden,
            torso=None,  # We do not tie the parameters of the flow function to PF
            logZ_value=logZ,
        )

Naive pytorch question: why do we need nn.Parameter ?


def forward(
self, preprocessed_states: TT["batch_shape", "input_dim", float]
Expand Down
Loading