-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Easier environment definition #143
Merged
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
dc774e9
fix __repr__ of modules - no more env
saleml ec4b415
Merge pull request #135 from saleml/fixrepr
josephdviviano 6c618f4
bump up version
saleml 298405b
convience functions for common mask operations in DiscreteStates
josephdviviano 38ab3dc
change name of enforce_exit_masks method for consistency
josephdviviano 6ecb60d
added default log reward clipping. Also, reward() is by default not i…
josephdviviano b77ccad
black formatting
josephdviviano 4b8cf79
updated scripts to be consistent with new mask and reward defintion p…
josephdviviano a357475
log reward clipping is now -100 (much smaller)
josephdviviano d3a72bb
I'm actually trying to trigger actions...
josephdviviano 4f6e965
added flake8 back into deps
josephdviviano e2c74b6
missing quote
josephdviviano d3bcdf7
letting notebooks back in
josephdviviano 382f396
added smiley FM & TB tutorial
josephdviviano 4a2861a
bugfix: unrelated, but logz was not being learned during the task any…
josephdviviano 77abec6
black and autoflake reformatting
saleml 0400d04
specify flake8 and black versions
saleml 03f6ca6
ignore one error from flake8 and revert version selection
saleml ad12af3
add wandb for pytest ci, and add quotations around python versions
saleml e0d750b
remove explicit wandb, and use pip install.[all]
saleml 8678f9c
line environment v1 done
josephdviviano b1f1040
Merge branch 'easier_environment_definition' of github.com:saleml/tor…
josephdviviano f4f69d6
smiley tweaks
josephdviviano 0735ceb
isort
josephdviviano 4161dba
typo fixed
josephdviviano 4d8a145
documentation nits and using convienience builtins
josephdviviano 459203c
using set_nonexit_action_masks properly, with a comment for clarity
josephdviviano fb7e819
Improved method name and documentaiton
josephdviviano 226185f
changed test logic (will run on github instead of this terrible compu…
josephdviviano 84bb169
new targets
josephdviviano File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -182,6 +182,5 @@ wandb/ | |
|
||
scripts.py | ||
|
||
*.ipynb | ||
models/ | ||
*.DS_Store | ||
*.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" | |
[tool.poetry] | ||
name = "torchgfn" | ||
packages = [{include = "gfn", from = "src"}] | ||
version = "1.1" | ||
version = "1.1.1" | ||
description = "A torch implementation of GFlowNets" | ||
authors = ["Salem Lahou <[email protected]>", "Joseph Viviano <[email protected]>", "Victor Schmidt <[email protected]>"] | ||
license = "MIT" | ||
|
@@ -26,16 +26,17 @@ torch = ">=1.9.0" | |
torchtyping = ">=0.1.4" | ||
|
||
# dev dependencies. | ||
black = { version = "22.3.0", optional = true } | ||
black = { version = "*", optional = true } | ||
flake8 = { version = "*", optional = true } | ||
gitmopy = { version = "*", optional = true } | ||
myst-parser = { version = "*", optional = true } | ||
pre-commit = { version = "*", optional = true } | ||
pytest = { version = "*", optional = true } | ||
renku-sphinx-theme = { version = "*", optional = true } | ||
sphinx = { version = "*", optional = true } | ||
sphinx_rtd_theme = { version = "*", optional = true } | ||
sphinx-autoapi = { version = "*", optional = true } | ||
sphinx-math-dollar = { version = "*", optional = true } | ||
sphinx_rtd_theme = { version = "*", optional = true } | ||
tox = { version = "*", optional = true } | ||
|
||
# scripts dependencies. | ||
|
@@ -52,30 +53,32 @@ dev = [ | |
"pre-commit", | ||
"pytest", | ||
"renku-sphinx-theme", | ||
"sphinx", | ||
"sphinx_rtd_theme", | ||
"sphinx-autoapi", | ||
"sphinx-math-dollar", | ||
"sphinx_rtd_theme", | ||
"tox" | ||
"sphinx", | ||
"tox", | ||
"flake8", | ||
] | ||
|
||
scripts = ["tqdm", "wandb", "scikit-learn", "scipy"] | ||
|
||
all = [ | ||
"black", | ||
"flake8", | ||
"myst-parser", | ||
"pre-commit", | ||
"pytest", | ||
"renku-sphinx-theme", | ||
"scikit-learn", | ||
"scipy", | ||
"sphinx_rtd_theme", | ||
"sphinx-autoapi", | ||
"sphinx-math-dollar", | ||
"sphinx", | ||
"tox", | ||
"black", | ||
"myst-parser", | ||
"tqdm", | ||
"wandb", | ||
"scikit-learn", | ||
"scipy" | ||
] | ||
|
||
[project.urls] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,16 +23,21 @@ def __init__( | |
sf: Optional[TT["state_shape", torch.float]] = None, | ||
device_str: Optional[str] = None, | ||
preprocessor: Optional[Preprocessor] = None, | ||
log_reward_clip: Optional[float] = -100.0, | ||
): | ||
"""Initializes an environment. | ||
|
||
Args: | ||
s0: Representation of the initial state. All individual states would be of the same shape. | ||
sf (optional): Representation of the final state. Only used for a human readable representation of | ||
the states or trajectories. | ||
device_str (Optional[str], optional): 'cpu' or 'cuda'. Defaults to None, in which case the device is inferred from s0. | ||
preprocessor (Optional[Preprocessor], optional): a Preprocessor object that converts raw states to a tensor that can be fed | ||
into a neural network. Defaults to None, in which case the IdentityPreprocessor is used. | ||
s0: Representation of the initial state. All individual states would be of | ||
the same shape. | ||
sf: Representation of the final state. Only used for a human | ||
readable representation of the states or trajectories. | ||
device_str: 'cpu' or 'cuda'. Defaults to None, in which case the device is | ||
inferred from s0. | ||
preprocessor: a Preprocessor object that converts raw states to a tensor | ||
that can be fed into a neural network. Defaults to None, in which case | ||
the IdentityPreprocessor is used. | ||
log_reward_clip: Used to clip small rewards (in particular, log(0) rewards). | ||
""" | ||
self.device = torch.device(device_str) if device_str is not None else s0.device | ||
|
||
|
@@ -53,6 +58,7 @@ def __init__( | |
|
||
self.preprocessor = preprocessor | ||
self.is_discrete = False | ||
self.log_reward_clip = log_reward_clip | ||
|
||
@abstractmethod | ||
def make_States_class(self) -> type[States]: | ||
|
@@ -184,12 +190,15 @@ def backward_step( | |
return new_states | ||
|
||
def reward(self, final_states: States) -> TT["batch_shape", torch.float]: | ||
"""Either this or log_reward needs to be implemented.""" | ||
return torch.exp(self.log_reward(final_states)) | ||
"""The environment's reward given a state. | ||
|
||
This or log_reward must be implemented. | ||
""" | ||
raise NotImplementedError("Reward function is not implemented.") | ||
|
||
def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]: | ||
"""Either this or reward needs to be implemented.""" | ||
raise NotImplementedError("log_reward function not implemented") | ||
"""Calculates the log reward (clipping small rewards).""" | ||
return torch.log(self.reward(final_states)).clip(self.log_reward_clip) | ||
|
||
@property | ||
def log_partition(self) -> float: | ||
|
@@ -203,8 +212,9 @@ class DiscreteEnv(Env, ABC): | |
""" | ||
Base class for discrete environments, where actions are represented by a number in | ||
{0, ..., n_actions - 1}, the last one being the exit action. | ||
`DiscreteEnv` allow specifying the validity of actions (forward and backward), via mask tensors, that | ||
are directly attached to `States` objects. | ||
|
||
`DiscreteEnv` allows for specifying the validity of actions (forward and backward), | ||
via mask tensors, that are directly attached to `States` objects. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -214,16 +224,22 @@ def __init__( | |
sf: Optional[TT["state_shape", torch.float]] = None, | ||
device_str: Optional[str] = None, | ||
preprocessor: Optional[Preprocessor] = None, | ||
log_reward_clip: Optional[float] = -100.0, | ||
): | ||
"""Initializes a discrete environment. | ||
|
||
Args: | ||
n_actions: The number of actions in the environment. | ||
|
||
s0: The initial state tensor (shared among all trajectories). | ||
sf: The final state tensor (shared among all trajectories). | ||
device_str: String representation of a torch.device. | ||
preprocessor: An optional preprocessor for intermediate states. | ||
log_reward_clip: Used to clip small rewards (in particular, log(0) rewards). | ||
""" | ||
self.n_actions = n_actions | ||
super().__init__(s0, sf, device_str, preprocessor) | ||
super().__init__(s0, sf, device_str, preprocessor, log_reward_clip) | ||
self.is_discrete = True | ||
self.log_reward_clip = log_reward_clip | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unnecessary |
||
|
||
def make_Actions_class(self) -> type[Actions]: | ||
env = self | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from torch.distributions import Beta, Categorical, Distribution, MixtureSameFamily | ||
from torchtyping import TensorType as TT | ||
|
||
|
@@ -600,8 +601,8 @@ class BoxStateFlowModule(NeuralNet): | |
"""A deep neural network for the state flow function.""" | ||
|
||
def __init__(self, logZ_value: torch.Tensor, **kwargs): | ||
self.logZ_value = logZ_value | ||
super().__init__(**kwargs) | ||
self.logZ_value = nn.Parameter(logZ_value) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only place
Naive pytorch question: why do we need |
||
|
||
def forward( | ||
self, preprocessed_states: TT["batch_shape", "input_dim", float] | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea !