From 92f4d8328a86eee620df6b14fd816f9a1aeb8b0b Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 19 Mar 2024 14:59:06 -0400 Subject: [PATCH 01/73] Use a partial instantiation of the environment and use it to create multiple instances of the env in the agent instead of copying it --- gflownet/gflownet.py | 13 +++++++------ main.py | 9 +++++++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 85b8fa28e..dae87edd6 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -37,7 +37,7 @@ class GFlowNetAgent: def __init__( self, - env, + env_maker, seed, device, float_precision, @@ -66,7 +66,8 @@ def __init__( # Float precision self.float = set_float_precision(float_precision) # Environment - self.env = env + self.env_maker = env_maker + self.env = self.env_maker() # Continuous environments self.continuous = hasattr(self.env, "continuous") and self.env.continuous if self.continuous and optimizer.loss in ["flowmatch", "flowmatching"]: @@ -435,7 +436,7 @@ def sample_batch( # ON-POLICY FORWARD trajectories t0_forward = time.time() - envs = [self.env.copy().reset(idx) for idx in range(n_forward)] + envs = [self.env_maker().set_id(idx) for idx in range(n_forward)] batch_forward = Batch(env=self.env, device=self.device, float_type=self.float) while envs: # Sample actions @@ -457,7 +458,7 @@ def sample_batch( # TRAIN BACKWARD trajectories t0_train = time.time() - envs = [self.env.copy().reset(idx) for idx in range(n_train)] + envs = [self.env_maker().set_id(idx) for idx in range(n_train)] batch_train = Batch(env=self.env, device=self.device, float_type=self.float) if n_train > 0 and self.buffer.train_pkl is not None: with open(self.buffer.train_pkl, "rb") as f: @@ -494,7 +495,7 @@ def sample_batch( with open(self.buffer.replay_pkl, "rb") as f: dict_replay = pickle.load(f) n_replay = min(n_replay, len(dict_replay["x"])) - envs = [self.env.copy().reset(idx) for idx in range(n_replay)] + envs = [self.env_maker().set_id(idx) for idx in range(n_replay)] x_replay = self.buffer.select( dict_replay, n_replay, self.replay_sampling, self.rng ) @@ -915,7 +916,7 @@ def estimate_logprobs_data( for state_idx in range(init_batch, end_batch): for traj_idx in range(n_trajectories): idx = int(mult_indices * state_idx + traj_idx) - env = self.env.copy().reset(idx) + env = self.env_maker().set_id(idx) env.set_state(states_term[state_idx], done=True) envs.append(env) # Sample trajectories diff --git a/main.py b/main.py index 25171afc6..506752c66 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ """ Runnable script with hydra capabilities """ + import os import pickle import random @@ -37,12 +38,16 @@ def main(config): float_precision=config.float_precision, ) # The proxy is passed to env and used for computing rewards - env = hydra.utils.instantiate( + # Using Hydra's partial instantiation, see: + # https://hydra.cc/docs/advanced/instantiate_objects/overview/#partial-instantiation + env_maker = hydra.utils.instantiate( config.env, proxy=proxy, device=config.device, float_precision=config.float_precision, + _partial_=True, ) + env = env_maker() # The policy is used to model the probability of a forward/backward action forward_config = parse_policy_config(config, kind="forward") backward_config = parse_policy_config(config, kind="backward") @@ -76,7 +81,7 @@ def main(config): config.gflownet, device=config.device, float_precision=config.float_precision, - env=env, + env_maker=env_maker, forward_policy=forward_policy, backward_policy=backward_policy, state_flow=state_flow, From 9d9a3d7ac651f4896062800070f64bdf1f666f76 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 19 Mar 2024 15:26:50 -0400 Subject: [PATCH 02/73] Adapt common tests to new GFlowNetAgent env_maker parameter --- tests/gflownet/envs/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/gflownet/envs/common.py b/tests/gflownet/envs/common.py index c499995dc..53d513de5 100644 --- a/tests/gflownet/envs/common.py +++ b/tests/gflownet/envs/common.py @@ -477,7 +477,7 @@ def test__gflownet_minimal_runs(self, n_repeat=1): config.gflownet, device=config.device, float_precision=config.float_precision, - env=self.env, + env_maker=self.env.__class__, forward_policy=forward_policy, backward_policy=backward_policy, buffer=config.env.buffer, From 9349497a874f1a9c695aa9649cd6c120e4167c3a Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Fri, 29 Mar 2024 18:15:14 -0400 Subject: [PATCH 03/73] WIP: progress in implementing reward-handling functionality in base proxy. --- gflownet/proxy/base.py | 86 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 84 insertions(+), 2 deletions(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 045ed7841..663d960fa 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -3,9 +3,11 @@ """ from abc import ABC, abstractmethod +from typing import Callable, List, Union import numpy as np import numpy.typing as npt +from torchtyping import TensorType from gflownet.utils.common import set_device, set_float_precision @@ -15,7 +17,19 @@ class Proxy(ABC): Generic proxy class """ - def __init__(self, device, float_precision, higher_is_better=False, **kwargs): + def __init__( + self, + device, + float_precision, + reward_function: Union[Callable, str] = "identity", + reward_function_kwargs: dict = None, + higher_is_better=False, + **kwargs, + ): + # Proxy to reward function + self.reward_function = self._get_reward_function( + reward_function, reward_function_kwargs + ) # Device self.device = set_device(device) # Float precision @@ -27,7 +41,7 @@ def setup(self, env=None): pass @abstractmethod - def __call__(self, states: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: + def __call__(self, states: Union[TensorType, List, npt.NDArray]) -> TensorType: """ Implement this function to call the get_reward method of the appropriate Proxy Class (EI, UCB, Proxy, Oracle etc). @@ -38,6 +52,74 @@ def __call__(self, states: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: """ pass + def rewards(self, states: Union[TensorType, List, npt.NDArray]) -> TensorType: + """ + Computes the rewards of a batch of states. + + The rewards are computing by first calling the proxy function, then + transforming the proxy values according to the reward function. + + Parameters + ---------- + states : tensor or list or array + A batch of states in proxy format. + + Returns + ------- + tensor + The reward of all elements in the batch. + """ + return self.proxy2reward(self(states)) + + # TODO: consider adding option to clip values + # TODO: check that rewards are non-negative + def proxy2reward(proxy_values: TensorType) -> TensorType: + """ + Transform a tensor of proxy values into rewards. + + Parameters + ---------- + proxy_values : tensor + The proxy values corresponding to a batch of states. + + Returns + ------- + tensor + The reward of all elements in the batch. + """ + return self.reward_func(proxy_values) + + def _get_reward_function(reward_function: Union[Callable, str], **kwargs): + r""" + Returns a callable corresponding to the function that transforms proxy values + into rewards. + + If reward_function is callable, it is returned as is. If it is a string, it + must correspond to one of the following options: + + - power: the rewards are the proxy values to the power of beta. See: + :py:meth:`~gflownet.proxy.base._power()` + - boltzmann: the rewards are the negative exponential of the proxy values. + See: :py:meth:`~gflownet.proxy.base._boltzmann()` + - shift: the rewards are the proxy values shifted by beta. + See: :py:meth:`~gflownet.proxy.base._boltzmann()` + + Parameters + ---------- + reward_function : callable or str + A callable or a string corresponding to one of the pre-defined functions. + """ + # If reward_function is callable, return it + if isinstance(reward_function, Callable): + return reward_function + + # Otherwise it must be a string + if not isinstance(reward_function, str): + raise AssertionError( + "reward_func must be a callable or a string; " + f"got {type(reward_function)} instead." + ) + def infer_on_train_set(self): """ Implement this method in specific proxies. From d0db503d380d93a53a9841feff9218b448c357c1 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 1 Apr 2024 10:19:09 -0400 Subject: [PATCH 04/73] Basic test for base proxy and fixes. --- gflownet/proxy/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 663d960fa..c1dae4870 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -22,13 +22,13 @@ def __init__( device, float_precision, reward_function: Union[Callable, str] = "identity", - reward_function_kwargs: dict = None, + reward_function_kwargs: dict = {}, higher_is_better=False, **kwargs, ): # Proxy to reward function self.reward_function = self._get_reward_function( - reward_function, reward_function_kwargs + reward_function, **reward_function_kwargs ) # Device self.device = set_device(device) @@ -89,7 +89,7 @@ def proxy2reward(proxy_values: TensorType) -> TensorType: """ return self.reward_func(proxy_values) - def _get_reward_function(reward_function: Union[Callable, str], **kwargs): + def _get_reward_function(self, reward_function: Union[Callable, str], **kwargs): r""" Returns a callable corresponding to the function that transforms proxy values into rewards. From c4b6e5c444098e16ea94a9cb1b377caa1a6fad60 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 1 Apr 2024 10:21:07 -0400 Subject: [PATCH 05/73] Remove higher_is_better parameter --- gflownet/proxy/base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index c1dae4870..3eba0b38c 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -23,7 +23,6 @@ def __init__( float_precision, reward_function: Union[Callable, str] = "identity", reward_function_kwargs: dict = {}, - higher_is_better=False, **kwargs, ): # Proxy to reward function @@ -34,8 +33,6 @@ def __init__( self.device = set_device(device) # Float precision self.float = set_float_precision(float_precision) - # Reward2Proxy multiplicative factor (1 or -1) - self.higher_is_better = higher_is_better def setup(self, env=None): pass From f1ca7ef40045d5cd88f4fa9e2e5b2016b306f7bd Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 1 Apr 2024 10:22:53 -0400 Subject: [PATCH 06/73] Add basic log_rewards() --- gflownet/proxy/base.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 3eba0b38c..662dc1e5d 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -53,7 +53,7 @@ def rewards(self, states: Union[TensorType, List, npt.NDArray]) -> TensorType: """ Computes the rewards of a batch of states. - The rewards are computing by first calling the proxy function, then + The rewards are computed by first calling the proxy function, then transforming the proxy values according to the reward function. Parameters @@ -68,6 +68,26 @@ def rewards(self, states: Union[TensorType, List, npt.NDArray]) -> TensorType: """ return self.proxy2reward(self(states)) + def log_rewards(self, states: Union[TensorType, List, npt.NDArray]) -> TensorType: + """ + Computes the log(rewards) of a batch of states. + + The rewards are computed by first calling the proxy function, then + transforming the proxy values according to the reward function, then taking the + logarithm. + + Parameters + ---------- + states : tensor or list or array + A batch of states in proxy format. + + Returns + ------- + tensor + The log reward of all elements in the batch. + """ + return torch.log(self.proxy2reward(self(states))) + # TODO: consider adding option to clip values # TODO: check that rewards are non-negative def proxy2reward(proxy_values: TensorType) -> TensorType: From c0be4fb5069a07113e8bda6a280b71976f1ddfb9 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 1 Apr 2024 11:06:10 -0400 Subject: [PATCH 07/73] Implement power proxy2reward --- gflownet/proxy/base.py | 29 ++++++++++++++--- tests/gflownet/proxy/test_base.py | 52 +++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 4 deletions(-) create mode 100644 tests/gflownet/proxy/test_base.py diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 662dc1e5d..74294f6fc 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -26,7 +26,8 @@ def __init__( **kwargs, ): # Proxy to reward function - self.reward_function = self._get_reward_function( + self.reward_function = reward_function + self._reward_function = self._get_reward_function( reward_function, **reward_function_kwargs ) # Device @@ -90,7 +91,7 @@ def log_rewards(self, states: Union[TensorType, List, npt.NDArray]) -> TensorTyp # TODO: consider adding option to clip values # TODO: check that rewards are non-negative - def proxy2reward(proxy_values: TensorType) -> TensorType: + def proxy2reward(self, proxy_values: TensorType) -> TensorType: """ Transform a tensor of proxy values into rewards. @@ -104,7 +105,7 @@ def proxy2reward(proxy_values: TensorType) -> TensorType: tensor The reward of all elements in the batch. """ - return self.reward_func(proxy_values) + return self._reward_function(proxy_values) def _get_reward_function(self, reward_function: Union[Callable, str], **kwargs): r""" @@ -133,10 +134,30 @@ def _get_reward_function(self, reward_function: Union[Callable, str], **kwargs): # Otherwise it must be a string if not isinstance(reward_function, str): raise AssertionError( - "reward_func must be a callable or a string; " + "reward_function must be a callable or a string; " f"got {type(reward_function)} instead." ) + if reward_function == "power": + return Proxy._power(**kwargs) + + @staticmethod + def _power(beta: float = 1.0) -> Callable: + """ + Returns a lambda expression where the input (proxy values) are raised to the + power of beta. + + Parameters + ---------- + beta : float + The exponent to which the proxy values are raised. + + Returns + ------- + A lambda expression proxy values raised to the power of beta. + """ + return lambda proxy_values: proxy_values**beta + def infer_on_train_set(self): """ Implement this method in specific proxies. diff --git a/tests/gflownet/proxy/test_base.py b/tests/gflownet/proxy/test_base.py new file mode 100644 index 000000000..8819b8e6b --- /dev/null +++ b/tests/gflownet/proxy/test_base.py @@ -0,0 +1,52 @@ +import pytest +import torch + +from gflownet.proxy.base import Proxy +from gflownet.proxy.uniform import Uniform +from gflownet.utils.common import tfloat + + +@pytest.fixture() +def uniform(): + return Uniform(device="cpu", float_precision=32) + + +@pytest.fixture() +def uniform_power(beta): + return Uniform( + reward_function="power", + reward_function_kwargs={"beta": beta}, + device="cpu", + float_precision=32, + ) + + +@pytest.mark.parametrize("proxy, beta", [("uniform", None), ("uniform_power", 1)]) +def test__uniform_proxy_initializes_without_errors(proxy, beta, request): + proxy = request.getfixturevalue(proxy) + return proxy + + +@pytest.mark.parametrize( + "beta, proxy_values, rewards_exp", + [ + ( + 1, + [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], + [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], + ), + ( + 2, + [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], + [10000, 100, 1, 0.25, 0.01, 0.0, 0.01, 0.25, 1, 100, 10000], + ), + ], +) +def test_reward_function_power__behaves_as_expected( + uniform_power, beta, proxy_values, rewards_exp +): + proxy = uniform_power + proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) + rewards_exp = tfloat(rewards_exp, device=proxy.device, float_type=proxy.float) + assert all(torch.isclose(proxy._reward_function(proxy_values), rewards_exp)) + assert all(torch.isclose(proxy.proxy2reward(proxy_values), rewards_exp)) From fb011aeec8443f8c2e761267c4280ee03a2fb707 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 1 Apr 2024 11:51:31 -0400 Subject: [PATCH 08/73] Implement exponential proxy2reward --- gflownet/proxy/base.py | 35 ++++++++++++--- tests/gflownet/proxy/test_base.py | 72 ++++++++++++++++++++++++++++++- 2 files changed, 98 insertions(+), 9 deletions(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 74294f6fc..040c771e4 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -7,6 +7,7 @@ import numpy as np import numpy.typing as npt +import torch from torchtyping import TensorType from gflownet.utils.common import set_device, set_float_precision @@ -115,12 +116,12 @@ def _get_reward_function(self, reward_function: Union[Callable, str], **kwargs): If reward_function is callable, it is returned as is. If it is a string, it must correspond to one of the following options: - - power: the rewards are the proxy values to the power of beta. See: + - pow(er): the rewards are the proxy values to the power of beta. See: :py:meth:`~gflownet.proxy.base._power()` - - boltzmann: the rewards are the negative exponential of the proxy values. - See: :py:meth:`~gflownet.proxy.base._boltzmann()` + - exp(onential) or boltzmann: the rewards are the negative exponential of + the proxy values. See: :py:meth:`~gflownet.proxy.base._exponential()` - shift: the rewards are the proxy values shifted by beta. - See: :py:meth:`~gflownet.proxy.base._boltzmann()` + See: :py:meth:`~gflownet.proxy.base._shift()` Parameters ---------- @@ -138,13 +139,16 @@ def _get_reward_function(self, reward_function: Union[Callable, str], **kwargs): f"got {type(reward_function)} instead." ) - if reward_function == "power": + if reward_function.startswith("pow"): return Proxy._power(**kwargs) + if reward_function.startswith("exp") or reward_function == "boltzmann": + return Proxy._exponential(**kwargs) + @staticmethod def _power(beta: float = 1.0) -> Callable: """ - Returns a lambda expression where the input (proxy values) are raised to the + Returns a lambda expression where the inputs (proxy values) are raised to the power of beta. Parameters @@ -154,10 +158,27 @@ def _power(beta: float = 1.0) -> Callable: Returns ------- - A lambda expression proxy values raised to the power of beta. + A lambda expression where the proxy values raised to the power of beta. """ return lambda proxy_values: proxy_values**beta + @staticmethod + def _exponential(beta: float = 1.0) -> Callable: + """ + Returns a lambda expression where the output is the exponential of the product + of the input (proxy) values and beta. + + Parameters + ---------- + beta : float + The factor by which the proxy values are multiplied. + + Returns + ------- + A lambda expression that takes the exponential of the proxy values * beta. + """ + return lambda proxy_values: torch.exp(proxy_values**beta) + def infer_on_train_set(self): """ Implement this method in specific proxies. diff --git a/tests/gflownet/proxy/test_base.py b/tests/gflownet/proxy/test_base.py index 8819b8e6b..7badc3dcd 100644 --- a/tests/gflownet/proxy/test_base.py +++ b/tests/gflownet/proxy/test_base.py @@ -21,10 +21,31 @@ def uniform_power(beta): ) -@pytest.mark.parametrize("proxy, beta", [("uniform", None), ("uniform_power", 1)]) +@pytest.fixture() +def uniform_exponential(beta): + return Uniform( + reward_function="exponential", + reward_function_kwargs={"beta": beta}, + device="cpu", + float_precision=32, + ) + + +@pytest.mark.parametrize( + "proxy, beta", + [ + ("uniform", None), + ("uniform_power", 1), + ("uniform_power", 2), + ("uniform_exponential", 1), + ("uniform_exponential", -1), + ("uniform_exponential", 2), + ("uniform_exponential", -2), + ], +) def test__uniform_proxy_initializes_without_errors(proxy, beta, request): proxy = request.getfixturevalue(proxy) - return proxy + assert True @pytest.mark.parametrize( @@ -50,3 +71,50 @@ def test_reward_function_power__behaves_as_expected( rewards_exp = tfloat(rewards_exp, device=proxy.device, float_type=proxy.float) assert all(torch.isclose(proxy._reward_function(proxy_values), rewards_exp)) assert all(torch.isclose(proxy.proxy2reward(proxy_values), rewards_exp)) + + +@pytest.mark.parametrize( + "beta, proxy_values, rewards_exp", + [ + ( + 1.0, + [-10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10], + [ + 4.54e-05, + 3.6788e-01, + 6.0653e-01, + 9.0484e-01, + 1.0, + 1.1052, + 1.6487e00, + 2.7183, + 22026.4648, + ], + ), + ( + -1.0, + [-10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10], + [ + 22026.4648, + 2.7183, + 1.6487, + 1.1052, + 1.0, + 9.0484e-01, + 6.0653e-01, + 3.6788e-01, + 4.54e-05, + ], + ), + ], +) +def test_reward_function_exponential__behaves_as_expected( + uniform_exponential, beta, proxy_values, rewards_exp +): + proxy = uniform_exponential + proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) + rewards_exp = tfloat(rewards_exp, device=proxy.device, float_type=proxy.float) + assert all( + torch.isclose(proxy._reward_function(proxy_values), rewards_exp, atol=1e-4) + ) + assert all(torch.isclose(proxy.proxy2reward(proxy_values), rewards_exp, atol=1e-4)) From 2c8ef545503d675ac1eefb53a1e0d1c32da188bd Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 1 Apr 2024 12:16:27 -0400 Subject: [PATCH 09/73] Implement shift proxy2reward; fixes and doscstring --- gflownet/proxy/base.py | 38 +++++++++++++++++++++++++++--- tests/gflownet/proxy/test_base.py | 39 +++++++++++++++++++++++++++++-- 2 files changed, 72 insertions(+), 5 deletions(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 040c771e4..c3ad0e404 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -28,6 +28,7 @@ def __init__( ): # Proxy to reward function self.reward_function = reward_function + self.reward_function_kwargs = reward_function_kwargs self._reward_function = self._get_reward_function( reward_function, **reward_function_kwargs ) @@ -145,12 +146,19 @@ def _get_reward_function(self, reward_function: Union[Callable, str], **kwargs): if reward_function.startswith("exp") or reward_function == "boltzmann": return Proxy._exponential(**kwargs) + if reward_function.startswith("shift"): + return Proxy._shift(**kwargs) + @staticmethod def _power(beta: float = 1.0) -> Callable: - """ + r""" Returns a lambda expression where the inputs (proxy values) are raised to the power of beta. + $$ + R(x) = \varepsilon(x)^{\beta} + $$ + Parameters ---------- beta : float @@ -164,10 +172,14 @@ def _power(beta: float = 1.0) -> Callable: @staticmethod def _exponential(beta: float = 1.0) -> Callable: - """ + r""" Returns a lambda expression where the output is the exponential of the product of the input (proxy) values and beta. + $$ + R(x) = \exp{\beta\varepsilon(x)} + $$ + Parameters ---------- beta : float @@ -177,7 +189,27 @@ def _exponential(beta: float = 1.0) -> Callable: ------- A lambda expression that takes the exponential of the proxy values * beta. """ - return lambda proxy_values: torch.exp(proxy_values**beta) + return lambda proxy_values: torch.exp(proxy_values * beta) + + @staticmethod + def _shift(beta: float = 1.0) -> Callable: + r""" + Returns a lambda expression where the inputs (proxy values) are shifted by beta. + + $$ + R(x) = \varepsilon(x) + \beta + $$ + + Parameters + ---------- + beta : float + The factor by which the proxy values are shifted. + + Returns + ------- + A lambda expression that shifts the proxy values by beta. + """ + return lambda proxy_values: proxy_values + beta def infer_on_train_set(self): """ diff --git a/tests/gflownet/proxy/test_base.py b/tests/gflownet/proxy/test_base.py index 7badc3dcd..baa3a6434 100644 --- a/tests/gflownet/proxy/test_base.py +++ b/tests/gflownet/proxy/test_base.py @@ -31,6 +31,16 @@ def uniform_exponential(beta): ) +@pytest.fixture() +def uniform_shift(beta): + return Uniform( + reward_function="shift", + reward_function_kwargs={"beta": beta}, + device="cpu", + float_precision=32, + ) + + @pytest.mark.parametrize( "proxy, beta", [ @@ -39,8 +49,8 @@ def uniform_exponential(beta): ("uniform_power", 2), ("uniform_exponential", 1), ("uniform_exponential", -1), - ("uniform_exponential", 2), - ("uniform_exponential", -2), + ("uniform_shift", 5), + ("uniform_shift", -5), ], ) def test__uniform_proxy_initializes_without_errors(proxy, beta, request): @@ -118,3 +128,28 @@ def test_reward_function_exponential__behaves_as_expected( torch.isclose(proxy._reward_function(proxy_values), rewards_exp, atol=1e-4) ) assert all(torch.isclose(proxy.proxy2reward(proxy_values), rewards_exp, atol=1e-4)) + + +@pytest.mark.parametrize( + "beta, proxy_values, rewards_exp", + [ + ( + 5, + [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], + [-95, -5, 4, 4.5, 4.9, 5.0, 5.1, 5.5, 6, 15, 105], + ), + ( + -5, + [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], + [-105, -15, -6, -5.5, -5.1, -5.0, -4.9, -4.5, -4, 5, 95], + ), + ], +) +def test_reward_function_shift__behaves_as_expected( + uniform_shift, beta, proxy_values, rewards_exp +): + proxy = uniform_shift + proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) + rewards_exp = tfloat(rewards_exp, device=proxy.device, float_type=proxy.float) + assert all(torch.isclose(proxy._reward_function(proxy_values), rewards_exp)) + assert all(torch.isclose(proxy.proxy2reward(proxy_values), rewards_exp)) From 84409988beab13d6ebad34fb0c3b813aa6a9afb4 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 1 Apr 2024 12:26:05 -0400 Subject: [PATCH 10/73] Implement product proxy2reward --- gflownet/proxy/base.py | 36 ++++++++++++++++++++++++++++-- tests/gflownet/proxy/test_base.py | 37 +++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index c3ad0e404..ef6d257a2 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -123,6 +123,8 @@ def _get_reward_function(self, reward_function: Union[Callable, str], **kwargs): the proxy values. See: :py:meth:`~gflownet.proxy.base._exponential()` - shift: the rewards are the proxy values shifted by beta. See: :py:meth:`~gflownet.proxy.base._shift()` + - prod(uct): the rewards are the proxy values multiplied by beta. + See: :py:meth:`~gflownet.proxy.base._product()` Parameters ---------- @@ -143,12 +145,21 @@ def _get_reward_function(self, reward_function: Union[Callable, str], **kwargs): if reward_function.startswith("pow"): return Proxy._power(**kwargs) - if reward_function.startswith("exp") or reward_function == "boltzmann": + elif reward_function.startswith("exp") or reward_function == "boltzmann": return Proxy._exponential(**kwargs) - if reward_function.startswith("shift"): + elif reward_function == "shift": return Proxy._shift(**kwargs) + elif reward_function.startswith("prod"): + return Proxy._product(**kwargs) + + else: + raise ValueError( + "reward_function must be one of: pow(er), exp(onential), shift, " + f"prod(uct). Received {reward_function} instead." + ) + @staticmethod def _power(beta: float = 1.0) -> Callable: r""" @@ -211,6 +222,27 @@ def _shift(beta: float = 1.0) -> Callable: """ return lambda proxy_values: proxy_values + beta + @staticmethod + def _product(beta: float = 1.0) -> Callable: + r""" + Returns a lambda expression where the inputs (proxy values) are multiplied by + beta. + + $$ + R(x) = \beta\varepsilon(x) + $$ + + Parameters + ---------- + beta : float + The factor by which the proxy values are multiplied. + + Returns + ------- + A lambda expression that multiplies the proxy values by beta. + """ + return lambda proxy_values: proxy_values * beta + def infer_on_train_set(self): """ Implement this method in specific proxies. diff --git a/tests/gflownet/proxy/test_base.py b/tests/gflownet/proxy/test_base.py index baa3a6434..ab32c9ded 100644 --- a/tests/gflownet/proxy/test_base.py +++ b/tests/gflownet/proxy/test_base.py @@ -41,6 +41,16 @@ def uniform_shift(beta): ) +@pytest.fixture() +def uniform_product(beta): + return Uniform( + reward_function="product", + reward_function_kwargs={"beta": beta}, + device="cpu", + float_precision=32, + ) + + @pytest.mark.parametrize( "proxy, beta", [ @@ -51,6 +61,8 @@ def uniform_shift(beta): ("uniform_exponential", -1), ("uniform_shift", 5), ("uniform_shift", -5), + ("uniform_product", 2), + ("uniform_product", -2), ], ) def test__uniform_proxy_initializes_without_errors(proxy, beta, request): @@ -153,3 +165,28 @@ def test_reward_function_shift__behaves_as_expected( rewards_exp = tfloat(rewards_exp, device=proxy.device, float_type=proxy.float) assert all(torch.isclose(proxy._reward_function(proxy_values), rewards_exp)) assert all(torch.isclose(proxy.proxy2reward(proxy_values), rewards_exp)) + + +@pytest.mark.parametrize( + "beta, proxy_values, rewards_exp", + [ + ( + 2, + [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], + [-200, -20, -2, -1.0, -0.2, 0.0, 0.2, 1.0, 2, 20, 200], + ), + ( + -2, + [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], + [200, 20, 2, 1.0, 0.2, 0.0, -0.2, -1.0, -2, -20, -200], + ), + ], +) +def test_reward_function_product__behaves_as_expected( + uniform_product, beta, proxy_values, rewards_exp +): + proxy = uniform_product + proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) + rewards_exp = tfloat(rewards_exp, device=proxy.device, float_type=proxy.float) + assert all(torch.isclose(proxy._reward_function(proxy_values), rewards_exp)) + assert all(torch.isclose(proxy.proxy2reward(proxy_values), rewards_exp)) From f23deadc4826007933e28131c91933a66bd520ea Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 1 Apr 2024 12:29:21 -0400 Subject: [PATCH 11/73] Add identity proxy2reward --- gflownet/proxy/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index ef6d257a2..a3a147da8 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -117,6 +117,7 @@ def _get_reward_function(self, reward_function: Union[Callable, str], **kwargs): If reward_function is callable, it is returned as is. If it is a string, it must correspond to one of the following options: + - identity: the rewards are directly the proxy values. - pow(er): the rewards are the proxy values to the power of beta. See: :py:meth:`~gflownet.proxy.base._power()` - exp(onential) or boltzmann: the rewards are the negative exponential of @@ -142,7 +143,10 @@ def _get_reward_function(self, reward_function: Union[Callable, str], **kwargs): f"got {type(reward_function)} instead." ) - if reward_function.startswith("pow"): + if reward_function.startswith("identity"): + return lambda proxy_values: proxy_values + + elif reward_function.startswith("pow"): return Proxy._power(**kwargs) elif reward_function.startswith("exp") or reward_function == "boltzmann": From 615e08cc08e26f3938e34313f90aaba0350db95a Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 1 Apr 2024 13:24:24 -0400 Subject: [PATCH 12/73] WIP: handle log rewards --- gflownet/proxy/base.py | 51 +++++++++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index a3a147da8..3a12c9703 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -3,7 +3,7 @@ """ from abc import ABC, abstractmethod -from typing import Callable, List, Union +from typing import Callable, List, Tuple, Union import numpy as np import numpy.typing as npt @@ -29,7 +29,7 @@ def __init__( # Proxy to reward function self.reward_function = reward_function self.reward_function_kwargs = reward_function_kwargs - self._reward_function = self._get_reward_function( + self._reward_function, _logreward_function = self._get_reward_functions( reward_function, **reward_function_kwargs ) # Device @@ -109,10 +109,29 @@ def proxy2reward(self, proxy_values: TensorType) -> TensorType: """ return self._reward_function(proxy_values) - def _get_reward_function(self, reward_function: Union[Callable, str], **kwargs): + # TODO: consider adding option to clip values + def proxy2logreward(self, proxy_values: TensorType) -> TensorType: + """ + Transform a tensor of proxy values into log-rewards. + + Parameters + ---------- + proxy_values : tensor + The proxy values corresponding to a batch of states. + + Returns + ------- + tensor + The log-reward of all elements in the batch. + """ + return self._logreward_function(proxy_values) + + def _get_reward_functions( + self, reward_function: Union[Callable, str], **kwargs + ) -> Tuple[Callable, Callable]: r""" - Returns a callable corresponding to the function that transforms proxy values - into rewards. + Returns a tuple of callable corresponding to the function that transforms proxy + values into rewards and log-rewards. If reward_function is callable, it is returned as is. If it is a string, it must correspond to one of the following options: @@ -131,10 +150,17 @@ def _get_reward_function(self, reward_function: Union[Callable, str], **kwargs): ---------- reward_function : callable or str A callable or a string corresponding to one of the pre-defined functions. + + Returns + ------- + Callable + The function the transforms proxy values into rewards. + Callable + The function the transforms proxy values into log-rewards. """ # If reward_function is callable, return it if isinstance(reward_function, Callable): - return reward_function + return reward_function, lambda y: torch.log(reward_function) # Otherwise it must be a string if not isinstance(reward_function, str): @@ -144,19 +170,22 @@ def _get_reward_function(self, reward_function: Union[Callable, str], **kwargs): ) if reward_function.startswith("identity"): - return lambda proxy_values: proxy_values + return ( + lambda y: x, + lambda y: torch.log(x), + ) elif reward_function.startswith("pow"): - return Proxy._power(**kwargs) + return Proxy._power(**kwargs), Proxy._power(**kwargs) elif reward_function.startswith("exp") or reward_function == "boltzmann": - return Proxy._exponential(**kwargs) + return Proxy._exponential(**kwargs), Proxy._product(**kwargs) elif reward_function == "shift": - return Proxy._shift(**kwargs) + return Proxy._shift(**kwargs), Proxy._shift(**kwargs) elif reward_function.startswith("prod"): - return Proxy._product(**kwargs) + return Proxy._product(**kwargs), Proxy._product(**kwargs) else: raise ValueError( From e239d66c8a8cf628eb0eae61bb6759736971db59 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 1 Apr 2024 14:45:54 -0400 Subject: [PATCH 13/73] Handle log rewards properly --- gflownet/proxy/base.py | 21 +++- tests/gflownet/proxy/test_base.py | 201 +++++++++++++++++++++++++++--- 2 files changed, 200 insertions(+), 22 deletions(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 3a12c9703..3a46b8470 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -29,7 +29,7 @@ def __init__( # Proxy to reward function self.reward_function = reward_function self.reward_function_kwargs = reward_function_kwargs - self._reward_function, _logreward_function = self._get_reward_functions( + self._reward_function, self._logreward_function = self._get_reward_functions( reward_function, **reward_function_kwargs ) # Device @@ -171,21 +171,30 @@ def _get_reward_functions( if reward_function.startswith("identity"): return ( - lambda y: x, - lambda y: torch.log(x), + lambda x: x, + lambda x: torch.log(x), ) elif reward_function.startswith("pow"): - return Proxy._power(**kwargs), Proxy._power(**kwargs) + return ( + Proxy._power(**kwargs), + lambda x: torch.log(Proxy._power(**kwargs)(x)), + ) elif reward_function.startswith("exp") or reward_function == "boltzmann": return Proxy._exponential(**kwargs), Proxy._product(**kwargs) elif reward_function == "shift": - return Proxy._shift(**kwargs), Proxy._shift(**kwargs) + return ( + Proxy._shift(**kwargs), + lambda x: torch.log(Proxy._shift(**kwargs)(x)), + ) elif reward_function.startswith("prod"): - return Proxy._product(**kwargs), Proxy._product(**kwargs) + return ( + Proxy._product(**kwargs), + lambda x: torch.log(Proxy._product(**kwargs)(x)), + ) else: raise ValueError( diff --git a/tests/gflownet/proxy/test_base.py b/tests/gflownet/proxy/test_base.py index ab32c9ded..d7be012e9 100644 --- a/tests/gflownet/proxy/test_base.py +++ b/tests/gflownet/proxy/test_base.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import torch @@ -11,6 +12,16 @@ def uniform(): return Uniform(device="cpu", float_precision=32) +@pytest.fixture() +def uniform_identity(beta): + return Uniform( + reward_function="power", + reward_function_kwargs={"beta": beta}, + device="cpu", + float_precision=32, + ) + + @pytest.fixture() def uniform_power(beta): return Uniform( @@ -70,33 +81,118 @@ def test__uniform_proxy_initializes_without_errors(proxy, beta, request): assert True +def check_proxy2reward(rewards_computed, rewards_expected, atol=1e-3): + comp_nan = rewards_computed.isnan() + exp_nan = rewards_expected.isnan() + notnan_allclose = torch.all( + torch.isclose( + rewards_computed[~comp_nan], rewards_expected[~exp_nan], atol=atol + ) + ) + nan_equal = torch.equal(comp_nan, exp_nan) + return notnan_allclose, nan_equal + + @pytest.mark.parametrize( - "beta, proxy_values, rewards_exp", + "beta, proxy_values, rewards_exp, logrewards_exp", [ ( 1, [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], + [ + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + -np.inf, + -2.3025, + -0.6931, + 0.0, + 2.3025, + 4.6052, + ], + ), + ], +) +def test_reward_function_identity__behaves_as_expected( + uniform_identity, beta, proxy_values, rewards_exp, logrewards_exp +): + proxy = uniform_identity + proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) + # Rewards + rewards_exp = tfloat(rewards_exp, device=proxy.device, float_type=proxy.float) + assert all(check_proxy2reward(proxy._reward_function(proxy_values), rewards_exp)) + assert all(check_proxy2reward(proxy.proxy2reward(proxy_values), rewards_exp)) + # Log Rewards + logrewards_exp = tfloat(logrewards_exp, device=proxy.device, float_type=proxy.float) + assert all( + check_proxy2reward(proxy._logreward_function(proxy_values), logrewards_exp) + ) + assert all(check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp)) + + +@pytest.mark.parametrize( + "beta, proxy_values, rewards_exp, logrewards_exp", + [ + ( + 1, + [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], + [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], + [ + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + -np.inf, + -2.3025, + -0.6931, + 0.0, + 2.3025, + 4.6052, + ], ), ( 2, [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], [10000, 100, 1, 0.25, 0.01, 0.0, 0.01, 0.25, 1, 100, 10000], + [ + 9.2103, + 4.6052, + 0.0, + -1.3863, + -4.6052, + -np.inf, + -4.6052, + -1.3863, + 0.0, + 4.6052, + 9.2103, + ], ), ], ) def test_reward_function_power__behaves_as_expected( - uniform_power, beta, proxy_values, rewards_exp + uniform_power, beta, proxy_values, rewards_exp, logrewards_exp ): proxy = uniform_power proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) + # Rewards rewards_exp = tfloat(rewards_exp, device=proxy.device, float_type=proxy.float) - assert all(torch.isclose(proxy._reward_function(proxy_values), rewards_exp)) - assert all(torch.isclose(proxy.proxy2reward(proxy_values), rewards_exp)) + assert all(check_proxy2reward(proxy._reward_function(proxy_values), rewards_exp)) + assert all(check_proxy2reward(proxy.proxy2reward(proxy_values), rewards_exp)) + # Log Rewards + logrewards_exp = tfloat(logrewards_exp, device=proxy.device, float_type=proxy.float) + assert all( + check_proxy2reward(proxy._logreward_function(proxy_values), logrewards_exp) + ) + assert all(check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp)) @pytest.mark.parametrize( - "beta, proxy_values, rewards_exp", + "beta, proxy_values, rewards_exp, logrewards_exp", [ ( 1.0, @@ -112,6 +208,7 @@ def test_reward_function_power__behaves_as_expected( 2.7183, 22026.4648, ], + [-10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10], ), ( -1.0, @@ -127,66 +224,138 @@ def test_reward_function_power__behaves_as_expected( 3.6788e-01, 4.54e-05, ], + [10, 1, 0.5, 0.1, 0.0, -0.1, -0.5, -1, -10], ), ], ) def test_reward_function_exponential__behaves_as_expected( - uniform_exponential, beta, proxy_values, rewards_exp + uniform_exponential, beta, proxy_values, rewards_exp, logrewards_exp ): proxy = uniform_exponential proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) + # Rewards rewards_exp = tfloat(rewards_exp, device=proxy.device, float_type=proxy.float) + assert all(check_proxy2reward(proxy._reward_function(proxy_values), rewards_exp)) + assert all(check_proxy2reward(proxy.proxy2reward(proxy_values), rewards_exp)) + # Log Rewards + logrewards_exp = tfloat(logrewards_exp, device=proxy.device, float_type=proxy.float) assert all( - torch.isclose(proxy._reward_function(proxy_values), rewards_exp, atol=1e-4) + check_proxy2reward(proxy._logreward_function(proxy_values), logrewards_exp) ) - assert all(torch.isclose(proxy.proxy2reward(proxy_values), rewards_exp, atol=1e-4)) + assert all(check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp)) @pytest.mark.parametrize( - "beta, proxy_values, rewards_exp", + "beta, proxy_values, rewards_exp, logrewards_exp", [ ( 5, [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], [-95, -5, 4, 4.5, 4.9, 5.0, 5.1, 5.5, 6, 15, 105], + [ + np.nan, + np.nan, + 1.3863, + 1.5041, + 1.5892, + 1.6094, + 1.6292, + 1.7047, + 1.7918, + 2.7081, + 4.6540, + ], ), ( -5, [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], [-105, -15, -6, -5.5, -5.1, -5.0, -4.9, -4.5, -4, 5, 95], + [ + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + 1.6094, + 4.5539, + ], ), ], ) def test_reward_function_shift__behaves_as_expected( - uniform_shift, beta, proxy_values, rewards_exp + uniform_shift, beta, proxy_values, rewards_exp, logrewards_exp ): proxy = uniform_shift proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) + # Rewards rewards_exp = tfloat(rewards_exp, device=proxy.device, float_type=proxy.float) - assert all(torch.isclose(proxy._reward_function(proxy_values), rewards_exp)) - assert all(torch.isclose(proxy.proxy2reward(proxy_values), rewards_exp)) + assert all(check_proxy2reward(proxy._reward_function(proxy_values), rewards_exp)) + assert all(check_proxy2reward(proxy.proxy2reward(proxy_values), rewards_exp)) + # Log Rewards + logrewards_exp = tfloat(logrewards_exp, device=proxy.device, float_type=proxy.float) + assert all( + check_proxy2reward(proxy._logreward_function(proxy_values), logrewards_exp) + ) + assert all(check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp)) @pytest.mark.parametrize( - "beta, proxy_values, rewards_exp", + "beta, proxy_values, rewards_exp, logrewards_exp", [ ( 2, [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], [-200, -20, -2, -1.0, -0.2, 0.0, 0.2, 1.0, 2, 20, 200], + [ + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + -np.inf, + -1.6094, + 0.0, + 0.6931, + 2.9957, + 5.2983, + ], ), ( -2, [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], [200, 20, 2, 1.0, 0.2, 0.0, -0.2, -1.0, -2, -20, -200], + [ + 5.2983, + 2.9957, + 0.6931, + 0.0, + -1.6094, + -np.inf, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + ], ), ], ) def test_reward_function_product__behaves_as_expected( - uniform_product, beta, proxy_values, rewards_exp + uniform_product, beta, proxy_values, rewards_exp, logrewards_exp ): proxy = uniform_product proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) + # Rewards rewards_exp = tfloat(rewards_exp, device=proxy.device, float_type=proxy.float) - assert all(torch.isclose(proxy._reward_function(proxy_values), rewards_exp)) - assert all(torch.isclose(proxy.proxy2reward(proxy_values), rewards_exp)) + assert all(check_proxy2reward(proxy._reward_function(proxy_values), rewards_exp)) + assert all(check_proxy2reward(proxy.proxy2reward(proxy_values), rewards_exp)) + # Log Rewards + logrewards_exp = tfloat(logrewards_exp, device=proxy.device, float_type=proxy.float) + assert all( + check_proxy2reward(proxy._logreward_function(proxy_values), logrewards_exp) + ) + assert all(check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp)) From fbff130b979ef4aa6585463ed5d8efaf9e565b9a Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 1 Apr 2024 15:12:17 -0400 Subject: [PATCH 14/73] Implement functionality for passing callable instead of string --- gflownet/proxy/base.py | 24 +++-- tests/gflownet/proxy/test_base.py | 150 +++++++++++++++++++++++++----- 2 files changed, 145 insertions(+), 29 deletions(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 3a46b8470..20cf33046 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -3,7 +3,7 @@ """ from abc import ABC, abstractmethod -from typing import Callable, List, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import numpy.typing as npt @@ -22,15 +22,17 @@ def __init__( self, device, float_precision, - reward_function: Union[Callable, str] = "identity", - reward_function_kwargs: dict = {}, + reward_function: Optional[Union[Callable, str]] = "identity", + logreward_function: Optional[Callable] = None, + reward_function_kwargs: Optional[dict] = {}, **kwargs, ): # Proxy to reward function self.reward_function = reward_function + self.logreward_function = logreward_function self.reward_function_kwargs = reward_function_kwargs self._reward_function, self._logreward_function = self._get_reward_functions( - reward_function, **reward_function_kwargs + reward_function, logreward_function, **reward_function_kwargs ) # Device self.device = set_device(device) @@ -127,7 +129,10 @@ def proxy2logreward(self, proxy_values: TensorType) -> TensorType: return self._logreward_function(proxy_values) def _get_reward_functions( - self, reward_function: Union[Callable, str], **kwargs + self, + reward_function: Union[Callable, str], + logreward_function: Callable = None, + **kwargs, ) -> Tuple[Callable, Callable]: r""" Returns a tuple of callable corresponding to the function that transforms proxy @@ -150,6 +155,10 @@ def _get_reward_functions( ---------- reward_function : callable or str A callable or a string corresponding to one of the pre-defined functions. + reward_function : callable + A callable of the logreward function, meant to be used to compute the log + rewards in a more numerically stable way. None by default, in which case + the log of the reward function will be taken. Returns ------- @@ -160,7 +169,10 @@ def _get_reward_functions( """ # If reward_function is callable, return it if isinstance(reward_function, Callable): - return reward_function, lambda y: torch.log(reward_function) + if isinstance(logreward_function, Callable): + return reward_function, logreward_function + else: + return reward_function, lambda x: torch.log(reward_function(x)) # Otherwise it must be a string if not isinstance(reward_function, str): diff --git a/tests/gflownet/proxy/test_base.py b/tests/gflownet/proxy/test_base.py index d7be012e9..4f033c28a 100644 --- a/tests/gflownet/proxy/test_base.py +++ b/tests/gflownet/proxy/test_base.py @@ -13,7 +13,7 @@ def uniform(): @pytest.fixture() -def uniform_identity(beta): +def proxy_identity(beta): return Uniform( reward_function="power", reward_function_kwargs={"beta": beta}, @@ -23,7 +23,7 @@ def uniform_identity(beta): @pytest.fixture() -def uniform_power(beta): +def proxy_power(beta): return Uniform( reward_function="power", reward_function_kwargs={"beta": beta}, @@ -33,7 +33,7 @@ def uniform_power(beta): @pytest.fixture() -def uniform_exponential(beta): +def proxy_exponential(beta): return Uniform( reward_function="exponential", reward_function_kwargs={"beta": beta}, @@ -43,7 +43,7 @@ def uniform_exponential(beta): @pytest.fixture() -def uniform_shift(beta): +def proxy_shift(beta): return Uniform( reward_function="shift", reward_function_kwargs={"beta": beta}, @@ -53,7 +53,7 @@ def uniform_shift(beta): @pytest.fixture() -def uniform_product(beta): +def proxy_product(beta): return Uniform( reward_function="product", reward_function_kwargs={"beta": beta}, @@ -62,18 +62,28 @@ def uniform_product(beta): ) +@pytest.fixture() +def proxy_callable(reward_function, logreward_function): + return Uniform( + reward_function=reward_function, + logreward_function=logreward_function, + device="cpu", + float_precision=32, + ) + + @pytest.mark.parametrize( "proxy, beta", [ ("uniform", None), - ("uniform_power", 1), - ("uniform_power", 2), - ("uniform_exponential", 1), - ("uniform_exponential", -1), - ("uniform_shift", 5), - ("uniform_shift", -5), - ("uniform_product", 2), - ("uniform_product", -2), + ("proxy_power", 1), + ("proxy_power", 2), + ("proxy_exponential", 1), + ("proxy_exponential", -1), + ("proxy_shift", 5), + ("proxy_shift", -5), + ("proxy_product", 2), + ("proxy_product", -2), ], ) def test__uniform_proxy_initializes_without_errors(proxy, beta, request): @@ -81,6 +91,20 @@ def test__uniform_proxy_initializes_without_errors(proxy, beta, request): assert True +@pytest.mark.parametrize( + "proxy, reward_function, logreward_function", + [ + ("proxy_callable", lambda x: x + 1, None), + ("proxy_callable", lambda x: torch.exp(x - 1), lambda x: x - 1), + ], +) +def test__uniform_proxy_callable_initializes_without_errors( + proxy, reward_function, logreward_function, request +): + proxy = request.getfixturevalue(proxy) + assert True + + def check_proxy2reward(rewards_computed, rewards_expected, atol=1e-3): comp_nan = rewards_computed.isnan() exp_nan = rewards_expected.isnan() @@ -117,9 +141,9 @@ def check_proxy2reward(rewards_computed, rewards_expected, atol=1e-3): ], ) def test_reward_function_identity__behaves_as_expected( - uniform_identity, beta, proxy_values, rewards_exp, logrewards_exp + proxy_identity, beta, proxy_values, rewards_exp, logrewards_exp ): - proxy = uniform_identity + proxy = proxy_identity proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) # Rewards rewards_exp = tfloat(rewards_exp, device=proxy.device, float_type=proxy.float) @@ -175,9 +199,9 @@ def test_reward_function_identity__behaves_as_expected( ], ) def test_reward_function_power__behaves_as_expected( - uniform_power, beta, proxy_values, rewards_exp, logrewards_exp + proxy_power, beta, proxy_values, rewards_exp, logrewards_exp ): - proxy = uniform_power + proxy = proxy_power proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) # Rewards rewards_exp = tfloat(rewards_exp, device=proxy.device, float_type=proxy.float) @@ -229,9 +253,9 @@ def test_reward_function_power__behaves_as_expected( ], ) def test_reward_function_exponential__behaves_as_expected( - uniform_exponential, beta, proxy_values, rewards_exp, logrewards_exp + proxy_exponential, beta, proxy_values, rewards_exp, logrewards_exp ): - proxy = uniform_exponential + proxy = proxy_exponential proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) # Rewards rewards_exp = tfloat(rewards_exp, device=proxy.device, float_type=proxy.float) @@ -287,9 +311,9 @@ def test_reward_function_exponential__behaves_as_expected( ], ) def test_reward_function_shift__behaves_as_expected( - uniform_shift, beta, proxy_values, rewards_exp, logrewards_exp + proxy_shift, beta, proxy_values, rewards_exp, logrewards_exp ): - proxy = uniform_shift + proxy = proxy_shift proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) # Rewards rewards_exp = tfloat(rewards_exp, device=proxy.device, float_type=proxy.float) @@ -345,9 +369,89 @@ def test_reward_function_shift__behaves_as_expected( ], ) def test_reward_function_product__behaves_as_expected( - uniform_product, beta, proxy_values, rewards_exp, logrewards_exp + proxy_product, beta, proxy_values, rewards_exp, logrewards_exp +): + proxy = proxy_product + proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) + # Rewards + rewards_exp = tfloat(rewards_exp, device=proxy.device, float_type=proxy.float) + assert all(check_proxy2reward(proxy._reward_function(proxy_values), rewards_exp)) + assert all(check_proxy2reward(proxy.proxy2reward(proxy_values), rewards_exp)) + # Log Rewards + logrewards_exp = tfloat(logrewards_exp, device=proxy.device, float_type=proxy.float) + assert all( + check_proxy2reward(proxy._logreward_function(proxy_values), logrewards_exp) + ) + assert all(check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp)) + + +@pytest.mark.parametrize( + "reward_function, logreward_function, proxy_values, rewards_exp, logrewards_exp", + [ + ( + lambda x: x + 1, + None, + [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], + [-99, -9, 0, 0.5, 0.9, 1.0, 1.1, 1.5, 2, 11, 101], + [ + np.nan, + np.nan, + -np.inf, + -0.6931, + -0.1054, + 0.0, + 0.0953, + 0.4055, + 0.6931, + 2.3979, + 4.6151, + ], + ), + ( + lambda x: torch.exp(x - 1), + lambda x: x - 1, + [-10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10], + [ + 1.6702e-05, + 1.3534e-01, + 2.2313e-01, + 3.3287e-01, + 3.6788e-01, + 4.0657e-01, + 6.0653e-01, + 1.0, + 8.1031e03, + ], + [-11, -2, -1.5, -1.1, -1, -0.9, -0.5, 0, 9], + ), + ( + lambda x: torch.exp(x - 1), + None, + [-10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10], + [ + 1.6702e-05, + 1.3534e-01, + 2.2313e-01, + 3.3287e-01, + 3.6788e-01, + 4.0657e-01, + 6.0653e-01, + 1.0, + 8.1031e03, + ], + [-11, -2, -1.5, -1.1, -1, -0.9, -0.5, 0, 9], + ), + ], +) +def test_reward_function_callable__behaves_as_expected( + proxy_callable, + reward_function, + logreward_function, + proxy_values, + rewards_exp, + logrewards_exp, ): - proxy = uniform_product + proxy = proxy_callable proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) # Rewards rewards_exp = tfloat(rewards_exp, device=proxy.device, float_type=proxy.float) From c4144d0ec75cb60ad4c0efafcf5c5bc0b2452186 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 1 Apr 2024 15:16:04 -0400 Subject: [PATCH 15/73] log_rewards uses proxy2logreward --- gflownet/proxy/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 20cf33046..8ed1e4d7b 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -91,7 +91,7 @@ def log_rewards(self, states: Union[TensorType, List, npt.NDArray]) -> TensorTyp tensor The log reward of all elements in the batch. """ - return torch.log(self.proxy2reward(self(states))) + return self.proxy2logreward(self(states)) # TODO: consider adding option to clip values # TODO: check that rewards are non-negative From ecc9c56f26d2e32a34e825fd8cd8fd09d38d2004 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 1 Apr 2024 17:12:07 -0400 Subject: [PATCH 16/73] Add min reward as attribute --- gflownet/proxy/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 8ed1e4d7b..96d978b50 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -25,6 +25,7 @@ def __init__( reward_function: Optional[Union[Callable, str]] = "identity", logreward_function: Optional[Callable] = None, reward_function_kwargs: Optional[dict] = {}, + reward_min: float = 1e-8, **kwargs, ): # Proxy to reward function @@ -34,6 +35,8 @@ def __init__( self._reward_function, self._logreward_function = self._get_reward_functions( reward_function, logreward_function, **reward_function_kwargs ) + self.reward_min = reward_min + self.logreward_min = torch.log(logreward_min) # Device self.device = set_device(device) # Float precision From 4c53cfb3f2ef0c4d262b9000e84bba227f1f72c7 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 1 Apr 2024 19:04:21 -0400 Subject: [PATCH 17/73] Add absolute proxy-to-reward function; integrate log_rewards() into rewards() via a log parameter; fixes --- gflownet/proxy/base.py | 51 +++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 28 deletions(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 96d978b50..b45311143 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -22,7 +22,7 @@ def __init__( self, device, float_precision, - reward_function: Optional[Union[Callable, str]] = "identity", + reward_function: Optional[Union[Callable, str]] = "absolute", logreward_function: Optional[Callable] = None, reward_function_kwargs: Optional[dict] = {}, reward_min: float = 1e-8, @@ -36,7 +36,7 @@ def __init__( reward_function, logreward_function, **reward_function_kwargs ) self.reward_min = reward_min - self.logreward_min = torch.log(logreward_min) + self.logreward_min = np.log(reward_min) # Device self.device = set_device(device) # Float precision @@ -57,7 +57,9 @@ def __call__(self, states: Union[TensorType, List, npt.NDArray]) -> TensorType: """ pass - def rewards(self, states: Union[TensorType, List, npt.NDArray]) -> TensorType: + def rewards( + self, states: Union[TensorType, List, npt.NDArray], log: bool = False + ) -> TensorType: """ Computes the rewards of a batch of states. @@ -68,33 +70,19 @@ def rewards(self, states: Union[TensorType, List, npt.NDArray]) -> TensorType: ---------- states : tensor or list or array A batch of states in proxy format. + log : bool + If True, returns the logarithm of the rewards. If False (default), returns + the natural rewards. Returns ------- tensor The reward of all elements in the batch. """ - return self.proxy2reward(self(states)) - - def log_rewards(self, states: Union[TensorType, List, npt.NDArray]) -> TensorType: - """ - Computes the log(rewards) of a batch of states. - - The rewards are computed by first calling the proxy function, then - transforming the proxy values according to the reward function, then taking the - logarithm. - - Parameters - ---------- - states : tensor or list or array - A batch of states in proxy format. - - Returns - ------- - tensor - The log reward of all elements in the batch. - """ - return self.proxy2logreward(self(states)) + if log: + return self.proxy2logreward(self(states)) + else: + return self.proxy2reward(self(states)) # TODO: consider adding option to clip values # TODO: check that rewards are non-negative @@ -144,7 +132,8 @@ def _get_reward_functions( If reward_function is callable, it is returned as is. If it is a string, it must correspond to one of the following options: - - identity: the rewards are directly the proxy values. + - id(entity): the rewards are directly the proxy values. + - abs(olute): the rewards are the absolute value of the proxy values. - pow(er): the rewards are the proxy values to the power of beta. See: :py:meth:`~gflownet.proxy.base._power()` - exp(onential) or boltzmann: the rewards are the negative exponential of @@ -184,12 +173,18 @@ def _get_reward_functions( f"got {type(reward_function)} instead." ) - if reward_function.startswith("identity"): + if reward_function.startswith("id"): return ( lambda x: x, lambda x: torch.log(x), ) + if reward_function.startswith("abs"): + return ( + lambda x: torch.abs(x), + lambda x: torch.log(torch.abs(x)), + ) + elif reward_function.startswith("pow"): return ( Proxy._power(**kwargs), @@ -213,8 +208,8 @@ def _get_reward_functions( else: raise ValueError( - "reward_function must be one of: pow(er), exp(onential), shift, " - f"prod(uct). Received {reward_function} instead." + "reward_function must be one of: id(entity), abs(olute) pow(er), " + f"exp(onential), shift, prod(uct). Received {reward_function} instead." ) @staticmethod From 34275f35eddded42f4f274eef40b4c88e3ba0734 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 1 Apr 2024 19:05:01 -0400 Subject: [PATCH 18/73] Make default base proxy config file and set it as default for other proxy config files --- config/proxy/base.yaml | 19 +++++++++++++++++++ config/proxy/corners.yaml | 3 +++ config/proxy/crystals/dave.yaml | 3 +++ config/proxy/crystals/lattice_parameters.yaml | 3 +++ config/proxy/crystals/spacegroup.yaml | 3 +++ config/proxy/length.yaml | 3 +++ config/proxy/molecule.yaml | 5 ++++- config/proxy/scrabble.yaml | 3 +++ config/proxy/tetris.yaml | 3 +++ config/proxy/torus.yaml | 3 +++ config/proxy/tree.yaml | 3 +++ config/proxy/uniform.yaml | 3 +++ 12 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 config/proxy/base.yaml diff --git a/config/proxy/base.yaml b/config/proxy/base.yaml new file mode 100644 index 000000000..8563bbe91 --- /dev/null +++ b/config/proxy/base.yaml @@ -0,0 +1,19 @@ +_target_: gflownet.proxy.base.Proxy + +# Reward function: string identifier of the proxy-to-reward function: +# - identity +# - absolute (default) +# - power +# - exponential +# - shift +# - product +# Alternatively, it can be a callable of the function itself. +reward_function: absolute +# A callable of the proxy-to-logreward function. +# None by default, which takes the log of the proxy-to-reward function +logreward_function: null +# Arguments of the proxy-to-reward function. +# The default functions use an argument with key beta +reward_function_kwargs: {} +# Minimum reward. Used to clip the rewards. +reward_min: 1e-8 diff --git a/config/proxy/corners.yaml b/config/proxy/corners.yaml index 081004490..546902b7d 100644 --- a/config/proxy/corners.yaml +++ b/config/proxy/corners.yaml @@ -1,3 +1,6 @@ +defaults: + - base + _target_: gflownet.proxy.corners.Corners mu: 0.75 diff --git a/config/proxy/crystals/dave.yaml b/config/proxy/crystals/dave.yaml index d61c938fa..0b174f396 100644 --- a/config/proxy/crystals/dave.yaml +++ b/config/proxy/crystals/dave.yaml @@ -1,3 +1,6 @@ +defaults: + - base + _target_: gflownet.proxy.crystals.dave.DAVE release: 0.3.4 diff --git a/config/proxy/crystals/lattice_parameters.yaml b/config/proxy/crystals/lattice_parameters.yaml index 06dacd5ba..eced8f81d 100644 --- a/config/proxy/crystals/lattice_parameters.yaml +++ b/config/proxy/crystals/lattice_parameters.yaml @@ -1,3 +1,6 @@ +defaults: + - base + _target_: gflownet.proxy.crystals.lattice_parameters.LatticeParameters min_value: -100 diff --git a/config/proxy/crystals/spacegroup.yaml b/config/proxy/crystals/spacegroup.yaml index a7fc9b1a9..24e9f90fe 100644 --- a/config/proxy/crystals/spacegroup.yaml +++ b/config/proxy/crystals/spacegroup.yaml @@ -1,3 +1,6 @@ +defaults: + - base + _target_: gflownet.proxy.crystals.spacegroup.SpaceGroup normalize: True diff --git a/config/proxy/length.yaml b/config/proxy/length.yaml index f63ead4c4..e771b7452 100644 --- a/config/proxy/length.yaml +++ b/config/proxy/length.yaml @@ -1,3 +1,6 @@ +defaults: + - base + _target_: gflownet.proxy.aptamers.Aptamers oracle_id: length diff --git a/config/proxy/molecule.yaml b/config/proxy/molecule.yaml index 1283c09e7..da818a84d 100644 --- a/config/proxy/molecule.yaml +++ b/config/proxy/molecule.yaml @@ -1,4 +1,7 @@ +defaults: + - base + _target_: gflownet.proxy.molecule.RFMoleculeEnergy path_to_model: './data/random_forest_reward_100.pkl' -url_to_model: 'https://drive.google.com/uc?id=1OpQNC8WWIsMh8K4olfSaQRFlj3emYThF' \ No newline at end of file +url_to_model: 'https://drive.google.com/uc?id=1OpQNC8WWIsMh8K4olfSaQRFlj3emYThF' diff --git a/config/proxy/scrabble.yaml b/config/proxy/scrabble.yaml index bbf818fed..ea1f4cda8 100644 --- a/config/proxy/scrabble.yaml +++ b/config/proxy/scrabble.yaml @@ -1,3 +1,6 @@ +defaults: + - base + _target_: gflownet.proxy.scrabble.ScrabbleScorer vocabulary_check: False diff --git a/config/proxy/tetris.yaml b/config/proxy/tetris.yaml index 9774e1dbb..4d323203a 100644 --- a/config/proxy/tetris.yaml +++ b/config/proxy/tetris.yaml @@ -1,3 +1,6 @@ +defaults: + - base + _target_: gflownet.proxy.tetris.Tetris normalize: True diff --git a/config/proxy/torus.yaml b/config/proxy/torus.yaml index 9739c85b0..7623badd0 100644 --- a/config/proxy/torus.yaml +++ b/config/proxy/torus.yaml @@ -1,3 +1,6 @@ +defaults: + - base + _target_: gflownet.proxy.torus.Torus normalize: True diff --git a/config/proxy/tree.yaml b/config/proxy/tree.yaml index 3aead4e73..cccc9dad1 100644 --- a/config/proxy/tree.yaml +++ b/config/proxy/tree.yaml @@ -1,3 +1,6 @@ +defaults: + - base + _target_: gflownet.proxy.tree.TreeProxy use_prior: False diff --git a/config/proxy/uniform.yaml b/config/proxy/uniform.yaml index b09272dea..5a034f850 100644 --- a/config/proxy/uniform.yaml +++ b/config/proxy/uniform.yaml @@ -1 +1,4 @@ +defaults: + - base + _target_: gflownet.proxy.uniform.Uniform From 911e0b79abfd97fbf24abccf43eb325b540473c8 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 1 Apr 2024 19:05:59 -0400 Subject: [PATCH 19/73] Adapt Batch to compute rewards from proxy and handle log rewards. Tests are WIP and everything needs further testing. --- gflownet/utils/batch.py | 163 +++++++++++++++++++++-------- tests/gflownet/utils/test_batch.py | 10 +- 2 files changed, 125 insertions(+), 48 deletions(-) diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index 1884fa08d..dacfb13ed 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -7,6 +7,7 @@ from torchtyping import TensorType from gflownet.envs.base import GFlowNetEnv +from gflownet.proxy.base import Proxy from gflownet.utils.common import ( concat_items, copy, @@ -37,20 +38,23 @@ class Batch: def __init__( self, env: Optional[GFlowNetEnv] = None, + proxy: Optional[Proxy] = None, device: Union[str, torch.device] = "cpu", float_type: Union[int, torch.dtype] = 32, ): """ + Arguments + --------- env : GFlowNetEnv An instance of the environment that will be used to form the batch. - + proxy : Proxy + An instance of a GFlowNet proxy that will be used to compute proxy values + and rewards. device : str or torch.device torch.device or string indicating the device to use ("cpu" or "cuda") - float_type : torch.dtype or int One of float torch.dtype or an int indicating the float precision (16, 32 or 64). - """ # Device self.device = set_device(device) @@ -65,6 +69,8 @@ def __init__( self.source = None self.conditional = None self.continuous = None + # Proxy + self.proxy = proxy # Initialize batch size 0 self.size = 0 # Initialize empty batch variables @@ -96,6 +102,9 @@ def __init__( self.rewards_available = False self.rewards_parents_available = False self.rewards_source_available = False + self.logrewards_available = False + self.logrewards_parents_available = False + self.logrewards_source_available = False def __len__(self): return self.size @@ -144,6 +153,12 @@ def set_env(self, env: GFlowNetEnv): self.conditional = self.env.conditional self.continuous = self.env.continuous + def set_proxy(self, proxy: Proxy): + """ + Sets the proxy, used to compute rewards from a batch of states. + """ + self.proxy = proxy + def add_to_batch( self, envs: List[GFlowNetEnv], @@ -241,6 +256,7 @@ def add_to_batch( self.parents_policy_available = False self.parents_all_available = False self.rewards_available = False + self.logrewards_available = False def get_n_trajectories(self) -> int: """ @@ -841,110 +857,158 @@ def _compute_masks_backward(self): def get_rewards( self, + log: bool = False, force_recompute: Optional[bool] = False, do_non_terminating: Optional[bool] = False, ) -> TensorType["n_states"]: """ Returns the rewards of all states in the batch (including not done). - Args - ---- + Parameters + ---------- + log : bool + If True, return the logarithm of the rewards. force_recompute : bool If True, the rewards are recomputed even if they are available. - do_non_terminating : bool - If True, compute the rewards of the non-terminating states instead of - assigning reward 0. + If True, compute the actual rewards of the non-terminating states. If + False, non-terminating states will be assigned reward 0. """ if self.rewards_available is False or force_recompute is True: - self._compute_rewards(do_non_terminating) - return self.rewards + self._compute_rewards(log, do_non_terminating) + if log: + return self.logrewards + else: + return self.rewards - def _compute_rewards(self, do_non_terminating: Optional[bool] = False): + def _compute_rewards( + self, log: bool = False, do_non_terminating: Optional[bool] = False + ): """ Computes rewards for all self.states by first converting the states into proxy format. The result is stored in self.rewards as a torch.tensor - Args - ---- + Parameters + ---------- + log : bool + If True, compute the logarithm of the rewards. do_non_terminating : bool If True, compute the rewards of the non-terminating states instead of assigning reward 0. """ if do_non_terminating: - self.rewards = self.env.proxy2reward(self.env.proxy(self.states2proxy())) + rewards = self.proxy.rewards(self.states2proxy(), log) else: - self.rewards = torch.zeros(len(self), dtype=self.float, device=self.device) + rewards = torch.zeros(len(self), dtype=self.float, device=self.device) done = self.get_done() if len(done) > 0: states_proxy_done = self.get_terminating_states(proxy=True) - self.rewards[done] = self.env.proxy2reward( - self.env.proxy(states_proxy_done) - ) - self.rewards_available = True + rewards[done] = self.proxy.rewards(states_proxy_done, log) + if log: + self.logrewards = rewards + self.logrewards_available = True + else: + self.rewards = rewards + self.rewards_available = True - def get_rewards_parents(self) -> TensorType["n_states"]: + def get_rewards_parents(self, log: bool = False) -> TensorType["n_states"]: """ Returns the rewards of all parents in the batch. + Parameters + ---------- + log : bool + If True, return the logarithm of the rewards. + Returns ------- - self.rewards_parents + self.rewards_parents or self.logrewards_parents A tensor containing the rewards of the parents of self.states. """ if not self.rewards_parents_available: - self._compute_rewards_parents() - return self.rewards_parents + self._compute_rewards_parents(log) + if log: + return self.logrewards_parents + else: + return self.rewards_parents - def _compute_rewards_parents(self): + def _compute_rewards_parents(self, log: bool = False): """ Computes the rewards of self.parents by reusing the rewards of the states (self.rewards). - Stores the result in self.rewards_parents. + Stores the result in self.rewards_parents or self.logrewards_parents. + + Parameters + ---------- + log : bool + If True, compute the logarithm of the rewards. """ # TODO: this may return zero rewards for all parents if before # rewards for states were computed with do_non_terminating=False - state_rewards = self.get_rewards(do_non_terminating=True) - self.rewards_parents = torch.zeros_like(state_rewards) + state_rewards = self.get_rewards(log=log, do_non_terminating=True) + rewards_parents = torch.zeros_like(state_rewards) parent_indices = self.get_parents_indices() parent_is_source = parent_indices == -1 - self.rewards_parents[~parent_is_source] = self.rewards[ + rewards_parents[~parent_is_source] = state_rewards[ parent_indices[~parent_is_source] ] - rewards_source = self.get_rewards_source() - self.rewards_parents[parent_is_source] = rewards_source[parent_is_source] - self.rewards_parents_available = True + rewards_source = self.get_rewards_source(log) + rewards_parents[parent_is_source] = rewards_source[parent_is_source] + if log: + self.logrewards_parents = rewards_parents + self.logrewards_parents_available = True + else: + self.rewards_parents = rewards_parents + self.rewards_parents_available = True - def get_rewards_source(self) -> TensorType["n_states"]: + def get_rewards_source(self, log: bool = False) -> TensorType["n_states"]: """ Returns rewards of the corresponding source states for each state in the batch. + Parameters + ---------- + log : bool + If True, return the logarithm of the rewards. + Returns ------- - self.rewards_source + self.rewards_source or self.logrewards_source A tensor containing the rewards the source states. """ if not self.rewards_source_available: - self._compute_rewards_source() - return self.rewards_source + self._compute_rewards_source(log) + if log: + return self.logrewards_source + else: + return self.rewards_source - def _compute_rewards_source(self): + def _compute_rewards_source(self, log: bool = False): """ Computes a tensor of length len(self.states) with the rewards of the corresponding source states. - Stores the result in self.rewards_source. + Stores the result in self.rewards_source or self.logrewards_source. + + Parameters + ---------- + log : bool + If True, compute the logarithm of the rewards. """ # This will not work if source is randomised if not self.conditional: source_proxy = self.env.state2proxy(self.env.source) - reward_source = self.env.proxy2reward(self.env.proxy(source_proxy)) - self.rewards_source = reward_source.expand(len(self)) + reward_source = self.proxy.rewards(source_proxy, log) + rewards_source = reward_source.expand(len(self)) else: raise NotImplementedError - self.rewards_source_available = True + if log: + self.logrewards_source = rewards_source + self.logrewards_source_available = True + else: + self.rewards_source = rewards_source + self.rewards_source_available = True def get_terminating_states( self, @@ -1013,6 +1077,7 @@ def get_terminating_states( def get_terminating_rewards( self, sort_by: str = "insertion", + log: bool = False, force_recompute: Optional[bool] = False, ) -> TensorType["n_trajectories"]: """ @@ -1021,15 +1086,16 @@ def get_terminating_rewards( (sort_by = "insert[ion]", default) or by trajectory index (sort_by = "traj[ectory]". - Args - ---- + Parameters + ---------- sort_by : str Indicates how to sort the output: - insert[ion]: sort by order of insertion (rewards of trajectories that reached the terminating state first come first) - traj[ectory]: sort by trajectory index (the order in the ordered dict self.trajectories) - + log : bool + If True, return the logarithm of the rewards. force_recompute : bool If True, the rewards are recomputed even if they are available. """ @@ -1040,9 +1106,12 @@ def get_terminating_rewards( else: raise ValueError("sort_by must be either insert[ion] or traj[ectory]") if self.rewards_available is False or force_recompute is True: - self._compute_rewards() + self._compute_rewards(log, do_non_terminating=False) done = self.get_done()[indices] - return self.rewards[indices][done] + if log: + return self.logrewards[indices][done] + else: + return self.rewards[indices][done] def get_actions_trajectories(self) -> List[List[Tuple]]: """ @@ -1166,6 +1235,10 @@ def merge(self, batches: List): self.rewards = extend(self.rewards, batch.rewards) else: self.rewards = None + if self.logrewards_available and batch.logrewards_available: + self.logrewards = extend(self.logrewards, batch.logrewards) + else: + self.logrewards = None assert self.is_valid() return self diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 803704ecc..15a6dbb14 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -240,11 +240,12 @@ def test__get_masks_backward__single_env_returns_expected(env, batch, request): def test__get_rewards__single_env_returns_expected(env, proxy, batch, request): env = request.getfixturevalue(env) proxy = request.getfixturevalue(proxy) + proxy.setup(env) env = env.reset() - env.proxy = proxy - env.setup_proxy() batch.set_env(env) + batch.set_proxy(proxy) + rewards_from_env = [] rewards = [] while not env.done: parent = env.state @@ -253,7 +254,10 @@ def test__get_rewards__single_env_returns_expected(env, proxy, batch, request): # Add to batch batch.add_to_batch([env], [action], [valid]) if valid: - rewards.append(env.reward()) + if env.done: + rewards.append(proxy.rewards(env.state2proxy())[0]) + else: + rewards.append(tfloat(0.0, float_type=batch.float, device=batch.device)) rewards_batch = batch.get_rewards() rewards = torch.stack(rewards) assert torch.equal( From 060a82e6bc76e45f809046cf48bbffc700969bee Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 1 Apr 2024 20:23:40 -0400 Subject: [PATCH 20/73] Implement get_min_reward() in base proxy. --- gflownet/proxy/base.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index b45311143..80fa017d2 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -36,7 +36,6 @@ def __init__( reward_function, logreward_function, **reward_function_kwargs ) self.reward_min = reward_min - self.logreward_min = np.log(reward_min) # Device self.device = set_device(device) # Float precision @@ -119,6 +118,26 @@ def proxy2logreward(self, proxy_values: TensorType) -> TensorType: """ return self._logreward_function(proxy_values) + def get_min_reward(self, log: bool = False) -> float: + """ + Returns the minimum value of the (log) reward, retrieved from self.reward_min. + + Parameters + ---------- + log : bool + If True, returns the logarithm of the minimum reward. If False (default), + returns the natural minimum reward. + + Returns + ------- + float + The mimnimum (log) reward. + """ + if log: + return np.log(self.reward_min) + else: + return self.reward_min + def _get_reward_functions( self, reward_function: Union[Callable, str], From 25cf127eb3b7150a17be99099a92b9cd932c2c89 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 1 Apr 2024 20:24:39 -0400 Subject: [PATCH 21/73] Use get_min_reward() in batch. --- gflownet/utils/batch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index dacfb13ed..3c9188efc 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -900,7 +900,9 @@ def _compute_rewards( if do_non_terminating: rewards = self.proxy.rewards(self.states2proxy(), log) else: - rewards = torch.zeros(len(self), dtype=self.float, device=self.device) + rewards = self.proxy.get_min_reward(log) * torch.ones( + len(self), dtype=self.float, device=self.device + ) done = self.get_done() if len(done) > 0: states_proxy_done = self.get_terminating_states(proxy=True) From 26fc8ea34371ecaf60cb5ace4f1815ae48c20e6a Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 1 Apr 2024 20:59:18 -0400 Subject: [PATCH 22/73] Adapt Batch tests --- tests/gflownet/utils/test_batch.py | 326 +++++++++++++++++++++++++---- 1 file changed, 281 insertions(+), 45 deletions(-) diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 15a6dbb14..2b3206fba 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -236,7 +236,6 @@ def test__get_masks_backward__single_env_returns_expected(env, batch, request): "env, proxy", [("grid2d", "corners"), ("tetris6x4", "tetris_score"), ("ctorus2d5l", "corners")], ) -# @pytest.mark.skip(reason="skip while developping other tests") def test__get_rewards__single_env_returns_expected(env, proxy, batch, request): env = request.getfixturevalue(env) proxy = request.getfixturevalue(proxy) @@ -245,7 +244,6 @@ def test__get_rewards__single_env_returns_expected(env, proxy, batch, request): batch.set_env(env) batch.set_proxy(proxy) - rewards_from_env = [] rewards = [] while not env.done: parent = env.state @@ -257,7 +255,13 @@ def test__get_rewards__single_env_returns_expected(env, proxy, batch, request): if env.done: rewards.append(proxy.rewards(env.state2proxy())[0]) else: - rewards.append(tfloat(0.0, float_type=batch.float, device=batch.device)) + rewards.append( + tfloat( + proxy.get_min_reward(), + float_type=batch.float, + device=batch.device, + ) + ) rewards_batch = batch.get_rewards() rewards = torch.stack(rewards) assert torch.equal( @@ -266,6 +270,45 @@ def test__get_rewards__single_env_returns_expected(env, proxy, batch, request): ), (rewards, rewards_batch) +@pytest.mark.repeat(N_REPETITIONS) +@pytest.mark.parametrize( + "env, proxy", + [("grid2d", "corners"), ("tetris6x4", "tetris_score"), ("ctorus2d5l", "corners")], +) +def test__get_logrewards__single_env_returns_expected(env, proxy, batch, request): + env = request.getfixturevalue(env) + proxy = request.getfixturevalue(proxy) + proxy.setup(env) + env = env.reset() + batch.set_env(env) + batch.set_proxy(proxy) + + logrewards = [] + while not env.done: + parent = env.state + # Sample random action + _, action, valid = env.step_random() + # Add to batch + batch.add_to_batch([env], [action], [valid]) + if valid: + if env.done: + logrewards.append(proxy.rewards(env.state2proxy(), log=True)[0]) + else: + logrewards.append( + tfloat( + proxy.get_min_reward(log=True), + float_type=batch.float, + device=batch.device, + ) + ) + logrewards_batch = batch.get_rewards(log=True) + logrewards = torch.stack(logrewards) + assert torch.equal( + logrewards_batch, + tfloat(logrewards, device=batch.device, float_type=batch.float), + ), (logrewards, logrewards_batch) + + @pytest.mark.repeat(N_REPETITIONS) @pytest.mark.parametrize( "env, proxy", @@ -276,16 +319,14 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ batch_size = BATCH_SIZE env_ref = request.getfixturevalue(env) proxy = request.getfixturevalue(proxy) - env_ref.proxy = proxy - env_ref.setup_proxy() + proxy.setup(env_ref) batch.set_env(env_ref) + batch.set_proxy(proxy) # Make list of envs envs = [] for idx in range(batch_size): env_aux = env_ref.copy().reset(idx) - env_aux.proxy = proxy - env_aux.setup_proxy() envs.append(env_aux) # Initialize empty lists for checks @@ -330,7 +371,16 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ env_parents, env_parents_a = env.get_parents() parents_all.extend(env_parents) parents_all_a.extend(env_parents_a) - rewards.append(env.reward()) + if env.done: + rewards.append(proxy.rewards(env.state2proxy())[0]) + else: + rewards.append( + tfloat( + proxy.get_min_reward(), + float_type=batch.float, + device=batch.device, + ) + ) traj_indices.append(env.id) state_indices.append(env.n_actions) if env.done: @@ -427,9 +477,9 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req batch_size = BATCH_SIZE env_ref = request.getfixturevalue(env) proxy = request.getfixturevalue(proxy) - env_ref.proxy = proxy - env_ref.setup_proxy() + proxy.setup(env_ref) batch.set_env(env_ref) + batch.set_proxy(proxy) # Sample terminating states and build list of envs x_batch = env_ref.get_random_terminating_states(n_states=batch_size) @@ -438,8 +488,6 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req env_aux = env_ref.copy().reset(idx) env_aux = env_aux.set_state(state=x, done=True) env_aux.n_actions = env_aux.get_max_traj_length() - env_aux.proxy = proxy - env_aux.setup_proxy() envs.append(env_aux) # Initialize empty lists for checks @@ -470,7 +518,12 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req mask_backward = env.get_mask_invalid_actions_backward() if not env.continuous: env_parents, env_parents_a = env.get_parents() - reward = env.reward() + if env.done: + reward = proxy.rewards(env.state2proxy())[0] + else: + reward = tfloat( + proxy.get_min_reward(), float_type=batch.float, device=batch.device + ) if env.done: states_term_sorted[env.id] = env.state # Sample random action @@ -585,9 +638,9 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques # Initialize fixtures and batch env_ref = request.getfixturevalue(env) proxy = request.getfixturevalue(proxy) - env_ref.proxy = proxy - env_ref.setup_proxy() + proxy.setup(env_ref) batch.set_env(env_ref) + batch.set_proxy(proxy) # Initialize empty lists for checks states = [] @@ -611,8 +664,6 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques envs = [] for idx in range(batch_size_forward): env_aux = env_ref.copy().reset(idx) - env_aux.proxy = proxy - env_aux.setup_proxy() envs.append(env_aux) states_term_sorted.extend([None for _ in range(batch_size_forward)]) @@ -644,7 +695,16 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques env_parents, env_parents_a = env.get_parents() parents_all.extend(env_parents) parents_all_a.extend(env_parents_a) - rewards.append(env.reward()) + if env.done: + rewards.append(proxy.rewards(env.state2proxy())[0]) + else: + rewards.append( + tfloat( + proxy.get_min_reward(), + float_type=batch.float, + device=batch.device, + ) + ) traj_indices.append(env.id) state_indices.append(env.n_actions) if env.done: @@ -683,7 +743,12 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques mask_backward = env.get_mask_invalid_actions_backward() if not env.continuous: env_parents, env_parents_a = env.get_parents() - reward = env.reward() + if env.done: + reward = proxy.rewards(env.state2proxy())[0] + else: + reward = tfloat( + proxy.get_min_reward(), float_type=batch.float, device=batch.device + ) if env.done: states_term_sorted[env.id] = env.state # Sample random action @@ -800,10 +865,9 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): # Initialize fixtures and batch env_ref = request.getfixturevalue(env) proxy = request.getfixturevalue(proxy) - env_ref.proxy = proxy - env_ref.setup_proxy() - batch_fw = Batch(env=env_ref) - batch_bw = Batch(env=env_ref) + proxy.setup(env_ref) + batch_fw = Batch(env=env_ref, proxy=proxy) + batch_bw = Batch(env=env_ref, proxy=proxy) # Initialize empty lists for checks states = [] @@ -827,8 +891,6 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): envs = [] for idx in range(batch_size_forward): env_aux = env_ref.copy().reset(idx) - env_aux.proxy = proxy - env_aux.setup_proxy() envs.append(env_aux) states_term_sorted.extend([None for _ in range(batch_size_forward)]) @@ -860,7 +922,16 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): env_parents, env_parents_a = env.get_parents() parents_all.extend(env_parents) parents_all_a.extend(env_parents_a) - rewards.append(env.reward()) + if env.done: + rewards.append(proxy.rewards(env.state2proxy())[0]) + else: + rewards.append( + tfloat( + proxy.get_min_reward(), + float_type=batch_fw.float, + device=batch_fw.device, + ) + ) traj_indices.append(env.id) state_indices.append(env.n_actions) if env.done: @@ -899,7 +970,14 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): mask_backward = env.get_mask_invalid_actions_backward() if not env.continuous: env_parents, env_parents_a = env.get_parents() - reward = env.reward() + if env.done: + reward = proxy.rewards(env.state2proxy())[0] + else: + reward = tfloat( + proxy.get_min_reward(), + float_type=batch_bw.float, + device=batch_bw.device, + ) if env.done: states_term_sorted[env.id + batch_size_forward] = env.state # Sample random action @@ -929,7 +1007,7 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): ### MERGE BATCHES ### - batch = Batch(env=env_ref) + batch = Batch(env=env_ref, proxy=proxy) batch = batch.merge([batch_fw, batch_bw]) ### CHECKS ### @@ -1208,16 +1286,15 @@ def test__make_indices_consecutive__multiplied_indices_become_consecutive( "env, proxy", [("grid2d", "corners"), ("tetris6x4", "tetris_score"), ("ctorus2d5l", "corners")], ) -# @pytest.mark.skip(reason="skip while developping other tests") def test__get_rewards__single_env_returns_expected_non_terminating( env, proxy, batch, request ): env = request.getfixturevalue(env) proxy = request.getfixturevalue(proxy) + proxy.setup(env) env = env.reset() - env.proxy = proxy - env.setup_proxy() batch.set_env(env) + batch.set_proxy(proxy) rewards = [] while not env.done: @@ -1227,7 +1304,7 @@ def test__get_rewards__single_env_returns_expected_non_terminating( # Add to batch batch.add_to_batch([env], [action], [valid]) if valid: - rewards.append(env.reward(do_non_terminating=True)) + rewards.append(proxy.rewards(env.state2proxy())[0]) rewards_batch = batch.get_rewards(do_non_terminating=True) rewards = torch.stack(rewards) assert torch.equal( @@ -1237,7 +1314,40 @@ def test__get_rewards__single_env_returns_expected_non_terminating( @pytest.mark.repeat(N_REPETITIONS) -# @pytest.mark.skip(reason="skip while developping other tests") +@pytest.mark.parametrize( + "env, proxy", + [("grid2d", "corners"), ("tetris6x4", "tetris_score"), ("ctorus2d5l", "corners")], +) +def test__get_logrewards__single_env_returns_expected_non_terminating( + env, proxy, batch, request +): + env = request.getfixturevalue(env) + proxy = request.getfixturevalue(proxy) + proxy.setup(env) + env = env.reset() + batch.set_env(env) + batch.set_proxy(proxy) + + logrewards = [] + while not env.done: + parent = env.state + # Sample random action + _, action, valid = env.step_random() + # Add to batch + batch.add_to_batch([env], [action], [valid]) + if valid: + logrewards.append(proxy.rewards(env.state2proxy(), log=True)[0]) + logrewards_batch = batch.get_rewards(log=True, do_non_terminating=True) + logrewards = torch.stack(logrewards) + assert torch.all( + torch.isclose( + logrewards_batch, + tfloat(logrewards, device=batch.device, float_type=batch.float), + ) + ), (logrewards, logrewards_batch) + + +@pytest.mark.repeat(N_REPETITIONS) @pytest.mark.parametrize( "env, proxy", [("grid2d", "corners"), ("tetris6x4", "tetris_score_norm")], @@ -1248,12 +1358,11 @@ def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminating( batch_size = BATCH_SIZE env_ref = request.getfixturevalue(env) proxy = request.getfixturevalue(proxy) + proxy.setup(env_ref) env_ref = env_ref.reset() - env_ref.proxy = proxy - env_ref.setup_proxy() - env_ref.reward_func = "boltzmann" batch.set_env(env_ref) + batch.set_proxy(proxy) # Make list of envs envs = [] @@ -1262,7 +1371,6 @@ def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminating( envs.append(env_aux) rewards = [] - proxy_values = [] # Iterate until envs is empty while envs: @@ -1277,8 +1385,7 @@ def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminating( # Add to iter lists actions_iter.append(action) valids_iter.append(valid) - rewards.append(env.reward(do_non_terminating=True)) - proxy_values.append(env.proxy(env.state2proxy(env.state))[0]) + rewards.append(proxy.rewards(env.state2proxy())[0]) # Add all envs, actions and valids to batch batch.add_to_batch(envs, actions_iter, valids_iter) # Remove done envs @@ -1295,6 +1402,61 @@ def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminating( ), rewards_batch +@pytest.mark.repeat(N_REPETITIONS) +@pytest.mark.parametrize( + "env, proxy", + [("grid2d", "corners"), ("tetris6x4", "tetris_score_norm")], +) +def test__get_logrewards_multiple_env_returns_expected_non_zero_non_terminating( + env, proxy, batch, request +): + batch_size = BATCH_SIZE + env_ref = request.getfixturevalue(env) + proxy = request.getfixturevalue(proxy) + proxy.setup(env_ref) + env_ref = env_ref.reset() + + batch.set_env(env_ref) + batch.set_proxy(proxy) + + # Make list of envs + envs = [] + for idx in range(batch_size): + env_aux = env_ref.copy().reset(idx) + envs.append(env_aux) + + logrewards = [] + + # Iterate until envs is empty + while envs: + actions_iter = [] + valids_iter = [] + # Make step env by env (different to GFN Agent) to have full control + for env in envs: + parent = copy(env.state) + # Sample random action + state, action, valid = env.step_random() + if valid: + # Add to iter lists + actions_iter.append(action) + valids_iter.append(valid) + logrewards.append(proxy.rewards(env.state2proxy(), log=True)[0]) + # Add all envs, actions and valids to batch + batch.add_to_batch(envs, actions_iter, valids_iter) + # Remove done envs + envs = [env for env in envs if not env.done] + + logrewards_batch = batch.get_rewards(log=True, do_non_terminating=True) + logrewards = torch.stack(logrewards) + assert torch.equal( + logrewards_batch, + tfloat(logrewards, device=batch.device, float_type=batch.float), + ), (logrewards, logrewards_batch) + assert ~torch.any( + torch.isclose(logrewards_batch, torch.zeros_like(logrewards_batch)) + ), logrewards_batch + + @pytest.mark.repeat(N_REPETITIONS) # @pytest.mark.skip(reason="skip while developping other tests") @pytest.mark.parametrize( @@ -1311,11 +1473,11 @@ def test__get_rewards_parents_multiple_env_returns_expected_non_terminating( batch_size = BATCH_SIZE env_ref = request.getfixturevalue(env) proxy = request.getfixturevalue(proxy) + proxy.setup(env_ref) env_ref = env_ref.reset() - env_ref.proxy = proxy - env_ref.setup_proxy() batch.set_env(env_ref) + batch.set_proxy(proxy) # Make list of envs envs = [] @@ -1341,10 +1503,8 @@ def test__get_rewards_parents_multiple_env_returns_expected_non_terminating( # Add to iter lists actions_iter.append(action) valids_iter.append(valid) - rewards_parents.append( - env.reward(state=parent, done=False, do_non_terminating=True) - ) - rewards.append(env.reward(do_non_terminating=True)) + rewards_parents.append(proxy.rewards(env.states2proxy([parent]))[0]) + rewards.append(proxy.rewards(env.state2proxy())[0]) # Add all envs, actions and valids to batch batch.add_to_batch(envs, actions_iter, valids_iter) # Remove done envs @@ -1369,3 +1529,79 @@ def test__get_rewards_parents_multiple_env_returns_expected_non_terminating( tfloat(rewards, device=batch.device, float_type=batch.float), ) ), (rewards, rewards_batch) + + +@pytest.mark.repeat(N_REPETITIONS) +# @pytest.mark.skip(reason="skip while developping other tests") +@pytest.mark.parametrize( + "env, proxy", + [ + ("grid2d", "corners"), + ("tetris6x4", "tetris_score_norm"), + ("ctorus2d5l", "corners"), + ], +) +def test__get_logrewards_parents_multiple_env_returns_expected_non_terminating( + env, proxy, batch, request +): + batch_size = BATCH_SIZE + env_ref = request.getfixturevalue(env) + proxy = request.getfixturevalue(proxy) + proxy.setup(env_ref) + env_ref = env_ref.reset() + + batch.set_env(env_ref) + batch.set_proxy(proxy) + + # Make list of envs + envs = [] + for idx in range(batch_size): + env_aux = env_ref.copy().reset(idx) + envs.append(env_aux) + + logrewards_parents = [] + logrewards = [] + + # Iterate until envs is empty + while envs: + actions_iter = [] + valids_iter = [] + # Make step env by env (different to GFN Agent) to have full control + for env in envs: + parent = copy(env.state) + assert env.done is False + + # Sample random action + state, action, valid = env.step_random() + if valid: + # Add to iter lists + actions_iter.append(action) + valids_iter.append(valid) + logrewards_parents.append( + proxy.rewards(env.states2proxy([parent]), log=True)[0] + ) + logrewards.append(proxy.rewards(env.state2proxy(), log=True)[0]) + # Add all envs, actions and valids to batch + batch.add_to_batch(envs, actions_iter, valids_iter) + # Remove done envs + envs = [env for env in envs if not env.done] + + logrewards_parents_batch = batch.get_rewards_parents(log=True) + logrewards_parents = torch.stack(logrewards_parents) + + logrewards_batch = batch.get_rewards(log=True, do_non_terminating=True) + logrewards = torch.stack(logrewards) + + assert torch.all( + torch.isclose( + logrewards_parents_batch, + tfloat(logrewards_parents, device=batch.device, float_type=batch.float), + ) + ), (logrewards_parents, logrewards_parents_batch) + + assert torch.all( + torch.isclose( + logrewards_batch, + tfloat(logrewards, device=batch.device, float_type=batch.float), + ) + ), (logrewards, logrewards_batch) From a9174eb1148aaacef4b807cc6bf708c9eedb1e56 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 1 Apr 2024 20:59:43 -0400 Subject: [PATCH 23/73] Handle clipping of rewards --- config/proxy/base.yaml | 4 +++- gflownet/proxy/base.py | 26 ++++++++++++++++++++++---- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/config/proxy/base.yaml b/config/proxy/base.yaml index 8563bbe91..1dd2b1f5f 100644 --- a/config/proxy/base.yaml +++ b/config/proxy/base.yaml @@ -16,4 +16,6 @@ logreward_function: null # The default functions use an argument with key beta reward_function_kwargs: {} # Minimum reward. Used to clip the rewards. -reward_min: 1e-8 +reward_min: 0.0 +# Flag to control whether rewards are clipped +do_clip_rewards: False diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 80fa017d2..addd09006 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -25,7 +25,8 @@ def __init__( reward_function: Optional[Union[Callable, str]] = "absolute", logreward_function: Optional[Callable] = None, reward_function_kwargs: Optional[dict] = {}, - reward_min: float = 1e-8, + reward_min: float = 0.0, + do_clip_rewards: bool = False, **kwargs, ): # Proxy to reward function @@ -36,6 +37,7 @@ def __init__( reward_function, logreward_function, **reward_function_kwargs ) self.reward_min = reward_min + self.do_clip_rewards = do_clip_rewards # Device self.device = set_device(device) # Float precision @@ -65,6 +67,10 @@ def rewards( The rewards are computed by first calling the proxy function, then transforming the proxy values according to the reward function. + If log is True, nan values are set to self.logreward_min. + + If do_clip_rewards is True, rewards are clipped to self.reward_min. + Parameters ---------- states : tensor or list or array @@ -79,9 +85,14 @@ def rewards( The reward of all elements in the batch. """ if log: - return self.proxy2logreward(self(states)) + logrewards = self.proxy2logreward(self(states)) + logrewards[logrewards.isnan()] = self.get_min_reward(log) + return logrewards else: - return self.proxy2reward(self(states)) + rewards = self.proxy2reward(self(states)) + if self.do_clip_rewards: + rewards = torch.clip(rewards, min=self.reward_min, max=None) + return rewards # TODO: consider adding option to clip values # TODO: check that rewards are non-negative @@ -122,6 +133,8 @@ def get_min_reward(self, log: bool = False) -> float: """ Returns the minimum value of the (log) reward, retrieved from self.reward_min. + If self.reward_min is exactly 0, then self.logreward_min is set to -inf. + Parameters ---------- log : bool @@ -134,7 +147,12 @@ def get_min_reward(self, log: bool = False) -> float: The mimnimum (log) reward. """ if log: - return np.log(self.reward_min) + if not hasattr(self, "logreward_min"): + if self.reward_min == 0.0: + self.logreward_min = -np.inf + else: + self.logreward_min = np.log(self.reward_min) + return self.logreward_min else: return self.reward_min From 99060938f763931b96c9f94a2989fcc59247f17c Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 1 Apr 2024 21:57:35 -0400 Subject: [PATCH 24/73] Add TODO --- gflownet/utils/batch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index 3c9188efc..2ae0b9842 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -855,6 +855,7 @@ def _compute_masks_backward(self): ].get_mask_invalid_actions_backward(state, done) self.masks_backward_available = True + # TODO: better handling of availability of rewards, logrewards, proxy_values. def get_rewards( self, log: bool = False, From 5d4476468294e4693856cc792bda2ad9fa79abc1 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 1 Apr 2024 21:58:39 -0400 Subject: [PATCH 25/73] Clip and replace nans in (log) rewards in proxy2reward and proxy2logreward instead of in rewards() --- gflownet/proxy/base.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index addd09006..7bbd2c9ce 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -67,10 +67,6 @@ def rewards( The rewards are computed by first calling the proxy function, then transforming the proxy values according to the reward function. - If log is True, nan values are set to self.logreward_min. - - If do_clip_rewards is True, rewards are clipped to self.reward_min. - Parameters ---------- states : tensor or list or array @@ -85,21 +81,16 @@ def rewards( The reward of all elements in the batch. """ if log: - logrewards = self.proxy2logreward(self(states)) - logrewards[logrewards.isnan()] = self.get_min_reward(log) - return logrewards + return self.proxy2logreward(self(states)) else: - rewards = self.proxy2reward(self(states)) - if self.do_clip_rewards: - rewards = torch.clip(rewards, min=self.reward_min, max=None) - return rewards + return self.proxy2reward(self(states)) - # TODO: consider adding option to clip values - # TODO: check that rewards are non-negative def proxy2reward(self, proxy_values: TensorType) -> TensorType: """ Transform a tensor of proxy values into rewards. + If do_clip_rewards is True, rewards are clipped to self.reward_min. + Parameters ---------- proxy_values : tensor @@ -110,13 +101,17 @@ def proxy2reward(self, proxy_values: TensorType) -> TensorType: tensor The reward of all elements in the batch. """ - return self._reward_function(proxy_values) + rewards = self._reward_function(proxy_values) + if self.do_clip_rewards: + rewards = torch.clip(rewards, min=self.reward_min, max=None) + return rewards - # TODO: consider adding option to clip values def proxy2logreward(self, proxy_values: TensorType) -> TensorType: """ Transform a tensor of proxy values into log-rewards. + NaN values are set to self.logreward_min. + Parameters ---------- proxy_values : tensor @@ -127,7 +122,9 @@ def proxy2logreward(self, proxy_values: TensorType) -> TensorType: tensor The log-reward of all elements in the batch. """ - return self._logreward_function(proxy_values) + logrewards = self._logreward_function(proxy_values) + logrewards[logrewards.isnan()] = self.get_min_reward(log=True) + return logrewards def get_min_reward(self, log: bool = False) -> float: """ From 7ff0206c729c21e0699b2c7c08f6434a22f156b2 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 1 Apr 2024 21:59:36 -0400 Subject: [PATCH 26/73] Adapt gflownet.py --- gflownet/gflownet.py | 95 +++++++++++++++++++++++++++++++------------- 1 file changed, 67 insertions(+), 28 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index dae87edd6..8f361e21f 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -38,6 +38,7 @@ class GFlowNetAgent: def __init__( self, env_maker, + proxy, seed, device, float_precision, @@ -68,6 +69,8 @@ def __init__( # Environment self.env_maker = env_maker self.env = self.env_maker() + # Proxy + self.proxy = proxy # Continuous environments self.continuous = hasattr(self.env, "continuous") and self.env.continuous if self.continuous and optimizer.loss in ["flowmatch", "flowmatching"]: @@ -437,7 +440,9 @@ def sample_batch( # ON-POLICY FORWARD trajectories t0_forward = time.time() envs = [self.env_maker().set_id(idx) for idx in range(n_forward)] - batch_forward = Batch(env=self.env, device=self.device, float_type=self.float) + batch_forward = Batch( + env=self.env, proxy=self.proxy, device=self.device, float_type=self.float + ) while envs: # Sample actions t0_a_envs = time.time() @@ -459,7 +464,9 @@ def sample_batch( # TRAIN BACKWARD trajectories t0_train = time.time() envs = [self.env_maker().set_id(idx) for idx in range(n_train)] - batch_train = Batch(env=self.env, device=self.device, float_type=self.float) + batch_train = Batch( + env=self.env, proxy=self.proxy, device=self.device, float_type=self.float + ) if n_train > 0 and self.buffer.train_pkl is not None: with open(self.buffer.train_pkl, "rb") as f: dict_train = pickle.load(f) @@ -490,7 +497,9 @@ def sample_batch( # REPLAY BACKWARD trajectories t0_replay = time.time() - batch_replay = Batch(env=self.env, device=self.device, float_type=self.float) + batch_replay = Batch( + env=self.env, proxy=self.proxy, device=self.device, float_type=self.float + ) if n_replay > 0 and self.buffer.replay_pkl is not None: with open(self.buffer.replay_pkl, "rb") as f: dict_replay = pickle.load(f) @@ -608,8 +617,7 @@ def flowmatch_loss(self, it, batch): done = batch.get_done() masks_sf = batch.get_masks_forward() parents_a_idx = self.env.actions2indices(parents_actions) - rewards = batch.get_rewards() - assert torch.all(rewards[done] > 0) + logrewards = batch.get_rewards(log=True) # In-flows inflow_logits = torch.full( (states.shape[0], self.env.policy_output_dim), @@ -626,7 +634,7 @@ def flowmatch_loss(self, it, batch): outflow_logits[masks_sf] = -torch.inf outflow = torch.logsumexp(outflow_logits, dim=1) # Loss at terminating nodes - loss_term = (inflow[done] - torch.log(rewards[done])).pow(2).mean() + loss_term = (inflow[done] - logrewards[done]).pow(2).mean() contrib_term = done.eq(1).to(self.float).mean() # Loss at intermediate nodes loss_interm = (inflow[~done] - outflow[~done]).pow(2).mean() @@ -661,14 +669,10 @@ def trajectorybalance_loss(self, it, batch): logprobs_f = self.compute_logprobs_trajectories(batch, backward=False) logprobs_b = self.compute_logprobs_trajectories(batch, backward=True) # Get rewards from batch - rewards = batch.get_terminating_rewards(sort_by="trajectory") + logrewards = batch.get_terminating_rewards(log=True, sort_by="trajectory") # Trajectory balance loss - loss = ( - (self.logZ.sum() + logprobs_f - logprobs_b - torch.log(rewards)) - .pow(2) - .mean() - ) + loss = (self.logZ.sum() + logprobs_f - logprobs_b - logrewards).pow(2).mean() return loss, loss, loss def detailedbalance_loss(self, it, batch): @@ -704,7 +708,7 @@ def detailedbalance_loss(self, it, batch): parents = batch.get_parents(policy=False) parents_policy = batch.get_parents(policy=True) done = batch.get_done() - rewards = batch.get_terminating_rewards(sort_by="insertion") + logrewards = batch.get_terminating_rewards(log=True, sort_by="insertion") # Get logprobs masks_f = batch.get_masks_forward(of_parents=True) @@ -720,7 +724,7 @@ def detailedbalance_loss(self, it, batch): # Get logflows logflows_states = self.state_flow(states_policy) - logflows_states[done.eq(1)] = torch.log(rewards) + logflows_states[done.eq(1)] = logrewards # TODO: Optimise by reusing logflows_states and batch.get_parent_indices logflows_parents = self.state_flow(parents_policy) @@ -763,8 +767,8 @@ def forwardlooking_loss(self, it, batch): actions = batch.get_actions() parents = batch.get_parents(policy=False) parents_policy = batch.get_parents(policy=True) - rewards_states = batch.get_rewards(do_non_terminating=True) - rewards_parents = batch.get_rewards_parents() + logrewards_states = batch.get_rewards(log=True, do_non_terminating=True) + logrewards_parents = batch.get_rewards_parents(log=True) done = batch.get_done() # Get logprobs @@ -787,7 +791,7 @@ def forwardlooking_loss(self, it, batch): logflflows_parents = self.state_flow(parents_policy) # Get energies transitions - energies_transitions = torch.log(rewards_parents) - torch.log(rewards_states) + energies_transitions = logrewards_parents - logrewards_states # Forward-looking loss loss_all = ( @@ -910,7 +914,12 @@ def estimate_logprobs_data( ) pbar = tqdm(total=n_states) while init_batch < n_states: - batch = Batch(env=self.env, device=self.device, float_type=self.float) + batch = Batch( + env=self.env, + proxy=self.proxy, + device=self.device, + float_type=self.float, + ) # Create an environment for each data point and trajectory and set the state envs = [] for state_idx in range(init_batch, end_batch): @@ -1024,7 +1033,12 @@ def train(self): self.logger.log_metrics(metrics, use_context=self.use_context, step=it) self.logger.log_summary(summary) t0_iter = time.time() - batch = Batch(env=self.env, device=self.device, float_type=self.float) + batch = Batch( + env=self.env, + proxy=self.proxy, + device=self.device, + float_type=self.float, + ) for j in range(self.sttr): sub_batch, times = self.sample_batch( n_forward=self.batch_size.forward, @@ -1065,11 +1079,26 @@ def train(self): all_losses.append([i.item() for i in losses]) # Buffer t0_buffer = time.time() + # TODO: the current implementation recomputes the proxy values of the + # terminating states in order to store the proxy values in the Buffer. + # Depending on the computational cost of the proxy, this may be very + # inneficient. For example, proxy.rewards() could return the proxy values, + # which could be stored in the Batch. + if it == 0: + print( + "IMPORTANT: The current implementation recomputes the proxy " + "values of the terminating states in order to store the proxy " + "values in the Buffer. Depending on the computational cost of " + "the proxy, this may be very inneficient." + ) states_term = batch.get_terminating_states(sort_by="trajectory") - rewards = batch.get_terminating_rewards(sort_by="trajectory") - actions_trajectories = batch.get_actions_trajectories() - proxy_vals = self.env.reward2proxy(rewards).tolist() + states_proxy_term = batch.get_terminating_states( + proxy=True, sort_by="trajectory" + ) + proxy_vals = self.proxy(states_proxy_term) + rewards = self.proxy.proxy2reward(proxy_vals) rewards = rewards.tolist() + actions_trajectories = batch.get_actions_trajectories() self.buffer.add(states_term, actions_trajectories, rewards, proxy_vals, it) self.buffer.add( states_term, @@ -1180,7 +1209,7 @@ def test(self, **plot_kwargs): ) mean_logprobs_std = logprobs_std.mean().item() mean_probs_std = probs_std.mean().item() - rewards_x_tt = self.env.reward_batch(x_tt) + rewards_x_tt = self.proxy.rewards(self.env.states2proxy(x_tt)) corr_prob_traj_rewards = np.corrcoef( np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt )[0, 1] @@ -1200,9 +1229,9 @@ def test(self, **plot_kwargs): if "density_true" in dict_tt: density_true = dict_tt["density_true"] else: - rewards = self.env.reward_batch(x_tt) - z_true = rewards.sum() - density_true = rewards / z_true + # This is hacky but it is re-done anyway in the Evaluator class. + z_true = rewards_x_tt.sum().item() + density_true = rewards_x_tt.numpy() / z_true with open(self.buffer.test_pkl, "wb") as f: dict_tt["density_true"] = density_true pickle.dump(dict_tt, f) @@ -1337,7 +1366,12 @@ def test_top_k(self, it, progress=False, gfn_states=None, random_states=None): print() if not gfn_states: # sample states from the current gfn - batch = Batch(env=self.env, device=self.device, float_type=self.float) + batch = Batch( + env=self.env, + proxy=self.proxy, + device=self.device, + float_type=self.float, + ) self.random_action_prob = 0 t = time.time() print("Sampling from GFN...", end="\r") @@ -1360,7 +1394,12 @@ def test_top_k(self, it, progress=False, gfn_states=None, random_states=None): if do_random: # sample random states from uniform actions if not random_states: - batch = Batch(env=self.env, device=self.device, float_type=self.float) + batch = Batch( + env=self.env, + proxy=self.proxy, + device=self.device, + float_type=self.float, + ) self.random_action_prob = 1.0 print("[test_top_k] Sampling at random...", end="\r") for b in batch_with_rest( From 84dbf463fcd7dee8c2d1671488b68cf53a9b6ae4 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 1 Apr 2024 21:59:44 -0400 Subject: [PATCH 27/73] Adapt main.py --- main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 506752c66..a1c2720bf 100644 --- a/main.py +++ b/main.py @@ -82,6 +82,7 @@ def main(config): device=config.device, float_precision=config.float_precision, env_maker=env_maker, + proxy=proxy, forward_policy=forward_policy, backward_policy=backward_policy, state_flow=state_flow, @@ -96,7 +97,7 @@ def main(config): if config.n_samples > 0 and config.n_samples <= 1e5: batch, times = gflownet.sample_batch(n_forward=config.n_samples, train=False) x_sampled = batch.get_terminating_states(proxy=True) - energies = env.proxy(x_sampled) + energies = proxy(x_sampled) x_sampled = batch.get_terminating_states() df = pd.DataFrame( { From 51939bb4f2667e71db6713cf400e56709e69bc5f Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 1 Apr 2024 22:00:34 -0400 Subject: [PATCH 28/73] Replace torch.equal by isclose in Batch tests of rewards --- tests/gflownet/utils/test_batch.py | 56 +++++++++++++++++++----------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 2b3206fba..6d59d68e8 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -451,9 +451,11 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ # Check rewards rewards_batch = batch.get_rewards() rewards = torch.stack(rewards) - assert torch.equal( - rewards_batch, - tfloat(rewards, device=batch.device, float_type=batch.float), + assert torch.all( + torch.isclose( + rewards_batch, + tfloat(rewards, device=batch.device, float_type=batch.float), + ) ), (rewards, rewards_batch) # Check terminating states (sorted by trajectory) states_term_batch = batch.get_terminating_states(sort_by="traj") @@ -612,9 +614,11 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req # Check rewards rewards_batch = batch.get_rewards() rewards = torch.stack(rewards) - assert torch.equal( - rewards_batch, - tfloat(rewards, device=batch.device, float_type=batch.float), + assert torch.all( + torch.isclose( + rewards_batch, + tfloat(rewards, device=batch.device, float_type=batch.float), + ) ), (rewards, rewards_batch) # Check terminating states (sorted by trajectory) states_term_batch = batch.get_terminating_states(sort_by="traj") @@ -839,9 +843,11 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques # Check rewards rewards_batch = batch.get_rewards() rewards = torch.stack(rewards) - assert torch.equal( - rewards_batch, - tfloat(rewards, device=batch.device, float_type=batch.float), + assert torch.all( + torch.isclose( + rewards_batch, + tfloat(rewards, device=batch.device, float_type=batch.float), + ) ), (rewards, rewards_batch) # Check terminating states (sorted by trajectory) states_term_batch = batch.get_terminating_states(sort_by="traj") @@ -1073,9 +1079,11 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): # Check rewards rewards_batch = batch.get_rewards() rewards = torch.stack(rewards) - assert torch.equal( - rewards_batch, - tfloat(rewards, device=batch.device, float_type=batch.float), + assert torch.all( + torch.isclose( + rewards_batch, + tfloat(rewards, device=batch.device, float_type=batch.float), + ) ), (rewards, rewards_batch) # Check terminating states (sorted by trajectory) states_term_batch = batch.get_terminating_states(sort_by="traj") @@ -1307,9 +1315,11 @@ def test__get_rewards__single_env_returns_expected_non_terminating( rewards.append(proxy.rewards(env.state2proxy())[0]) rewards_batch = batch.get_rewards(do_non_terminating=True) rewards = torch.stack(rewards) - assert torch.equal( - rewards_batch, - tfloat(rewards, device=batch.device, float_type=batch.float), + assert torch.all( + torch.isclose( + rewards_batch, + tfloat(rewards, device=batch.device, float_type=batch.float), + ) ), (rewards, rewards_batch) @@ -1393,9 +1403,11 @@ def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminating( rewards_batch = batch.get_rewards(do_non_terminating=True) rewards = torch.stack(rewards) - assert torch.equal( - rewards_batch, - tfloat(rewards, device=batch.device, float_type=batch.float), + assert torch.all( + torch.isclose( + rewards_batch, + tfloat(rewards, device=batch.device, float_type=batch.float), + ) ), (rewards, rewards_batch) assert ~torch.any( torch.isclose(rewards_batch, torch.zeros_like(rewards_batch)) @@ -1448,9 +1460,11 @@ def test__get_logrewards_multiple_env_returns_expected_non_zero_non_terminating( logrewards_batch = batch.get_rewards(log=True, do_non_terminating=True) logrewards = torch.stack(logrewards) - assert torch.equal( - logrewards_batch, - tfloat(logrewards, device=batch.device, float_type=batch.float), + assert torch.all( + torch.isclose( + logrewards_batch, + tfloat(logrewards, device=batch.device, float_type=batch.float), + ) ), (logrewards, logrewards_batch) assert ~torch.any( torch.isclose(logrewards_batch, torch.zeros_like(logrewards_batch)) From 8a7f6ab199ccc540a3eca65a88b96c92a62ef927 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 1 Apr 2024 22:04:08 -0400 Subject: [PATCH 29/73] Adapt env common tests --- tests/gflownet/envs/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/gflownet/envs/common.py b/tests/gflownet/envs/common.py index 53d513de5..7e234e940 100644 --- a/tests/gflownet/envs/common.py +++ b/tests/gflownet/envs/common.py @@ -467,7 +467,6 @@ def test__gflownet_minimal_runs(self, n_repeat=1): float_precision=config.float_precision, base=forward_policy, ) - self.env.proxy = proxy # Set proxy in env. config.env.buffer.train = None # No buffers config.env.buffer.test = None config.env.buffer.replay_capacity = 0 # No replay buffer @@ -478,6 +477,7 @@ def test__gflownet_minimal_runs(self, n_repeat=1): device=config.device, float_precision=config.float_precision, env_maker=self.env.__class__, + proxy=proxy, forward_policy=forward_policy, backward_policy=backward_policy, buffer=config.env.buffer, From f55e62b563977489726f5d0294cac9ad75cc86f9 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 1 Apr 2024 22:51:32 -0400 Subject: [PATCH 30/73] Remove reward and proxy stuff from base env. --- gflownet/envs/base.py | 111 ------------------------------------------ 1 file changed, 111 deletions(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index 3a9f6579c..b6d349589 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -34,15 +34,11 @@ def __init__( device: str = "cpu", float_precision: int = 32, env_id: Union[int, str] = "env", - reward_min: float = 1e-8, - reward_beta: float = 1.0, reward_norm: float = 1.0, reward_norm_std_mult: float = 0.0, - reward_func: str = "identity", energies_stats: List[int] = None, denorm_proxy: bool = False, proxy=None, - proxy_state_format: str = "oracle", fixed_distr_params: Optional[dict] = None, random_distr_params: Optional[dict] = None, skip_mask_check: bool = False, @@ -61,21 +57,11 @@ def __init__( # Float precision self.float = set_float_precision(float_precision) # Reward settings - self.min_reward = reward_min - assert self.min_reward > 0 - self.reward_beta = reward_beta - assert self.reward_beta > 0 self.reward_norm = reward_norm assert self.reward_norm > 0 self.reward_norm_std_mult = reward_norm_std_mult - self.reward_func = reward_func self.energies_stats = energies_stats self.denorm_proxy = denorm_proxy - # Proxy - self.proxy = proxy - self.setup_proxy() - self.proxy_factor = -1.0 - self.proxy_state_format = proxy_state_format # Flag to skip checking if action is valid (computing mask) before step self.skip_mask_check = skip_mask_check # Log SoftMax function @@ -767,99 +753,6 @@ def traj2readable(self, traj=None): """ return str(traj).replace("(", "[").replace(")", "]").replace(",", "") - def reward(self, state=None, done=None, do_non_terminating=False): - """ - Computes the reward of a state - """ - state = self._get_state(state) - done = self._get_done(done) - if not done and not do_non_terminating: - return tfloat(0.0, float_type=self.float, device=self.device) - return self.proxy2reward(self.proxy(self.state2proxy(state))[0]) - - # TODO: cleanup - def reward_batch(self, states: List[List], done=None): - """ - Computes the rewards of a batch of states, given a list of states and 'dones' - """ - if done is None: - done = np.ones(len(states), dtype=bool) - states_proxy = self.states2proxy(states) - if isinstance(states_proxy, torch.Tensor): - states_proxy = states_proxy[list(done), :] - elif isinstance(states_proxy, list): - states_proxy = [states_proxy[i] for i in range(len(done)) if done[i]] - rewards = np.zeros(len(done)) - if len(states_proxy) > 0: - rewards[list(done)] = self.proxy2reward(self.proxy(states_proxy)).tolist() - return rewards - - def proxy2reward(self, proxy_vals): - """ - Prepares the output of a proxy for GFlowNet: the inputs proxy_vals is expected - to be a negative value (energy), unless self.denorm_proxy is True. If the - latter, the proxy values are first de-normalized according to the mean and - standard deviation in self.energies_stats. The output of the function is a - strictly positive reward - provided self.reward_norm and self.reward_beta are - positive - and larger than self.min_reward. - """ - if self.denorm_proxy: - # TODO: do with torch - # TODO: review - proxy_vals = ( - proxy_vals * (self.energies_stats[1] - self.energies_stats[0]) - + self.energies_stats[0] - ) - # proxy_vals = proxy_vals * self.energies_stats[3] + self.energies_stats[2] - if self.reward_func == "power": - return torch.clamp( - (self.proxy_factor * proxy_vals / self.reward_norm) ** self.reward_beta, - min=self.min_reward, - max=None, - ) - elif self.reward_func == "boltzmann": - return torch.clamp( - torch.exp(self.proxy_factor * self.reward_beta * proxy_vals), - min=self.min_reward, - max=None, - ) - elif self.reward_func == "identity": - return torch.clamp( - self.proxy_factor * proxy_vals, - min=self.min_reward, - max=None, - ) - elif self.reward_func == "shift": - return torch.clamp( - self.proxy_factor * proxy_vals + self.reward_beta, - min=self.min_reward, - max=None, - ) - else: - raise NotImplementedError - - def reward2proxy(self, reward): - """ - Converts a "GFlowNet reward" into a (negative) energy or values as returned by - a proxy. - """ - if self.reward_func == "power": - return self.proxy_factor * torch.exp( - ( - torch.log(reward) - + self.reward_beta * torch.log(torch.as_tensor(self.reward_norm)) - ) - / self.reward_beta - ) - elif self.reward_func == "boltzmann": - return self.proxy_factor * torch.log(reward) / self.reward_beta - elif self.reward_func == "identity": - return self.proxy_factor * reward - elif self.reward_func == "shift": - return self.proxy_factor * (reward - self.reward_beta) - else: - raise NotImplementedError - def reset(self, env_id: Union[int, str] = None): """ Resets the environment. @@ -993,10 +886,6 @@ def get_trajectories( ) return traj_list, traj_actions_list - def setup_proxy(self): - if self.proxy: - self.proxy.setup(self) - @torch.no_grad() def compute_train_energy_proxy_and_rewards(self): """ From 729ed6925545228010f8239999fc2e68ccdf768b Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 1 Apr 2024 22:52:19 -0400 Subject: [PATCH 31/73] Add TODOs --- gflownet/envs/cube.py | 4 +++- gflownet/envs/htorus.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/gflownet/envs/cube.py b/gflownet/envs/cube.py index f83e62383..7a7d6fd0d 100644 --- a/gflownet/envs/cube.py +++ b/gflownet/envs/cube.py @@ -1418,7 +1418,9 @@ def get_uniform_terminating_states( states = rng.uniform(low=kappa, high=1.0 - kappa, size=(n_states, self.n_dim)) return states.tolist() - # TODO: make generic for all environments + # TODO: make generic for all environments, or rather elsewhere + # TODO: fix because it currently uses reward_batch, proxy2reward, etc.. For + # example, this could be done by passing the proxy as a parameter. def sample_from_reward( self, n_samples: int, epsilon=1e-4 ) -> TensorType["n_samples", "state_dim"]: diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index 005d380a7..f15b1cdde 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -516,7 +516,9 @@ def get_uniform_terminating_states( states = np.concatenate((angles, np.ones((n_states, 1))), axis=1) return states.tolist() - # TODO: make generic for all environments + # TODO: make generic for all environments, or rather elsewhere + # TODO: fix because it currently uses reward_batch, proxy2reward, etc.. For + # example, this could be done by passing the proxy as a parameter. def sample_from_reward( self, n_samples: int, epsilon=1e-4 ) -> TensorType["n_samples", "state_dim"]: From f71d57afaa521702e4b0006d768fa3b82764e3a8 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 1 Apr 2024 22:52:41 -0400 Subject: [PATCH 32/73] Remove old stuff --- gflownet/envs/grid.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/gflownet/envs/grid.py b/gflownet/envs/grid.py index 600f54a23..375488fb8 100644 --- a/gflownet/envs/grid.py +++ b/gflownet/envs/grid.py @@ -79,10 +79,6 @@ def __init__( self.eos = tuple([0 for _ in range(self.n_dim)]) # Base class init super().__init__(**kwargs) - # Proxy format - # TODO: assess if really needed - if self.proxy_state_format == "ohe": - self.states2proxy = self.states2policy def get_action_space(self): """ From f6bfaa1e33145627bba79531a94eb3ddf5e7e5a4 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 1 Apr 2024 22:52:58 -0400 Subject: [PATCH 33/73] Add proxy to Buffer --- gflownet/gflownet.py | 6 +++++- gflownet/utils/buffer.py | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 8f361e21f..385025a9d 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -107,7 +107,11 @@ def __init__( self.replay_sampling = replay_sampling self.train_sampling = train_sampling self.buffer = Buffer( - **buffer, env=self.env, make_train_test=not sample_only, logger=logger + **buffer, + env=self.env, + proxy=self.proxy, + make_train_test=not sample_only, + logger=logger, ) # Train set statistics and reward normalization constant if self.buffer.train is not None: diff --git a/gflownet/utils/buffer.py b/gflownet/utils/buffer.py index 75edc4d44..4615d840f 100644 --- a/gflownet/utils/buffer.py +++ b/gflownet/utils/buffer.py @@ -19,6 +19,7 @@ class Buffer: def __init__( self, env, + proxy, make_train_test=False, replay_capacity=0, output_csv=None, @@ -30,6 +31,7 @@ def __init__( ): self.logger = logger self.env = env + self.proxy = proxy self.replay_capacity = replay_capacity self.main = pd.DataFrame(columns=["state", "traj", "reward", "energy", "iter"]) self.replay = pd.DataFrame( @@ -252,7 +254,7 @@ def make_data_set(self, config): samples = self.env.get_random_terminating_states(config.n) else: return None, None - energies = self.env.proxy(self.env.states2proxy(samples)).tolist() + energies = self.proxy(self.env.states2proxy(samples)).tolist() df = pd.DataFrame( { "samples": [self.env.state2readable(s) for s in samples], From 1020d6c5f3802a040df5d619b272df27adc0ce85 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 1 Apr 2024 23:14:44 -0400 Subject: [PATCH 34/73] Fixes --- gflownet/gflownet.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 385025a9d..774922f4f 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -71,6 +71,7 @@ def __init__( self.env = self.env_maker() # Proxy self.proxy = proxy + self.proxy.setup(self.env) # Continuous environments self.continuous = hasattr(self.env, "continuous") and self.env.continuous if self.continuous and optimizer.loss in ["flowmatch", "flowmatching"]: @@ -1102,6 +1103,7 @@ def train(self): proxy_vals = self.proxy(states_proxy_term) rewards = self.proxy.proxy2reward(proxy_vals) rewards = rewards.tolist() + proxy_vals = proxy_vals.tolist() actions_trajectories = batch.get_actions_trajectories() self.buffer.add(states_term, actions_trajectories, rewards, proxy_vals, it) self.buffer.add( From 065fa4f0f0bded0bd1c1cc7a115e20fa09d6f4df Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 1 Apr 2024 23:14:58 -0400 Subject: [PATCH 35/73] Adapt sanity checks config --- mila/dev/sanity_check_runs.yaml | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mila/dev/sanity_check_runs.yaml b/mila/dev/sanity_check_runs.yaml index 443c235ed..94ae95b7f 100644 --- a/mila/dev/sanity_check_runs.yaml +++ b/mila/dev/sanity_check_runs.yaml @@ -67,7 +67,8 @@ jobs: height: 10 pieces: ["J", "L", "S", "Z"] allow_eos_before_full: True - reward_func: boltzmann + proxy: + reward_func: exponential gflownet: flowmatch proxy: tetris - slurm: @@ -79,7 +80,8 @@ jobs: height: 10 pieces: ["J", "L", "S", "Z"] allow_eos_before_full: True - reward_func: boltzmann + proxy: + reward_func: exponential gflownet: trajectorybalance proxy: tetris - slurm: @@ -91,7 +93,8 @@ jobs: height: 10 pieces: ["J", "L", "S", "Z"] allow_eos_before_full: True - reward_func: boltzmann + proxy: + reward_func: exponential gflownet: forwardlooking proxy: tetris # Ctorus From 68be0cb652956900f0ebbfce2e2c88c01f92b23b Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 1 Apr 2024 23:36:35 -0400 Subject: [PATCH 36/73] Fix --- mila/dev/sanity_check_runs.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mila/dev/sanity_check_runs.yaml b/mila/dev/sanity_check_runs.yaml index 94ae95b7f..6c70471e4 100644 --- a/mila/dev/sanity_check_runs.yaml +++ b/mila/dev/sanity_check_runs.yaml @@ -68,7 +68,7 @@ jobs: pieces: ["J", "L", "S", "Z"] allow_eos_before_full: True proxy: - reward_func: exponential + reward_function: exponential gflownet: flowmatch proxy: tetris - slurm: @@ -81,7 +81,7 @@ jobs: pieces: ["J", "L", "S", "Z"] allow_eos_before_full: True proxy: - reward_func: exponential + reward_function: exponential gflownet: trajectorybalance proxy: tetris - slurm: @@ -94,7 +94,7 @@ jobs: pieces: ["J", "L", "S", "Z"] allow_eos_before_full: True proxy: - reward_func: exponential + reward_function: exponential gflownet: forwardlooking proxy: tetris # Ctorus From 0553b0277c8fc5a417271f559ee0026331268735 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Wed, 3 Apr 2024 14:34:41 -0400 Subject: [PATCH 37/73] Do not pass proxy to env anymore --- gflownet/envs/base.py | 3 --- main.py | 1 - 2 files changed, 4 deletions(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index b6d349589..83c51a809 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -37,8 +37,6 @@ def __init__( reward_norm: float = 1.0, reward_norm_std_mult: float = 0.0, energies_stats: List[int] = None, - denorm_proxy: bool = False, - proxy=None, fixed_distr_params: Optional[dict] = None, random_distr_params: Optional[dict] = None, skip_mask_check: bool = False, @@ -61,7 +59,6 @@ def __init__( assert self.reward_norm > 0 self.reward_norm_std_mult = reward_norm_std_mult self.energies_stats = energies_stats - self.denorm_proxy = denorm_proxy # Flag to skip checking if action is valid (computing mask) before step self.skip_mask_check = skip_mask_check # Log SoftMax function diff --git a/main.py b/main.py index a1c2720bf..99096aed4 100644 --- a/main.py +++ b/main.py @@ -42,7 +42,6 @@ def main(config): # https://hydra.cc/docs/advanced/instantiate_objects/overview/#partial-instantiation env_maker = hydra.utils.instantiate( config.env, - proxy=proxy, device=config.device, float_precision=config.float_precision, _partial_=True, From 146cace2db25a0b90ce13819b0b71a553c7d9ba5 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Wed, 3 Apr 2024 14:36:12 -0400 Subject: [PATCH 38/73] Remove old variables from base env config --- config/env/base.yaml | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/config/env/base.yaml b/config/env/base.yaml index 04f4105c7..ace90b1a7 100644 --- a/config/env/base.yaml +++ b/config/env/base.yaml @@ -1,19 +1,9 @@ _target_: gflownet.envs.base.GFlowNetEnv -# Reward function: power or boltzmann -# boltzmann: exp(-1.0 * reward_beta * proxy) -# power: (-1.0 * proxy / reward_norm) ** self.reward_beta -# identity: proxy -reward_func: identity -# Minimum reward -reward_min: 1e-8 -# Beta parameter of the reward function -reward_beta: 1.0 # Reward normalization for "power" reward function reward_norm: 1.0 # If > 0, reward_norm = reward_norm_std_mult * std(energies) reward_norm_std_mult: 0.0 -proxy_state_format: oracle # Check if action valid with mask before step skip_mask_check: False # Whether the environment has conditioning variables From 9a62803872001e426bfdc4bc2c879257de897a6e Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Wed, 1 May 2024 17:52:36 -0400 Subject: [PATCH 39/73] Fix tests minimal runs --- gflownet/gflownet.py | 3 ++- tests/gflownet/envs/common.py | 8 ++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 774922f4f..cc52d7acc 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -9,6 +9,7 @@ import pickle import time from collections import defaultdict +from functools import partial from pathlib import Path from typing import List, Optional, Tuple, Union @@ -37,7 +38,7 @@ class GFlowNetAgent: def __init__( self, - env_maker, + env_maker: partial, proxy, seed, device, diff --git a/tests/gflownet/envs/common.py b/tests/gflownet/envs/common.py index 7e234e940..108bb6a49 100644 --- a/tests/gflownet/envs/common.py +++ b/tests/gflownet/envs/common.py @@ -9,6 +9,7 @@ import inspect import warnings +from functools import partial import hydra import numpy as np @@ -16,7 +17,6 @@ import torch import yaml from hydra import compose, initialize -from omegaconf import OmegaConf from gflownet.utils.common import copy, tbool, tfloat from gflownet.utils.policy import parse_policy_config @@ -444,7 +444,10 @@ def test__gflownet_minimal_runs(self, n_repeat=1): ): config = compose(config_name="tests") + # Logger logger = hydra.utils.instantiate(config.logger, config, _recursive_=False) + + # Proxy proxy = hydra.utils.instantiate( config.proxy, device=config.device, @@ -472,11 +475,12 @@ def test__gflownet_minimal_runs(self, n_repeat=1): config.env.buffer.replay_capacity = 0 # No replay buffer config.gflownet.optimizer.n_train_steps = 1 # Set 1 training step + # GFlowNet agent gflownet = hydra.utils.instantiate( config.gflownet, device=config.device, float_precision=config.float_precision, - env_maker=self.env.__class__, + env_maker=partial(self.env.copy), proxy=proxy, forward_policy=forward_policy, backward_policy=backward_policy, From af904c3fa61f0721bc93999226d9d73530b8621c Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Fri, 3 May 2024 00:40:35 -0400 Subject: [PATCH 40/73] WIP: reorganise plotting methods --- gflownet/envs/cube.py | 49 +++++++---------------------- gflownet/gflownet.py | 71 +++++++++++++++++++++++++++++++++++++------ 2 files changed, 73 insertions(+), 47 deletions(-) diff --git a/gflownet/envs/cube.py b/gflownet/envs/cube.py index 7a7d6fd0d..d2d470d6e 100644 --- a/gflownet/envs/cube.py +++ b/gflownet/envs/cube.py @@ -15,7 +15,7 @@ from torchtyping import TensorType from gflownet.envs.base import GFlowNetEnv -from gflownet.utils.common import copy, tbool, tfloat +from gflownet.utils.common import copy, tbool, tfloat, torch2np class CubeBase(GFlowNetEnv, ABC): @@ -1418,49 +1418,24 @@ def get_uniform_terminating_states( states = rng.uniform(low=kappa, high=1.0 - kappa, size=(n_states, self.n_dim)) return states.tolist() - # TODO: make generic for all environments, or rather elsewhere - # TODO: fix because it currently uses reward_batch, proxy2reward, etc.. For - # example, this could be done by passing the proxy as a parameter. - def sample_from_reward( - self, n_samples: int, epsilon=1e-4 - ) -> TensorType["n_samples", "state_dim"]: - """ - Rejection sampling with proposal the uniform distribution in [0, 1]^n_dim. - - Returns a tensor in GFloNet (state) format. - """ - samples_final = [] - max_reward = self.proxy2reward(self.proxy.min) - while len(samples_final) < n_samples: - samples_uniform = self.states2proxy( - self.get_uniform_terminating_states(n_samples) - ) - rewards = self.proxy2reward(self.proxy(samples_uniform)) - mask = ( - torch.rand(n_samples, dtype=self.float, device=self.device) - * (max_reward + epsilon) - < rewards - ) - samples_accepted = samples_uniform[mask] - samples_final.extend(samples_accepted[-(n_samples - len(samples_final)) :]) - return torch.vstack(samples_final) - - # TODO: make generic for all envs def fit_kde(self, samples, kernel="gaussian", bandwidth=0.1): + samples = torch2np(self.states2proxy(samples)) return KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(samples) - def plot_reward_samples( + def plot_samples_reward( self, - samples, - alpha=0.5, - cell_min=-1.0, - cell_max=1.0, - dpi=150, - max_samples=500, + samples: List[List], + rewards: TensorType["batch_size"], + alpha: float = 0.5, + cell_min: float = -1.0, + cell_max: float = 1.0, + dpi: int = 150, + max_samples: int = 500, **kwargs, ): if self.n_dim != 2: return None + samples = torch2np(self.states2proxy(samples)) # Sample a grid of points in the state space and obtain the rewards x = np.linspace(cell_min, cell_max, 201) y = np.linspace(cell_min, cell_max, 201) @@ -1469,7 +1444,6 @@ def plot_reward_samples( states_mesh = torch.tensor( X.reshape(-1, 2), device=self.device, dtype=self.float ) - rewards = self.proxy2reward(self.proxy(states_mesh)) # Init figure fig, ax = plt.subplots() fig.set_dpi(dpi) @@ -1488,7 +1462,6 @@ def plot_reward_samples( plt.tight_layout() return fig - # TODO: make generic for all envs def plot_kde( self, kde, diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index cc52d7acc..5a6eab4af 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -11,16 +11,18 @@ from collections import defaultdict from functools import partial from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn from scipy.special import logsumexp from torch.distributions import Bernoulli +from torchtyping import TensorType from tqdm import tqdm from gflownet.envs.base import GFlowNetEnv +from gflownet.proxy.base import Proxy from gflownet.utils.batch import Batch from gflownet.utils.buffer import Buffer from gflownet.utils.common import ( @@ -1254,7 +1256,6 @@ def test(self, **plot_kwargs): assert batch.is_valid() x_sampled = batch.get_terminating_states() # TODO make it work with conditional env - x_sampled = torch2np(self.env.states2proxy(x_sampled)) x_tt = torch2np(self.env.states2proxy(x_tt)) kde_pred = self.env.fit_kde( x_sampled, @@ -1266,10 +1267,7 @@ def test(self, **plot_kwargs): kde_true = dict_tt["kde_true"] else: # Sample from reward via rejection sampling - x_from_reward = self.env.sample_from_reward( - n_samples=self.logger.test.n - ) - x_from_reward = torch2np(self.env.states2proxy(x_from_reward)) + x_from_reward = self.sample_from_reward(n_samples=self.logger.test.n) # Fit KDE with samples from reward kde_true = self.env.fit_kde( x_from_reward, @@ -1317,9 +1315,11 @@ def test(self, **plot_kwargs): jsd += 0.5 * np.sum(density_pred * (log_density_pred - log_mean_dens)) # Plots - - if hasattr(self.env, "plot_reward_samples"): - fig_reward_samples = self.env.plot_reward_samples(x_sampled, **plot_kwargs) + if hasattr(self.env, "plot_samples_reward"): + rewards = self.proxy.rewards(self.env.states2proxy(x_sampled)) + fig_reward_samples = self.env.plot_samples_reward( + x_sampled, rewards, **plot_kwargs + ) else: fig_reward_samples = None if hasattr(self.env, "plot_kde"): @@ -1457,6 +1457,59 @@ def test_top_k(self, it, progress=False, gfn_states=None, random_states=None): print() return metrics, figs, fig_names, summary + # TODO: implement other proposal distributions + # TODO: rethink whether it is needed to convert to reward + def sample_from_reward( + self, + n_samples: int, + proposal_distribution: str = "uniform", + epsilon=1e-4, + ) -> Union[List, Dict, TensorType["n_samples", "state_dim"]]: + """ + Rejection sampling with proposal the uniform distribution defined over the + sample space. + + Returns a tensor in GFloNet (state) format. + + Parameters + ---------- + n_samples : int + The number of samples to draw from the reward distribution. + proposal_distribution : str + Identifier of the proposal distribution. Currently only `uniform` is + implemented. + epsilon : float + Small epsilon parameter for rejection sampling. + + Returns + ------- + samples_final : list + The list of samples drawn from the reward distribution. + """ + samples_final = [] + max_reward = self.proxy.proxy2reward(self.proxy.min) + while len(samples_final) < n_samples: + if proposal_distribution == "uniform": + # TODO: sample only the remaining number of samples + samples_uniform = self.env.get_uniform_terminating_states(n_samples) + else: + raise NotImplementedError("The proposal distribution must be uniform") + rewards = self.proxy.proxy2reward( + self.proxy(self.env.states2proxy(samples_uniform)) + ) + indices_accept = ( + ( + torch.rand(n_samples, dtype=self.float, device=self.device) + * (max_reward + epsilon) + < rewards + ) + .flatten() + .tolist() + ) + samples_accepted = [samples_uniform[idx] for idx in indices_accept] + samples_final.extend(samples_accepted[-(n_samples - len(samples_final)) :]) + return samples_final + def get_log_corr(self, times): data_logq = [] times.update( From be2be7911f575e6dfb27c92df52671ed717b82da Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Fri, 3 May 2024 12:10:34 -0400 Subject: [PATCH 41/73] Adapt plot_reward_samples in GFlowNet agent and Cube --- config/logger/base.yaml | 2 ++ gflownet/envs/cube.py | 76 +++++++++++++++++++++++++++++++---------- gflownet/gflownet.py | 23 ++++++++++--- 3 files changed, 78 insertions(+), 23 deletions(-) diff --git a/config/logger/base.yaml b/config/logger/base.yaml index a50b04479..7d3a6cf2e 100644 --- a/config/logger/base.yaml +++ b/config/logger/base.yaml @@ -26,6 +26,8 @@ test: logprobs_bootstrap_size: 10000 # Maximum number of test data points to compute log likelihood probs. max_data_logprobs: 1e5 + # Number of points tor obtain a grid to estimate the reward density + n_grid: 40401 # Oracle metrics oracle: period: 100000 diff --git a/gflownet/envs/cube.py b/gflownet/envs/cube.py index d2d470d6e..dd0bc9b70 100644 --- a/gflownet/envs/cube.py +++ b/gflownet/envs/cube.py @@ -17,6 +17,9 @@ from gflownet.envs.base import GFlowNetEnv from gflownet.utils.common import copy, tbool, tfloat, torch2np +CELL_MIN = -1.0 +CELL_MAX = 1.0 + class CubeBase(GFlowNetEnv, ABC): """ @@ -136,7 +139,7 @@ def states2proxy( ) -> TensorType["batch", "state_dim"]: """ Prepares a batch of states in "environment format" for a proxy: clips the - states into [0, 1] and maps them to [-1.0, 1.0] + states into [0, 1] and maps them to [CELL_MIN, CELL_MAX] Args ---- @@ -149,7 +152,7 @@ def states2proxy( A tensor containing all the states in the batch. """ states = tfloat(states, device=self.device, float_type=self.float) - return 2.0 * torch.clip(states, min=0.0, max=1.0) - 1.0 + return 2.0 * torch.clip(states, min=0.0, max=CELL_MAX) - CELL_MAX def states2policy( self, states: Union[List, TensorType["batch", "state_dim"]] @@ -1418,13 +1421,33 @@ def get_uniform_terminating_states( states = rng.uniform(low=kappa, high=1.0 - kappa, size=(n_states, self.n_dim)) return states.tolist() - def fit_kde(self, samples, kernel="gaussian", bandwidth=0.1): - samples = torch2np(self.states2proxy(samples)) + def fit_kde( + self, + samples: TensorType["batch_size", "state_proxy_dim"], + kernel: str = "gaussian", + bandwidth: float = 0.1, + ): + r""" + Fits a Kernel Density Estimator on a batch of samples. + + Parameters + ---------- + samples : tensor + A batch of samples in proxy format. + kernel : str + An identifier of the kernel to use for the density estimation. It must be a + valid kernel for the scikit-learn method + :py:meth:`sklearn.neighbors.KernelDensity`. + bandwidth : float + The bandwidth of the kernel. + """ + samples = torch2np(samples) return KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(samples) - def plot_samples_reward( + def plot_reward_samples( self, - samples: List[List], + samples: TensorType["batch_size", "state_proxy_dim"], + samples_reward: TensorType["batch_size", "state_proxy_dim"], rewards: TensorType["batch_size"], alpha: float = 0.5, cell_min: float = -1.0, @@ -1433,22 +1456,39 @@ def plot_samples_reward( max_samples: int = 500, **kwargs, ): + """ + Plots the reward contour alongside a batch of samples. + + Parameters + ---------- + samples : tensor + A batch of samples from the GFlowNet policy in proxy format. These samples + will be plotted on top of the reward density. + samples_reward : tensor + A batch of samples containing a grid over the sample space, from which the + reward has been obtained. These samples are used to plot the contour of + reward density. + rewards : tensor + The reward of samples_reward. + alpha : float + Transparency of the reward contour. + dpi : int + Dots per inch, indicating the resolution of the plot. + """ if self.n_dim != 2: return None - samples = torch2np(self.states2proxy(samples)) - # Sample a grid of points in the state space and obtain the rewards - x = np.linspace(cell_min, cell_max, 201) - y = np.linspace(cell_min, cell_max, 201) + samples = torch2np(samples) + samples_reward = torch2np(samples_reward) + rewards = torch2np(rewards) + # Create mesh from samples_reward + x = np.unique(samples_reward[:, 0]) + y = np.unique(samples_reward[:, 1]) xx, yy = np.meshgrid(x, y) - X = np.stack([xx, yy], axis=-1) - states_mesh = torch.tensor( - X.reshape(-1, 2), device=self.device, dtype=self.float - ) # Init figure fig, ax = plt.subplots() fig.set_dpi(dpi) # Plot reward contour - h = ax.contourf(xx, yy, rewards.reshape(xx.shape).cpu().numpy(), alpha=alpha) + h = ax.contourf(xx, yy, rewards.reshape(xx.shape), alpha=alpha) ax.axis("scaled") fig.colorbar(h, ax=ax) # Plot samples @@ -1456,9 +1496,9 @@ def plot_samples_reward( ax.scatter(samples[random_indices, 0], samples[random_indices, 1], alpha=alpha) # Figure settings ax.grid() - padding = 0.05 * (cell_max - cell_min) - ax.set_xlim([cell_min - padding, cell_max + padding]) - ax.set_ylim([cell_min - padding, cell_max + padding]) + padding = 0.05 * (CELL_MAX - CELL_MIN) + ax.set_xlim([CELL_MIN - padding, CELL_MAX + padding]) + ax.set_ylim([CELL_MIN - padding, CELL_MAX + padding]) plt.tight_layout() return fig diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 5a6eab4af..2569bfeee 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -1254,7 +1254,7 @@ def test(self, **plot_kwargs): elif self.continuous and hasattr(self.env, "fit_kde"): batch, _ = self.sample_batch(n_forward=self.logger.test.n, train=False) assert batch.is_valid() - x_sampled = batch.get_terminating_states() + x_sampled = batch.get_terminating_states(proxy=True) # TODO make it work with conditional env x_tt = torch2np(self.env.states2proxy(x_tt)) kde_pred = self.env.fit_kde( @@ -1315,10 +1315,23 @@ def test(self, **plot_kwargs): jsd += 0.5 * np.sum(density_pred * (log_density_pred - log_mean_dens)) # Plots - if hasattr(self.env, "plot_samples_reward"): - rewards = self.proxy.rewards(self.env.states2proxy(x_sampled)) - fig_reward_samples = self.env.plot_samples_reward( - x_sampled, rewards, **plot_kwargs + if hasattr(self.env, "plot_reward_samples"): + if hasattr(self.env, "get_all_terminating_states"): + samples_reward = self.env.get_all_terminating_states() + elif hasattr(self.env, "get_grid_terminating_states"): + samples_reward = self.env.get_grid_terminating_states( + self.logger.test.n_grid + ) + else: + raise NotImplementedError( + "In order to plot the reward density and the samples, the " + "environment must implement either get_all_terminating_states() " + "or get_grid_terminating_states()" + ) + samples_reward = self.env.states2proxy(samples_reward) + rewards = self.proxy.rewards(samples_reward) + fig_reward_samples = self.env.plot_reward_samples( + x_sampled, samples_reward, rewards, **plot_kwargs ) else: fig_reward_samples = None From 0d4db56c246cb804fbc38d515bac4012e1141cfb Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Fri, 3 May 2024 15:24:56 -0400 Subject: [PATCH 42/73] Fix issue with inf logprobs --- gflownet/proxy/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 7bbd2c9ce..81648eef2 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -124,6 +124,7 @@ def proxy2logreward(self, proxy_values: TensorType) -> TensorType: """ logrewards = self._logreward_function(proxy_values) logrewards[logrewards.isnan()] = self.get_min_reward(log=True) + logrewards[~logrewards.isfinite()] = self.get_min_reward(log=True) return logrewards def get_min_reward(self, log: bool = False) -> float: @@ -146,7 +147,7 @@ def get_min_reward(self, log: bool = False) -> float: if log: if not hasattr(self, "logreward_min"): if self.reward_min == 0.0: - self.logreward_min = -np.inf + self.logreward_min = -1e3 else: self.logreward_min = np.log(self.reward_min) return self.logreward_min From 1615a5ab552a63b21fe517d15d121e3c7ea66460 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Sat, 4 May 2024 20:52:01 -0400 Subject: [PATCH 43/73] Adapt plot_kde --- gflownet/envs/cube.py | 37 +++++++++++++++++++-------- gflownet/gflownet.py | 59 ++++++++++++++++++++++++++++++------------- 2 files changed, 68 insertions(+), 28 deletions(-) diff --git a/gflownet/envs/cube.py b/gflownet/envs/cube.py index dd0bc9b70..46b2395d7 100644 --- a/gflownet/envs/cube.py +++ b/gflownet/envs/cube.py @@ -1450,8 +1450,6 @@ def plot_reward_samples( samples_reward: TensorType["batch_size", "state_proxy_dim"], rewards: TensorType["batch_size"], alpha: float = 0.5, - cell_min: float = -1.0, - cell_max: float = 1.0, dpi: int = 150, max_samples: int = 500, **kwargs, @@ -1474,6 +1472,8 @@ def plot_reward_samples( Transparency of the reward contour. dpi : int Dots per inch, indicating the resolution of the plot. + max_samples : int + Maximum of number of samples to include in the plot. """ if self.n_dim != 2: return None @@ -1504,22 +1504,37 @@ def plot_reward_samples( def plot_kde( self, + samples: TensorType["batch_size", "state_proxy_dim"], kde, - alpha=0.5, - cell_min=-1.0, - cell_max=1.0, + alpha: float = 0.5, dpi=150, - colorbar=True, + colorbar: bool = True, **kwargs, ): + """ + Plots the density previously estimated from a batch of samples via KDE over the + entire sample space. + + Parameters + ---------- + samples : tensor + A batch of samples containing a grid over the sample space. These samples + are used to plot the contour of the estimated density. + kde : KDE + A scikit-learn KDE object fit with a batch of samples. + alpha : float + Transparency of the density contour. + dpi : int + Dots per inch, indicating the resolution of the plot. + """ if self.n_dim != 2: return None - # Sample a grid of points in the state space and score them with the KDE - x = np.linspace(cell_min, cell_max, 201) - y = np.linspace(cell_min, cell_max, 201) + samples = torch2np(samples) + # Create mesh from samples_reward + x = np.unique(samples[:, 0]) + y = np.unique(samples[:, 1]) xx, yy = np.meshgrid(x, y) - X = np.stack([xx, yy], axis=-1) - Z = np.exp(kde.score_samples(X.reshape(-1, 2))).reshape(xx.shape) + Z = np.exp(kde.score_samples(samples)).reshape(xx.shape) # Init figure fig, ax = plt.subplots() fig.set_dpi(dpi) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 2569bfeee..ce6d39154 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -1316,28 +1316,53 @@ def test(self, **plot_kwargs): # Plots if hasattr(self.env, "plot_reward_samples"): - if hasattr(self.env, "get_all_terminating_states"): - samples_reward = self.env.get_all_terminating_states() - elif hasattr(self.env, "get_grid_terminating_states"): - samples_reward = self.env.get_grid_terminating_states( - self.logger.test.n_grid - ) - else: - raise NotImplementedError( - "In order to plot the reward density and the samples, the " - "environment must implement either get_all_terminating_states() " - "or get_grid_terminating_states()" - ) - samples_reward = self.env.states2proxy(samples_reward) - rewards = self.proxy.rewards(samples_reward) + # TODO: improve to not repeat code + if not hasattr(self, "sample_space_batch"): + if hasattr(self.env, "get_all_terminating_states"): + self.sample_space_batch = self.env.get_all_terminating_states() + elif hasattr(self.env, "get_grid_terminating_states"): + self.sample_space_batch = self.env.get_grid_terminating_states( + self.logger.test.n_grid + ) + else: + raise NotImplementedError( + "In order to plot the reward density and the samples, the " + "environment must implement either get_all_terminating_states() " + "or get_grid_terminating_states()" + ) + self.sample_space_batch = self.env.states2proxy(self.sample_space_batch) + if not hasattr(self, "rewards_sample_space"): + self.rewards_sample_space = self.proxy.rewards(self.sample_space_batch) fig_reward_samples = self.env.plot_reward_samples( - x_sampled, samples_reward, rewards, **plot_kwargs + x_sampled, + self.sample_space_batch, + self.rewards_sample_space, + **plot_kwargs, ) else: fig_reward_samples = None if hasattr(self.env, "plot_kde"): - fig_kde_pred = self.env.plot_kde(kde_pred, **plot_kwargs) - fig_kde_true = self.env.plot_kde(kde_true, **plot_kwargs) + # TODO: improve to not repeat code + if not hasattr(self, "sample_space_batch"): + if hasattr(self.env, "get_all_terminating_states"): + self.sample_space_batch = self.env.get_all_terminating_states() + elif hasattr(self.env, "get_grid_terminating_states"): + self.sample_space_batch = self.env.get_grid_terminating_states( + self.logger.test.n_grid + ) + else: + raise NotImplementedError( + "In order to plot the KDEs over the sample space, the " + "environment must implement either get_all_terminating_states() " + "or get_grid_terminating_states()" + ) + self.sample_space_batch = self.env.states2proxy(self.sample_space_batch) + fig_kde_pred = self.env.plot_kde( + self.sample_space_batch, kde_pred, **plot_kwargs + ) + fig_kde_true = self.env.plot_kde( + self.sample_space_batch, kde_true, **plot_kwargs + ) else: fig_kde_pred = None fig_kde_true = None From d62cefcf92221d99deef45de1ba44998e6275934 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Sat, 4 May 2024 20:53:39 -0400 Subject: [PATCH 44/73] WIP: plot methods for grid (broken) --- gflownet/envs/grid.py | 88 ++++++++++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 35 deletions(-) diff --git a/gflownet/envs/grid.py b/gflownet/envs/grid.py index 375488fb8..a8db1c410 100644 --- a/gflownet/envs/grid.py +++ b/gflownet/envs/grid.py @@ -323,43 +323,61 @@ def get_uniform_terminating_states( states = rng.integers(low=0, high=self.length, size=(n_states, self.n_dim)) return states.tolist() - # TODO: review - def plot_samples_frequency(self, samples, ax=None, title=None, rescale=1): + def plot_reward_samples( + self, + samples, + ax=None, + title=None, + rescale=1, + dpi=150, + n_ticks_max=50, + reward_norm=True, + ): """ Plot 2D histogram of samples. """ - if self.n_dim > 2: + # Only available for 2D grids + if self.n_dim != 2: return None - if ax is None: - fig, ax = plt.subplots() - standalone = True - else: - standalone = False - # make a list of integers from 0 to n_dim - if rescale != 1: - step = int(self.length / rescale) - else: - step = 1 - ax.set_xticks(np.arange(start=0, stop=self.length, step=step)) - ax.set_yticks(np.arange(start=0, stop=self.length, step=step)) - # check if samples is on GPU - if torch.is_tensor(samples) and samples.is_cuda: - samples = samples.detach().cpu() - states = np.array(samples).astype(int) - grid = np.zeros((self.length, self.length)) - if title == None: - ax.set_title("Frequency of Coordinates Sampled") - else: - ax.set_title(title) - # TODO: optimize - for state in states: - grid[state[0], state[1]] += 1 - im = ax.imshow(grid) + # Init figure + fig, axes = plt.subplots(ncols=2, dpi=dpi) + step_ticks = np.ceil(self.length / n_ticks_max).astype(int) + # Get all states and their reward + if not hasattr(self, "_rewards_all_2d"): + states_all = self.get_all_terminating_states() + rewards_all = self.proxy2reward( + self.proxy(self.statebatch2proxy(states_all)) + ) + if reward_norm: + rewards_all = rewards_all / rewards_all.sum() + self._rewards_all_2d = torch.empty( + (self.length, self.length), device=self.device, dtype=self.float + ) + for row in range(self.length): + for col in range(self.length): + idx = states_all.index([row, col]) + self._rewards_all_2d[row, col] = rewards_all[idx] + self._rewards_all_2d = self._rewards_all_2d.detach().cpu().numpy() + # 2D histogram of samples + samples = np.array(samples) + samples_hist, xedges, yedges = np.histogram2d( + samples[:, 0], samples[:, 1], bins=(self.length, self.length), density=True + ) + # Transpose and reverse rows so that [0, 0] is at bottom left + samples_hist = samples_hist.T[::-1, :] + # Plot reward + self._plot_grid_2d(self._rewards_all_2d, axes[0], step_ticks) + # Plot samples histogram + self._plot_grid_2d(samples_hist, axes[1], step_ticks) + fig.tight_layout() + return fig + + @staticmethod + def _plot_grid_2d(img, ax, step_ticks): + ax_img = ax.imshow(img) divider = make_axes_locatable(ax) - cax = divider.append_axes("right", size="5%", pad=0.05) - plt.colorbar(im, cax=cax) - plt.show() - if standalone == True: - plt.tight_layout() - plt.close() - return ax + cax = divider.append_axes("top", size="5%", pad=0.05) + ax.set_xticks(np.arange(start=0, stop=img.shape[0], step=step_ticks)) + ax.set_yticks(np.arange(start=0, stop=img.shape[1], step=step_ticks)[::-1]) + plt.colorbar(ax_img, cax=cax, orientation="horizontal") + cax.xaxis.set_ticks_position("top") From 4f1e7141b2fa99581a4371ed8eee25acdd72f445 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Sat, 4 May 2024 21:20:59 -0400 Subject: [PATCH 45/73] WIP:: adapt plotting methods of htorus --- gflownet/envs/cube.py | 1 + gflownet/envs/htorus.py | 179 +++++++++++++++++++++++----------------- 2 files changed, 102 insertions(+), 78 deletions(-) diff --git a/gflownet/envs/cube.py b/gflownet/envs/cube.py index 46b2395d7..16745bee5 100644 --- a/gflownet/envs/cube.py +++ b/gflownet/envs/cube.py @@ -1534,6 +1534,7 @@ def plot_kde( x = np.unique(samples[:, 0]) y = np.unique(samples[:, 1]) xx, yy = np.meshgrid(x, y) + # Score samples with KDE Z = np.exp(kde.score_samples(samples)).reshape(xx.shape) # Init figure fig, ax = plt.subplots() diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index f15b1cdde..7377639fa 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -516,78 +516,83 @@ def get_uniform_terminating_states( states = np.concatenate((angles, np.ones((n_states, 1))), axis=1) return states.tolist() - # TODO: make generic for all environments, or rather elsewhere - # TODO: fix because it currently uses reward_batch, proxy2reward, etc.. For - # example, this could be done by passing the proxy as a parameter. - def sample_from_reward( - self, n_samples: int, epsilon=1e-4 - ) -> TensorType["n_samples", "state_dim"]: - """ - Rejection sampling with proposal the uniform distribution in [0, 2pi]]^n_dim. - - Returns a tensor in GFloNet (state) format. + def fit_kde( + self, + samples: TensorType["batch_size", "state_proxy_dim"], + kernel: str = "gaussian", + bandwidth: float = 0.1, + ): + r""" + Fits a Kernel Density Estimator on a batch of samples. + + The samples are previously augmented in order to visualise the periodic aspect + of the sample space. + + Parameters + ---------- + samples : tensor + A batch of samples in proxy format. + kernel : str + An identifier of the kernel to use for the density estimation. It must be a + valid kernel for the scikit-learn method + :py:meth:`sklearn.neighbors.KernelDensity`. + bandwidth : float + The bandwidth of the kernel. """ - samples_final = [] - max_reward = self.proxy2reward(torch.tensor([self.proxy.min])).to(self.device) - while len(samples_final) < n_samples: - angles_uniform = ( - torch.rand( - (n_samples, self.n_dim), dtype=self.float, device=self.device - ) - * 2 - * np.pi - ) - samples = torch.cat( - ( - angles_uniform, - torch.ones((angles_uniform.shape[0], 1)).to(angles_uniform), - ), - axis=1, - ) - rewards = tfloat( - self.reward_batch(samples), device=self.device, float_type=self.float - ) - mask = ( - torch.rand(n_samples, dtype=self.float, device=self.device) - * (max_reward + epsilon) - < rewards - ) - samples_accepted = samples[mask, :] - samples_final.extend(samples_accepted[-(n_samples - len(samples_final)) :]) - return torch.vstack(samples_final) - - def fit_kde(self, samples, kernel="gaussian", bandwidth=0.1): - aug_samples = [] + samples = torch2np(samples) + samples_aug = [] for add_0 in [0, -2 * np.pi, 2 * np.pi]: for add_1 in [0, -2 * np.pi, 2 * np.pi]: - aug_samples.append( + samples_aug.append( np.stack([samples[:, 0] + add_0, samples[:, 1] + add_1], axis=1) ) - aug_samples = np.concatenate(aug_samples) - kde = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(aug_samples) + samples_aug = np.concatenate(samples_aug) + kde = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(samples_aug) return kde def plot_reward_samples( self, - samples, - alpha=0.5, - low=-np.pi * 0.5, - high=2.5 * np.pi, - dpi=150, - limit_n_samples=500, + samples: TensorType["batch_size", "state_proxy_dim"], + samples_reward: TensorType["batch_size", "state_proxy_dim"], + rewards: TensorType["batch_size"], + alpha: float = 0.5, + dpi: int = 150, + max_samples: int = 500, **kwargs, ): - x = np.linspace(low, high, 201) - y = np.linspace(low, high, 201) + """ + Plots the reward contour alongside a batch of samples. + + The samples are previously augmented in order to visualise the periodic aspect + of the sample space. + + Parameters + ---------- + samples : tensor + A batch of samples from the GFlowNet policy in proxy format. These samples + will be plotted on top of the reward density. + samples_reward : tensor + A batch of samples containing a grid over the sample space, from which the + reward has been obtained. These samples are used to plot the contour of + reward density. + rewards : tensor + The reward of samples_reward. + alpha : float + Transparency of the reward contour. + dpi : int + Dots per inch, indicating the resolution of the plot. + max_samples : int + Maximum of number of samples to include in the plot. + """ + if self.n_dim != 2: + return None + samples = torch2np(samples) + samples_reward = torch2np(samples_reward) + rewards = torch2np(rewards) + # Create mesh from samples_reward + x = np.unique(samples_reward[:, 0]) + y = np.unique(samples_reward[:, 1]) xx, yy = np.meshgrid(x, y) - X = np.stack([xx, yy], axis=-1) - samples_mesh = torch.tensor(X.reshape(-1, 2), dtype=self.float) - states_mesh = torch.cat( - [samples_mesh, torch.ones(samples_mesh.shape[0], 1)], 1 - ).to(self.device) - rewards = torch2np( - self.proxy2reward(self.proxy(self.states2proxy(states_mesh))) - ) # Init figure fig, ax = plt.subplots() fig.set_dpi(dpi) @@ -599,25 +604,24 @@ def plot_reward_samples( ax.plot([0, 2 * np.pi], [0, 0], "-w", alpha=alpha) ax.plot([2 * np.pi, 2 * np.pi], [2 * np.pi, 0], "-w", alpha=alpha) ax.plot([2 * np.pi, 0], [2 * np.pi, 2 * np.pi], "-w", alpha=alpha) - # Plot samples - extra_samples = [] + # Augment samples + samples_aug = [] for add_0 in [0, -2 * np.pi, 2 * np.pi]: for add_1 in [0, -2 * np.pi, 2 * np.pi]: if not (add_0 == add_1 == 0): - extra_samples.append( + samples_aug.append( np.stack( [ - samples[:limit_n_samples, 0] + add_0, - samples[:limit_n_samples, 1] + add_1, + samples[:max_samples, 0] + add_0, + samples[:max_samples, 1] + add_1, ], axis=1, ) ) - extra_samples = np.concatenate(extra_samples) - ax.scatter( - samples[:limit_n_samples, 0], samples[:limit_n_samples, 1], alpha=alpha - ) - ax.scatter(extra_samples[:, 0], extra_samples[:, 1], alpha=alpha, color="white") + samples_aug = np.concatenate(samples_aug) + # Plot samples + ax.scatter(samples[:max_samples, 0], samples[:max_samples, 1], alpha=alpha) + ax.scatter(samples_aug[:, 0], samples_aug[:, 1], alpha=alpha, color="white") ax.grid() # Set tight layout plt.tight_layout() @@ -625,19 +629,38 @@ def plot_reward_samples( def plot_kde( self, + samples: TensorType["batch_size", "state_proxy_dim"], kde, - alpha=0.5, - low=-np.pi * 0.5, - high=2.5 * np.pi, + alpha: float = 0.5, dpi=150, - colorbar=True, + colorbar: bool = True, **kwargs, ): - x = np.linspace(0, 2 * np.pi, 101) - y = np.linspace(0, 2 * np.pi, 101) + """ + Plots the density previously estimated from a batch of samples via KDE over the + entire sample space. + + Parameters + ---------- + samples : tensor + A batch of samples containing a grid over the sample space. These samples + are used to plot the contour of the estimated density. + kde : KDE + A scikit-learn KDE object fit with a batch of samples. + alpha : float + Transparency of the density contour. + dpi : int + Dots per inch, indicating the resolution of the plot. + """ + if self.n_dim != 2: + return None + samples = torch2np(samples) + # Create mesh from samples_reward + x = np.unique(samples[:, 0]) + y = np.unique(samples[:, 1]) xx, yy = np.meshgrid(x, y) - X = np.stack([xx, yy], axis=-1) - Z = np.exp(kde.score_samples(X.reshape(-1, 2))).reshape(xx.shape) + # Score samples with KDE + Z = np.exp(kde.score_samples(samples)).reshape(xx.shape) # Init figure fig, ax = plt.subplots() fig.set_dpi(dpi) From b84011de00ca6bf00330d9e6e813032c606ba3c9 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Sat, 4 May 2024 23:48:44 -0400 Subject: [PATCH 46/73] Create mesh coordinates by reshaping samples --- gflownet/envs/cube.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/gflownet/envs/cube.py b/gflownet/envs/cube.py index 16745bee5..bb57407c6 100644 --- a/gflownet/envs/cube.py +++ b/gflownet/envs/cube.py @@ -1480,15 +1480,17 @@ def plot_reward_samples( samples = torch2np(samples) samples_reward = torch2np(samples_reward) rewards = torch2np(rewards) - # Create mesh from samples_reward - x = np.unique(samples_reward[:, 0]) - y = np.unique(samples_reward[:, 1]) - xx, yy = np.meshgrid(x, y) + # Create mesh grid from samples_reward + n_per_dim = int(np.sqrt(samples_reward.shape[0])) + assert n_per_dim**2 == samples_reward.shape[0] + x_coords = samples_reward[:, 0].reshape((n_per_dim, n_per_dim)) + y_coords = samples_reward[:, 1].reshape((n_per_dim, n_per_dim)) + rewards = rewards.reshape((n_per_dim, n_per_dim)) # Init figure fig, ax = plt.subplots() fig.set_dpi(dpi) # Plot reward contour - h = ax.contourf(xx, yy, rewards.reshape(xx.shape), alpha=alpha) + h = ax.contourf(x_coords, y_coords, rewards, alpha=alpha) ax.axis("scaled") fig.colorbar(h, ax=ax) # Plot samples @@ -1530,17 +1532,18 @@ def plot_kde( if self.n_dim != 2: return None samples = torch2np(samples) - # Create mesh from samples_reward - x = np.unique(samples[:, 0]) - y = np.unique(samples[:, 1]) - xx, yy = np.meshgrid(x, y) + # Create mesh grid from samples + n_per_dim = int(np.sqrt(samples.shape[0])) + assert n_per_dim**2 == samples.shape[0] + x_coords = samples[:, 0].reshape((n_per_dim, n_per_dim)) + y_coords = samples[:, 1].reshape((n_per_dim, n_per_dim)) # Score samples with KDE - Z = np.exp(kde.score_samples(samples)).reshape(xx.shape) + Z = np.exp(kde.score_samples(samples)).reshape((n_per_dim, n_per_dim)) # Init figure fig, ax = plt.subplots() fig.set_dpi(dpi) # Plot KDE - h = ax.contourf(xx, yy, Z, alpha=alpha) + h = ax.contourf(x_coords, y_coords, Z, alpha=alpha) ax.axis("scaled") if colorbar: fig.colorbar(h, ax=ax) From 602c1333f48020ada848cc5da8ae76b567ab5b78 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 6 May 2024 12:30:05 -0400 Subject: [PATCH 47/73] Small fix --- gflownet/gflownet.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index ce6d39154..18c561c9a 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -1267,7 +1267,9 @@ def test(self, **plot_kwargs): kde_true = dict_tt["kde_true"] else: # Sample from reward via rejection sampling - x_from_reward = self.sample_from_reward(n_samples=self.logger.test.n) + x_from_reward = self.env.states2proxy( + self.sample_from_reward(n_samples=self.logger.test.n) + ) # Fit KDE with samples from reward kde_true = self.env.fit_kde( x_from_reward, @@ -1522,7 +1524,8 @@ def sample_from_reward( Returns ------- samples_final : list - The list of samples drawn from the reward distribution. + The list of samples drawn from the reward distribution in environment + format. """ samples_final = [] max_reward = self.proxy.proxy2reward(self.proxy.min) From f87df71f5866784fc87f19a5017d714818142420 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 6 May 2024 12:30:22 -0400 Subject: [PATCH 48/73] Adapt plot methods of htorus --- gflownet/envs/htorus.py | 111 ++++++++++++++++++++++++---------------- 1 file changed, 68 insertions(+), 43 deletions(-) diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index 7377639fa..d7e813ff0 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -525,8 +525,8 @@ def fit_kde( r""" Fits a Kernel Density Estimator on a batch of samples. - The samples are previously augmented in order to visualise the periodic aspect - of the sample space. + The samples are previously augmented in order to account for the periodic + aspect of the sample space. Parameters ---------- @@ -540,13 +540,7 @@ def fit_kde( The bandwidth of the kernel. """ samples = torch2np(samples) - samples_aug = [] - for add_0 in [0, -2 * np.pi, 2 * np.pi]: - for add_1 in [0, -2 * np.pi, 2 * np.pi]: - samples_aug.append( - np.stack([samples[:, 0] + add_0, samples[:, 1] + add_1], axis=1) - ) - samples_aug = np.concatenate(samples_aug) + samples_aug = self.augment_samples(samples) kde = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(samples_aug) return kde @@ -555,6 +549,8 @@ def plot_reward_samples( samples: TensorType["batch_size", "state_proxy_dim"], samples_reward: TensorType["batch_size", "state_proxy_dim"], rewards: TensorType["batch_size"], + min_domain: float = -np.pi, + max_domain: float = 3 * np.pi, alpha: float = 0.5, dpi: int = 150, max_samples: int = 500, @@ -564,7 +560,7 @@ def plot_reward_samples( Plots the reward contour alongside a batch of samples. The samples are previously augmented in order to visualise the periodic aspect - of the sample space. + of the sample space. It is assumed that the samples and the rewards are sorted. Parameters ---------- @@ -577,6 +573,10 @@ def plot_reward_samples( reward density. rewards : tensor The reward of samples_reward. + min_domain : float + Minimum value of the domain to keep in the plot. + max_domain : float + Maximum value of the domain to keep in the plot. alpha : float Transparency of the reward contour. dpi : int @@ -589,39 +589,43 @@ def plot_reward_samples( samples = torch2np(samples) samples_reward = torch2np(samples_reward) rewards = torch2np(rewards) - # Create mesh from samples_reward - x = np.unique(samples_reward[:, 0]) - y = np.unique(samples_reward[:, 1]) - xx, yy = np.meshgrid(x, y) + n_per_dim = int(np.sqrt(samples_reward.shape[0])) + assert n_per_dim**2 == samples_reward.shape[0] + # Augment rewards to apply periodic boundary conditions + rewards = rewards.reshape((n_per_dim, n_per_dim)) + rewards = np.tile(rewards, (3, 3)) + # Create mesh grid from samples_reward + x = np.linspace(-2 * np.pi, 4 * np.pi, 3 * n_per_dim) + y = np.linspace(-2 * np.pi, 4 * np.pi, 3 * n_per_dim) + x_coords, y_coords = np.meshgrid(x, y) # Init figure fig, ax = plt.subplots() fig.set_dpi(dpi) # Plot reward contour - h = ax.contourf(xx, yy, rewards.reshape(xx.shape), alpha=alpha) + h = ax.contourf(x_coords, y_coords, rewards, alpha=alpha) ax.axis("scaled") fig.colorbar(h, ax=ax) ax.plot([0, 0], [0, 2 * np.pi], "-w", alpha=alpha) ax.plot([0, 2 * np.pi], [0, 0], "-w", alpha=alpha) ax.plot([2 * np.pi, 2 * np.pi], [2 * np.pi, 0], "-w", alpha=alpha) ax.plot([2 * np.pi, 0], [2 * np.pi, 2 * np.pi], "-w", alpha=alpha) + # Randomize and subsample samples + random_indices = np.random.permutation(samples.shape[0])[:max_samples] + samples = samples[random_indices, :] # Augment samples - samples_aug = [] - for add_0 in [0, -2 * np.pi, 2 * np.pi]: - for add_1 in [0, -2 * np.pi, 2 * np.pi]: - if not (add_0 == add_1 == 0): - samples_aug.append( - np.stack( - [ - samples[:max_samples, 0] + add_0, - samples[:max_samples, 1] + add_1, - ], - axis=1, - ) - ) - samples_aug = np.concatenate(samples_aug) - # Plot samples - ax.scatter(samples[:max_samples, 0], samples[:max_samples, 1], alpha=alpha) - ax.scatter(samples_aug[:, 0], samples_aug[:, 1], alpha=alpha, color="white") + samples_aug = self.augment_samples(samples, exclude_original=True) + ax.scatter( + samples_aug[:, 0], samples_aug[:, 1], alpha=1.5 * alpha, color="white" + ) + ax.scatter(samples[:, 0], samples[:, 1], alpha=alpha) + # Set axes limits + ax.set_xlim([min_domain, max_domain]) + ax.set_ylim([min_domain, max_domain]) + # Set ticks and labels + ticks = [0.0, np.pi / 2, np.pi, (3 * np.pi) / 2, 2 * np.pi] + labels = ["0.0", r"$\frac{\pi}{2}$", r"$\pi$", r"$\frac{3\pi}{3}$", f"$2\pi$"] + ax.set_xticks(ticks, labels) + ax.set_yticks(ticks, labels) ax.grid() # Set tight layout plt.tight_layout() @@ -655,12 +659,13 @@ def plot_kde( if self.n_dim != 2: return None samples = torch2np(samples) - # Create mesh from samples_reward - x = np.unique(samples[:, 0]) - y = np.unique(samples[:, 1]) - xx, yy = np.meshgrid(x, y) + # Create mesh grid from samples + n_per_dim = int(np.sqrt(samples.shape[0])) + assert n_per_dim**2 == samples.shape[0] + x_coords = samples[:, 0].reshape((n_per_dim, n_per_dim)) + y_coords = samples[:, 1].reshape((n_per_dim, n_per_dim)) # Score samples with KDE - Z = np.exp(kde.score_samples(samples)).reshape(xx.shape) + Z = np.exp(kde.score_samples(samples)).reshape((n_per_dim, n_per_dim)) # Init figure fig, ax = plt.subplots() fig.set_dpi(dpi) @@ -669,14 +674,34 @@ def plot_kde( ax.axis("scaled") if colorbar: fig.colorbar(h, ax=ax) - ax.set_xticks([]) - ax.set_yticks([]) - ax.text(0, -0.3, r"$0$", fontsize=15) - ax.text(-0.28, 0, r"$0$", fontsize=15) - ax.text(2 * np.pi - 0.4, -0.3, r"$2\pi$", fontsize=15) - ax.text(-0.45, 2 * np.pi - 0.3, r"$2\pi$", fontsize=15) + # Set ticks and labels + ticks = [0.0, np.pi / 2, np.pi, (3 * np.pi) / 2, 2 * np.pi] + labels = ["0.0", r"$\frac{\pi}{2}$", r"$\pi$", r"$\frac{3\pi}{3}$", f"$2\pi$"] + ax.set_xticks(ticks, labels) + ax.set_yticks(ticks, labels) for spine in ax.spines.values(): spine.set_visible(False) # Set tight layout plt.tight_layout() return fig + + @staticmethod + def augment_samples(samples: np.array, exclude_original: bool = False) -> np.array: + """ + Augments a batch of samples by applying the periodic boundary conditions from + [0, 2pi) to [-2pi, 4pi) for all dimensions. + """ + samples_aug = [] + for offsets in itertools.product( + [-2 * np.pi, 0.0, 2 * np.pi], repeat=samples.shape[-1] + ): + if exclude_original and all([offset == 0.0 for offset in offsets]): + continue + samples_aug.append( + np.stack( + [samples[:, dim] + offset for dim, offset in enumerate(offsets)], + axis=-1, + ) + ) + samples_aug = np.concatenate(samples_aug, axis=0) + return samples_aug From b88b6b1e7cd8d3075c1c0359d9b7ffedc1272090 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 6 May 2024 13:14:46 -0400 Subject: [PATCH 49/73] Change period of testing in icml ctorus config from 25 to 500 --- config/experiments/icml23/ctorus.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/experiments/icml23/ctorus.yaml b/config/experiments/icml23/ctorus.yaml index eb33c5d41..76e0c0b5a 100644 --- a/config/experiments/icml23/ctorus.yaml +++ b/config/experiments/icml23/ctorus.yaml @@ -56,7 +56,7 @@ logger: - continuous - ctorus test: - period: 25 + period: 500 n: 1000 checkpoints: period: 500 From 667b1508fe451b1de296c9fa4b234403f587c459 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 6 May 2024 13:24:49 -0400 Subject: [PATCH 50/73] Adapt get_grid_terminating_states so that the ordering of the dimensions matches the plotting (reshaping) ordering --- gflownet/envs/htorus.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index d7e813ff0..7b1007da7 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -502,10 +502,27 @@ def copy(self): return deepcopy(self) def get_grid_terminating_states(self, n_states: int) -> List[List]: + """ + Samples n terminating states by sub-sampling the state space as a grid, where n + / n_dim points are obtained for each dimension. + + Parameters + ---------- + n_states : int + The number of terminating states to sample. + + Returns + ------- + states : list + A list of randomly sampled terminating states. + """ n_per_dim = int(np.ceil(n_states ** (1 / self.n_dim))) - linspaces = [np.linspace(0, 2 * np.pi, n_per_dim) for _ in range(self.n_dim)] - angles = list(itertools.product(*linspaces)) - states = [list(el) + [self.length_traj] for el in angles] + linspace = np.linspace(0, 2 * np.pi, n_per_dim) + angles = np.meshgrid(*[linspace] * self.n_dim) + angles = np.stack(angles).reshape((self.n_dim, -1)).T + states = np.concatenate( + (angles, self.length_traj * np.ones((angles.shape[0], 1))), axis=1 + ).tolist() return states def get_uniform_terminating_states( @@ -664,13 +681,13 @@ def plot_kde( assert n_per_dim**2 == samples.shape[0] x_coords = samples[:, 0].reshape((n_per_dim, n_per_dim)) y_coords = samples[:, 1].reshape((n_per_dim, n_per_dim)) - # Score samples with KDE + # Score samples with KDE and reshape Z = np.exp(kde.score_samples(samples)).reshape((n_per_dim, n_per_dim)) # Init figure fig, ax = plt.subplots() fig.set_dpi(dpi) # Plot KDE - h = ax.contourf(xx, yy, Z, alpha=alpha) + h = ax.contourf(x_coords, y_coords, Z, alpha=alpha) ax.axis("scaled") if colorbar: fig.colorbar(h, ax=ax) From f32d999175af1d62c28a6c38b58f2f006e7425ec Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 6 May 2024 16:14:43 -0400 Subject: [PATCH 51/73] Adapt plotting methods of grid --- gflownet/envs/grid.py | 89 +++++++++++++++++++++++++++---------------- 1 file changed, 57 insertions(+), 32 deletions(-) diff --git a/gflownet/envs/grid.py b/gflownet/envs/grid.py index a8db1c410..b7426552c 100644 --- a/gflownet/envs/grid.py +++ b/gflownet/envs/grid.py @@ -9,11 +9,12 @@ import numpy as np import numpy.typing as npt import torch +from matplotlib.axes import Axes from mpl_toolkits.axes_grid1 import make_axes_locatable from torchtyping import TensorType from gflownet.envs.base import GFlowNetEnv -from gflownet.utils.common import tfloat, tlong +from gflownet.utils.common import tfloat, tlong, torch2np class Grid(GFlowNetEnv): @@ -311,9 +312,8 @@ def get_max_traj_length(self): return self.n_dim * self.length def get_all_terminating_states(self) -> List[List]: - all_x = np.int32( - list(itertools.product(*[list(range(self.length))] * self.n_dim)) - ) + grid = np.meshgrid(*[range(self.length)] * self.n_dim) + all_x = np.stack(grid).reshape((self.n_dim, -1)).T return all_x.tolist() def get_uniform_terminating_states( @@ -325,59 +325,84 @@ def get_uniform_terminating_states( def plot_reward_samples( self, - samples, - ax=None, - title=None, - rescale=1, - dpi=150, - n_ticks_max=50, - reward_norm=True, + samples: TensorType["batch_size", "state_proxy_dim"], + samples_reward: TensorType["batch_size", "state_proxy_dim"], + rewards: TensorType["batch_size"], + dpi: int = 150, + n_ticks_max: int = 50, + reward_norm: bool = True, + **kwargs, ): """ - Plot 2D histogram of samples. + Plots the reward density as a 2D histogram on the grid, alongside a histogram + representing the samples density. + + Parameters + ---------- + samples : tensor + A batch of samples from the GFlowNet policy in proxy format. These samples + will be plotted on top of the reward density. + samples_reward : tensor + A batch of samples containing a grid over the sample space, from which the + reward has been obtained. These samples are used to plot the contour of + reward density. + rewards : tensor + The reward of samples_reward. + dpi : int + Dots per inch, indicating the resolution of the plot. + n_ticks_max : int + Maximum of number of ticks to include in the axes. + reward_norm : bool + Whether to normalize the histogram. True by default. """ # Only available for 2D grids if self.n_dim != 2: return None + samples = torch2np(samples) + samples_reward = torch2np(samples_reward) + rewards = torch2np(rewards) # Init figure fig, axes = plt.subplots(ncols=2, dpi=dpi) step_ticks = np.ceil(self.length / n_ticks_max).astype(int) - # Get all states and their reward - if not hasattr(self, "_rewards_all_2d"): - states_all = self.get_all_terminating_states() - rewards_all = self.proxy2reward( - self.proxy(self.statebatch2proxy(states_all)) - ) - if reward_norm: - rewards_all = rewards_all / rewards_all.sum() - self._rewards_all_2d = torch.empty( - (self.length, self.length), device=self.device, dtype=self.float - ) - for row in range(self.length): - for col in range(self.length): - idx = states_all.index([row, col]) - self._rewards_all_2d[row, col] = rewards_all[idx] - self._rewards_all_2d = self._rewards_all_2d.detach().cpu().numpy() # 2D histogram of samples - samples = np.array(samples) samples_hist, xedges, yedges = np.histogram2d( samples[:, 0], samples[:, 1], bins=(self.length, self.length), density=True ) # Transpose and reverse rows so that [0, 0] is at bottom left samples_hist = samples_hist.T[::-1, :] + # Normalize and reshape reward into a grid with [0, 0] at the bottom left + if reward_norm: + rewards = rewards / rewards.sum() + rewards_2d = rewards.reshape(self.length, self.length).T[::-1, :] # Plot reward - self._plot_grid_2d(self._rewards_all_2d, axes[0], step_ticks) + self._plot_grid_2d(rewards_2d, axes[0], step_ticks, title="True reward") # Plot samples histogram - self._plot_grid_2d(samples_hist, axes[1], step_ticks) + self._plot_grid_2d(samples_hist, axes[1], step_ticks, title="Samples density") fig.tight_layout() return fig @staticmethod - def _plot_grid_2d(img, ax, step_ticks): + def _plot_grid_2d(img: np.array, ax: Axes, step_ticks: int, title: str): + """ + Plots a 2D histogram of a grid environment as an image. + + Parameters + ---------- + img : np.array + An array containing a 2D histogram over a grid. + ax : Axes + A matplotlib Axes object on which the image will be plotted. + step_ticks : int + The step value to add ticks to the axes. For example, if it is 2, the ticks + will be at 0, 2, 4, ... + title : str + Title for the axes. + """ ax_img = ax.imshow(img) divider = make_axes_locatable(ax) cax = divider.append_axes("top", size="5%", pad=0.05) ax.set_xticks(np.arange(start=0, stop=img.shape[0], step=step_ticks)) ax.set_yticks(np.arange(start=0, stop=img.shape[1], step=step_ticks)[::-1]) + cax.set_title(title) plt.colorbar(ax_img, cax=cax, orientation="horizontal") cax.xaxis.set_ticks_position("top") From 22e5de1d6994179727ac146e73e46c82be5d467e Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 6 May 2024 16:19:03 -0400 Subject: [PATCH 52/73] Do not start progress bar if there is only one batch --- gflownet/gflownet.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 18c561c9a..9383b157d 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -920,7 +920,8 @@ def estimate_logprobs_data( "Sampling backward actions from test data to estimate logprobs...", flush=True, ) - pbar = tqdm(total=n_states) + if n_states > batch_size: + pbar = tqdm(total=n_states) while init_batch < n_states: batch = Batch( env=self.env, @@ -970,7 +971,8 @@ def estimate_logprobs_data( # Increment batch indices init_batch += batch_size end_batch = min(end_batch + batch_size, n_states) - pbar.update(end_batch - init_batch) + if n_states > batch_size: + pbar.update(end_batch - init_batch) # Compute log of the average probabilities of the ratio PF / PB logprobs_estimates = torch.logsumexp( From 4ffa0bceaef285b200c4f1c159a0ff6c78e0353b Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 6 May 2024 17:15:31 -0400 Subject: [PATCH 53/73] Improve handling of logzero in proxy and fix tests --- gflownet/proxy/base.py | 18 ++-- tests/gflownet/proxy/test_base.py | 174 ++++++++++++++++++++++++++---- 2 files changed, 162 insertions(+), 30 deletions(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 81648eef2..d93de928f 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -12,6 +12,8 @@ from gflownet.utils.common import set_device, set_float_precision +LOGZERO = -1e3 + class Proxy(ABC): """ @@ -36,7 +38,13 @@ def __init__( self._reward_function, self._logreward_function = self._get_reward_functions( reward_function, logreward_function, **reward_function_kwargs ) + # Set minimum reward and log reward. If the minimum reward is exactly 0, + # the minimum log reward is set to -1000 in order to avoid -inf. self.reward_min = reward_min + if self.reward_min == 0: + self.logreward_min = LOGZERO + else: + self.logreward_min = np.log(self.reward_min) self.do_clip_rewards = do_clip_rewards # Device self.device = set_device(device) @@ -129,9 +137,8 @@ def proxy2logreward(self, proxy_values: TensorType) -> TensorType: def get_min_reward(self, log: bool = False) -> float: """ - Returns the minimum value of the (log) reward, retrieved from self.reward_min. - - If self.reward_min is exactly 0, then self.logreward_min is set to -inf. + Returns the minimum value of the (log) reward, retrieved from self.reward_min + and self.logreward_min. Parameters ---------- @@ -145,11 +152,6 @@ def get_min_reward(self, log: bool = False) -> float: The mimnimum (log) reward. """ if log: - if not hasattr(self, "logreward_min"): - if self.reward_min == 0.0: - self.logreward_min = -1e3 - else: - self.logreward_min = np.log(self.reward_min) return self.logreward_min else: return self.reward_min diff --git a/tests/gflownet/proxy/test_base.py b/tests/gflownet/proxy/test_base.py index 4f033c28a..4633adacc 100644 --- a/tests/gflownet/proxy/test_base.py +++ b/tests/gflownet/proxy/test_base.py @@ -2,7 +2,7 @@ import pytest import torch -from gflownet.proxy.base import Proxy +from gflownet.proxy.base import LOGZERO, Proxy from gflownet.proxy.uniform import Uniform from gflownet.utils.common import tfloat @@ -118,7 +118,7 @@ def check_proxy2reward(rewards_computed, rewards_expected, atol=1e-3): @pytest.mark.parametrize( - "beta, proxy_values, rewards_exp, logrewards_exp", + "beta, proxy_values, rewards_exp, logrewards_exp, logrewards_exp_clipped", [ ( 1, @@ -137,11 +137,29 @@ def check_proxy2reward(rewards_computed, rewards_expected, atol=1e-3): 2.3025, 4.6052, ], + [ + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + -2.3025, + -0.6931, + 0.0, + 2.3025, + 4.6052, + ], ), ], ) def test_reward_function_identity__behaves_as_expected( - proxy_identity, beta, proxy_values, rewards_exp, logrewards_exp + proxy_identity, + beta, + proxy_values, + rewards_exp, + logrewards_exp, + logrewards_exp_clipped, ): proxy = proxy_identity proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) @@ -151,14 +169,19 @@ def test_reward_function_identity__behaves_as_expected( assert all(check_proxy2reward(proxy.proxy2reward(proxy_values), rewards_exp)) # Log Rewards logrewards_exp = tfloat(logrewards_exp, device=proxy.device, float_type=proxy.float) + logrewards_exp_clipped = tfloat( + logrewards_exp_clipped, device=proxy.device, float_type=proxy.float + ) assert all( check_proxy2reward(proxy._logreward_function(proxy_values), logrewards_exp) ) - assert all(check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp)) + assert all( + check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp_clipped) + ) @pytest.mark.parametrize( - "beta, proxy_values, rewards_exp, logrewards_exp", + "beta, proxy_values, rewards_exp, logrewards_exp, logrewards_exp_clipped", [ ( 1, @@ -177,6 +200,19 @@ def test_reward_function_identity__behaves_as_expected( 2.3025, 4.6052, ], + [ + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + -2.3025, + -0.6931, + 0.0, + 2.3025, + 4.6052, + ], ), ( 2, @@ -195,11 +231,24 @@ def test_reward_function_identity__behaves_as_expected( 4.6052, 9.2103, ], + [ + 9.2103, + 4.6052, + 0.0, + -1.3863, + -4.6052, + LOGZERO, + -4.6052, + -1.3863, + 0.0, + 4.6052, + 9.2103, + ], ), ], ) def test_reward_function_power__behaves_as_expected( - proxy_power, beta, proxy_values, rewards_exp, logrewards_exp + proxy_power, beta, proxy_values, rewards_exp, logrewards_exp, logrewards_exp_clipped ): proxy = proxy_power proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) @@ -209,14 +258,19 @@ def test_reward_function_power__behaves_as_expected( assert all(check_proxy2reward(proxy.proxy2reward(proxy_values), rewards_exp)) # Log Rewards logrewards_exp = tfloat(logrewards_exp, device=proxy.device, float_type=proxy.float) + logrewards_exp_clipped = tfloat( + logrewards_exp_clipped, device=proxy.device, float_type=proxy.float + ) assert all( check_proxy2reward(proxy._logreward_function(proxy_values), logrewards_exp) ) - assert all(check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp)) + assert all( + check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp_clipped) + ) @pytest.mark.parametrize( - "beta, proxy_values, rewards_exp, logrewards_exp", + "beta, proxy_values, rewards_exp, logrewards_exp, logrewards_exp_clipped", [ ( 1.0, @@ -233,6 +287,7 @@ def test_reward_function_power__behaves_as_expected( 22026.4648, ], [-10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10], + [-10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10], ), ( -1.0, @@ -249,11 +304,17 @@ def test_reward_function_power__behaves_as_expected( 4.54e-05, ], [10, 1, 0.5, 0.1, 0.0, -0.1, -0.5, -1, -10], + [10, 1, 0.5, 0.1, 0.0, -0.1, -0.5, -1, -10], ), ], ) def test_reward_function_exponential__behaves_as_expected( - proxy_exponential, beta, proxy_values, rewards_exp, logrewards_exp + proxy_exponential, + beta, + proxy_values, + rewards_exp, + logrewards_exp, + logrewards_exp_clipped, ): proxy = proxy_exponential proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) @@ -263,14 +324,19 @@ def test_reward_function_exponential__behaves_as_expected( assert all(check_proxy2reward(proxy.proxy2reward(proxy_values), rewards_exp)) # Log Rewards logrewards_exp = tfloat(logrewards_exp, device=proxy.device, float_type=proxy.float) + logrewards_exp_clipped = tfloat( + logrewards_exp_clipped, device=proxy.device, float_type=proxy.float + ) assert all( check_proxy2reward(proxy._logreward_function(proxy_values), logrewards_exp) ) - assert all(check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp)) + assert all( + check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp_clipped) + ) @pytest.mark.parametrize( - "beta, proxy_values, rewards_exp, logrewards_exp", + "beta, proxy_values, rewards_exp, logrewards_exp, logrewards_exp_clipped", [ ( 5, @@ -289,6 +355,19 @@ def test_reward_function_exponential__behaves_as_expected( 2.7081, 4.6540, ], + [ + LOGZERO, + LOGZERO, + 1.3863, + 1.5041, + 1.5892, + 1.6094, + 1.6292, + 1.7047, + 1.7918, + 2.7081, + 4.6540, + ], ), ( -5, @@ -307,11 +386,24 @@ def test_reward_function_exponential__behaves_as_expected( 1.6094, 4.5539, ], + [ + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + 1.6094, + 4.5539, + ], ), ], ) def test_reward_function_shift__behaves_as_expected( - proxy_shift, beta, proxy_values, rewards_exp, logrewards_exp + proxy_shift, beta, proxy_values, rewards_exp, logrewards_exp, logrewards_exp_clipped ): proxy = proxy_shift proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) @@ -321,14 +413,19 @@ def test_reward_function_shift__behaves_as_expected( assert all(check_proxy2reward(proxy.proxy2reward(proxy_values), rewards_exp)) # Log Rewards logrewards_exp = tfloat(logrewards_exp, device=proxy.device, float_type=proxy.float) + logrewards_exp_clipped = tfloat( + logrewards_exp_clipped, device=proxy.device, float_type=proxy.float + ) assert all( check_proxy2reward(proxy._logreward_function(proxy_values), logrewards_exp) ) - assert all(check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp)) + assert all( + check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp_clipped) + ) @pytest.mark.parametrize( - "beta, proxy_values, rewards_exp, logrewards_exp", + "beta, proxy_values, rewards_exp, logrewards_exp, logrewards_exp_clipped", [ ( 2, @@ -347,6 +444,19 @@ def test_reward_function_shift__behaves_as_expected( 2.9957, 5.2983, ], + [ + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + -1.6094, + 0.0, + 0.6931, + 2.9957, + 5.2983, + ], ), ( -2, @@ -365,11 +475,29 @@ def test_reward_function_shift__behaves_as_expected( np.nan, np.nan, ], + [ + 5.2983, + 2.9957, + 0.6931, + 0.0, + -1.6094, + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + ], ), ], ) def test_reward_function_product__behaves_as_expected( - proxy_product, beta, proxy_values, rewards_exp, logrewards_exp + proxy_product, + beta, + proxy_values, + rewards_exp, + logrewards_exp, + logrewards_exp_clipped, ): proxy = proxy_product proxy_values = tfloat(proxy_values, device=proxy.device, float_type=proxy.float) @@ -379,10 +507,15 @@ def test_reward_function_product__behaves_as_expected( assert all(check_proxy2reward(proxy.proxy2reward(proxy_values), rewards_exp)) # Log Rewards logrewards_exp = tfloat(logrewards_exp, device=proxy.device, float_type=proxy.float) + logrewards_exp_clipped = tfloat( + logrewards_exp_clipped, device=proxy.device, float_type=proxy.float + ) assert all( check_proxy2reward(proxy._logreward_function(proxy_values), logrewards_exp) ) - assert all(check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp)) + assert all( + check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp_clipped) + ) @pytest.mark.parametrize( @@ -394,9 +527,9 @@ def test_reward_function_product__behaves_as_expected( [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], [-99, -9, 0, 0.5, 0.9, 1.0, 1.1, 1.5, 2, 11, 101], [ - np.nan, - np.nan, - -np.inf, + LOGZERO, + LOGZERO, + LOGZERO, -0.6931, -0.1054, 0.0, @@ -459,7 +592,4 @@ def test_reward_function_callable__behaves_as_expected( assert all(check_proxy2reward(proxy.proxy2reward(proxy_values), rewards_exp)) # Log Rewards logrewards_exp = tfloat(logrewards_exp, device=proxy.device, float_type=proxy.float) - assert all( - check_proxy2reward(proxy._logreward_function(proxy_values), logrewards_exp) - ) assert all(check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp)) From 1862d1e06b7f58d031cff819b1e4828049378bb6 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 6 May 2024 17:19:57 -0400 Subject: [PATCH 54/73] Fix test batch --- tests/gflownet/utils/test_batch.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 6d59d68e8..89cf87426 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -728,8 +728,6 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques env_aux = env_ref.copy().reset(idx + batch_size_forward) env_aux = env_aux.set_state(state=x, done=True) env_aux.n_actions = env_aux.get_max_traj_length() - env_aux.proxy = proxy - env_aux.setup_proxy() envs.append(env_aux) states_term_sorted.extend([copy(x) for x in x_batch]) @@ -957,8 +955,6 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): env_aux = env_ref.copy().reset(idx) env_aux = env_aux.set_state(state=x, done=True) env_aux.n_actions = env_aux.get_max_traj_length() - env_aux.proxy = proxy - env_aux.setup_proxy() envs.append(env_aux) states_term_sorted.extend([copy(x) for x in x_batch]) From d96827e9f301e40a3f193bb1ccb63912f273613a Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 6 May 2024 18:15:39 -0400 Subject: [PATCH 55/73] Add plot_samples_topk to Tetris and to GFN test() --- gflownet/envs/tetris.py | 86 +++++++++++++++++++++++++++++++++++++++++ gflownet/gflownet.py | 58 ++++++++++++++++----------- 2 files changed, 121 insertions(+), 23 deletions(-) diff --git a/gflownet/envs/tetris.py b/gflownet/envs/tetris.py index 07dd96918..1314fbb9d 100644 --- a/gflownet/envs/tetris.py +++ b/gflownet/envs/tetris.py @@ -7,6 +7,7 @@ import warnings from typing import List, Optional, Tuple, Union +import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt import torch @@ -25,6 +26,17 @@ "Z": [7, [[7, 7, 0], [0, 7, 7]]], } +PIECES_COLORS = { + 0: [255, 255, 255], + 1: [19, 232, 232], + 2: [30, 30, 201], + 3: [240, 110, 2], + 4: [236, 236, 14], + 5: [0, 128, 0], + 6: [125, 5, 126], + 7: [236, 14, 14], +} + class Tetris(GFlowNetEnv): """ @@ -513,3 +525,77 @@ def _get_max_piece_idx( return max_relevant_piece_idx + incr else: return min_idx + + def plot_samples_topk( + self, + samples: List, + rewards: TensorType["batch_size"], + k_top: int = 10, + n_rows: int = 2, + dpi: int = 150, + ): + """ + Plot tetris boards of top K samples. + + Parameters + ---------- + samples : list + List of terminating states sampled from the policy. + rewards : list + List of terminating states. + k_top : int + The number of samples that will be included in the plot. The k_top samples + with the highest reward are selected. + n_rows : int + Number of rows in the plot. The number of columns will be calculated + according the n_rows and k_top. + dpi : int + DPI (dots per inch) of the figure, to determine the resolution. + """ + # Init figure + n_cols = np.ceil(k_top / n_rows).astype(int) + fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, dpi=dpi) + rewards_topk, indices_topk = torch.sort(rewards, descending=True)[:k_top] + indices_topk = indices_topk.tolist() + for idx, ax in zip(indices_topk, axes.flatten()): + self._plot_board(samples[idx], ax) + fig.tight_layout() + return fig + + @staticmethod + def _plot_board(board, ax, cellsize=20, linewidth=2): + """ + Plots a single Tetris board (a state). + + Args + ---- + board : tensor + State to plot. + + ax : matplotlib Axis + The axis in which to plot the board. + + cellsize : int + The size (length) of each board cell, in pixels. + + linewidth : int + The width of the separation between cells, in pixels. + """ + board = board.clone().numpy() + height = board.shape[0] * cellsize + width = board.shape[1] * cellsize + board_img = 128 * np.ones( + (height + linewidth, width + linewidth, 3), dtype=np.uint8 + ) + for row in range(board.shape[0]): + for col in range(board.shape[1]): + row_init = row * cellsize + linewidth + row_end = row_init + cellsize - linewidth + col_init = col * cellsize + linewidth + col_end = col_init + cellsize - linewidth + color_key = int(board[row, col] / 100) + board_img[row_init:row_end, col_init:col_end, :] = PIECES_COLORS[ + color_key + ] + ax.imshow(board_img) + ax.set_axis_off() diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 9383b157d..9c3bdce13 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -1003,6 +1003,7 @@ def train(self): "True reward and GFlowNet samples", "GFlowNet KDE Policy", "Reward KDE", + "Samples TopK", ] if self.logger.do_test(it): ( @@ -1253,6 +1254,8 @@ def test(self, **plot_kwargs): density_pred = np.array([hist[tuple(x)] / z_pred for x in x_tt]) log_density_true = np.log(density_true + 1e-8) log_density_pred = np.log(density_pred + 1e-8) + elif self.buffer.test_type == "random": + env_metrics = self.env.test(x_sampled) elif self.continuous and hasattr(self.env, "fit_kde"): batch, _ = self.sample_batch(n_forward=self.logger.test.n, train=False) assert batch.is_valid() @@ -1294,29 +1297,23 @@ def test(self, **plot_kwargs): density_true = np.exp(log_density_true) density_pred = np.exp(log_density_pred) else: - # TODO: refactor - env_metrics = self.env.test(x_sampled) - return ( - self.l1, - self.kl, - self.jsd, - corr_prob_traj_rewards, - var_logrewards_logp, - nll_tt, - mean_logprobs_std, - mean_probs_std, - logprobs_std_nll_ratio, - (None,), - env_metrics, + raise NotImplementedError + + if self.buffer.test_type == "all" or self.continuous: + # L1 error + l1 = np.abs(density_pred - density_true).mean() + # KL divergence + kl = (density_true * (log_density_true - log_density_pred)).mean() + # Jensen-Shannon divergence + log_mean_dens = np.logaddexp(log_density_true, log_density_pred) + np.log( + 0.5 ) - # L1 error - l1 = np.abs(density_pred - density_true).mean() - # KL divergence - kl = (density_true * (log_density_true - log_density_pred)).mean() - # Jensen-Shannon divergence - log_mean_dens = np.logaddexp(log_density_true, log_density_pred) + np.log(0.5) - jsd = 0.5 * np.sum(density_true * (log_density_true - log_mean_dens)) - jsd += 0.5 * np.sum(density_pred * (log_density_pred - log_mean_dens)) + jsd = 0.5 * np.sum(density_true * (log_density_true - log_mean_dens)) + jsd += 0.5 * np.sum(density_pred * (log_density_pred - log_mean_dens)) + else: + l1 = self.l1 + kl = self.kl + jsd = self.jsd # Plots if hasattr(self.env, "plot_reward_samples"): @@ -1370,6 +1367,21 @@ def test(self, **plot_kwargs): else: fig_kde_pred = None fig_kde_true = None + if hasattr(self.env, "plot_samples_topk"): + # TODO: samples are in environment format because this is what is needed by + # Tetris. We may want to adapt it. + # TODO: this is a pretty bad implementation, but it will be fixed with the + # Evaluator + batch, _ = self.sample_batch(n_forward=self.logger.test.n, train=False) + x_sampled = batch.get_terminating_states() + rewards = self.proxy.rewards(self.env.states2proxy(x_sampled)) + fig_samples_topk = self.env.plot_samples_topk( + x_sampled, + rewards, + **plot_kwargs, + ) + else: + fig_samples_topk = None return ( l1, kl, @@ -1380,7 +1392,7 @@ def test(self, **plot_kwargs): mean_logprobs_std, mean_probs_std, logprobs_std_nll_ratio, - [fig_reward_samples, fig_kde_pred, fig_kde_true], + [fig_reward_samples, fig_kde_pred, fig_kde_true, fig_samples_topk], {}, ) From 8d733c49c8aecfa018e03ae1c73e3a59960056ce Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 6 May 2024 22:59:00 -0400 Subject: [PATCH 56/73] Docstring --- gflownet/envs/tetris.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/gflownet/envs/tetris.py b/gflownet/envs/tetris.py index 1314fbb9d..9cd2f722b 100644 --- a/gflownet/envs/tetris.py +++ b/gflownet/envs/tetris.py @@ -11,6 +11,7 @@ import numpy as np import numpy.typing as npt import torch +from matplotlib.axes import Axes from torchtyping import TensorType from gflownet.envs.base import GFlowNetEnv @@ -555,6 +556,7 @@ def plot_samples_topk( # Init figure n_cols = np.ceil(k_top / n_rows).astype(int) fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, dpi=dpi) + # Select top-k samples and plot them rewards_topk, indices_topk = torch.sort(rewards, descending=True)[:k_top] indices_topk = indices_topk.tolist() for idx, ax in zip(indices_topk, axes.flatten()): @@ -563,21 +565,18 @@ def plot_samples_topk( return fig @staticmethod - def _plot_board(board, ax, cellsize=20, linewidth=2): + def _plot_board(board, ax: Axes, cellsize: int = 20, linewidth: int = 2): """ Plots a single Tetris board (a state). - Args - ---- + Parameters + ---------- board : tensor State to plot. - - ax : matplotlib Axis - The axis in which to plot the board. - + ax : matplotlib Axes object + A matplotlib Axes object on which the board will be plotted. cellsize : int The size (length) of each board cell, in pixels. - linewidth : int The width of the separation between cells, in pixels. """ From 30c6b9cc5cfe6701613adceaefac833637ed9a94 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 6 May 2024 22:59:16 -0400 Subject: [PATCH 57/73] Re-enable tests by default for Tetris --- config/env/tetris.yaml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/config/env/tetris.yaml b/config/env/tetris.yaml index 123d6738b..67168f626 100644 --- a/config/env/tetris.yaml +++ b/config/env/tetris.yaml @@ -17,5 +17,9 @@ allow_eos_before_full: False buffer: data_path: null train: null - test: null + test: + type: random + n: 10 + output_csv: tetris_test.csv + output_pkl: tetris_test.pkl From 18f2cdeee436f221452417e20edb6edcd89be290 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 6 May 2024 23:08:53 -0400 Subject: [PATCH 58/73] Remove references to reward_norm --- gflownet/envs/base.py | 9 --------- gflownet/gflownet.py | 3 --- 2 files changed, 12 deletions(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index 83c51a809..55e388dba 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -34,8 +34,6 @@ def __init__( device: str = "cpu", float_precision: int = 32, env_id: Union[int, str] = "env", - reward_norm: float = 1.0, - reward_norm_std_mult: float = 0.0, energies_stats: List[int] = None, fixed_distr_params: Optional[dict] = None, random_distr_params: Optional[dict] = None, @@ -54,10 +52,6 @@ def __init__( self.device = set_device(device) # Float precision self.float = set_float_precision(float_precision) - # Reward settings - self.reward_norm = reward_norm - assert self.reward_norm > 0 - self.reward_norm_std_mult = reward_norm_std_mult self.energies_stats = energies_stats # Flag to skip checking if action is valid (computing mask) before step self.skip_mask_check = skip_mask_check @@ -838,9 +832,6 @@ def isclose(state_x, state_y, atol=1e-8): def set_energies_stats(self, energies_stats): self.energies_stats = energies_stats - def set_reward_norm(self, reward_norm): - self.reward_norm = reward_norm - def get_max_traj_length(self): return 1e3 diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 9c3bdce13..98904c743 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -134,9 +134,6 @@ def __init__( print(f"\tMax score: {energies_stats_tr[1]}") else: energies_stats_tr = None - if self.env.reward_norm_std_mult > 0 and energies_stats_tr is not None: - self.env.reward_norm = self.env.reward_norm_std_mult * energies_stats_tr[3] - self.env.set_reward_norm(self.env.reward_norm) # Test set statistics if self.buffer.test is not None: print("\nTest data") From 365d22c09dc4d184bf186e8522c5bb022d52d7da Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 6 May 2024 23:11:08 -0400 Subject: [PATCH 59/73] Remove references to energies_stats --- gflownet/envs/base.py | 5 ----- gflownet/gflownet.py | 1 - 2 files changed, 6 deletions(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index 55e388dba..27b1fd691 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -34,7 +34,6 @@ def __init__( device: str = "cpu", float_precision: int = 32, env_id: Union[int, str] = "env", - energies_stats: List[int] = None, fixed_distr_params: Optional[dict] = None, random_distr_params: Optional[dict] = None, skip_mask_check: bool = False, @@ -52,7 +51,6 @@ def __init__( self.device = set_device(device) # Float precision self.float = set_float_precision(float_precision) - self.energies_stats = energies_stats # Flag to skip checking if action is valid (computing mask) before step self.skip_mask_check = skip_mask_check # Log SoftMax function @@ -829,9 +827,6 @@ def isclose(state_x, state_y, atol=1e-8): else: return np.all(np.isclose(state_x, state_y, atol=atol)) - def set_energies_stats(self, energies_stats): - self.energies_stats = energies_stats - def get_max_traj_length(self): return 1e3 diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 98904c743..9e9484463 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -126,7 +126,6 @@ def __init__( self.buffer.std_tr, self.buffer.max_norm_tr, ] - self.env.set_energies_stats(energies_stats_tr) print("\nTrain data") print(f"\tMean score: {energies_stats_tr[2]}") print(f"\tStd score: {energies_stats_tr[3]}") From d69950e6e6f607b73e3052a5cc53d354316e3ede Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 6 May 2024 23:15:41 -0400 Subject: [PATCH 60/73] Update env base config --- config/env/base.yaml | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/config/env/base.yaml b/config/env/base.yaml index ace90b1a7..ca36bd3d0 100644 --- a/config/env/base.yaml +++ b/config/env/base.yaml @@ -1,19 +1,17 @@ _target_: gflownet.envs.base.GFlowNetEnv -# Reward normalization for "power" reward function -reward_norm: 1.0 -# If > 0, reward_norm = reward_norm_std_mult * std(energies) -reward_norm_std_mult: 0.0 +env_id: "env" +# Policy distribution parameters +fixed_distr_params: null +random_distr_params: null # Check if action valid with mask before step skip_mask_check: False # Whether the environment has conditioning variables conditional: False # Whether the environment is continuous continuous: False -# Buffer +# Buffer: no train and test buffers by default buffer: replay_capacity: 0 - train: - path: null - test: - path: null + train: null + test: null From 7f9e13621b15fef9ca5a22cc6c6e15b3534f95eb Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 6 May 2024 23:40:25 -0400 Subject: [PATCH 61/73] Remove reward_beta and reward_func from config files --- config/experiments/ccube/corners.yaml | 1 - config/experiments/ccube/uniform.yaml | 1 - config/experiments/clatticeparams/clatticeparams_owl.yaml | 1 - config/experiments/crystals/albatross.yaml | 2 -- config/experiments/crystals/albatross_sg_first.yaml | 2 -- config/experiments/crystals/lattice_parameters.yaml | 2 -- config/experiments/crystals/pigeon.yaml | 2 -- config/experiments/icml23/ctorus.yaml | 1 - config/experiments/icml23/dtorus.yaml | 1 - config/experiments/icml23/htorus.dryrun.yaml | 1 - config/experiments/icml23/htorus.yaml | 1 - config/experiments/neurips23/crystal-comp-sg-lp.yaml | 2 -- config/experiments/scrabble/jay.yaml | 1 - config/experiments/scrabble/penguin.yaml | 1 - config/experiments/simple_tetris.yaml | 2 -- config/experiments/tree.yaml | 2 -- config/experiments/workshop23/discrete-matbench.yaml | 2 -- 17 files changed, 25 deletions(-) diff --git a/config/experiments/ccube/corners.yaml b/config/experiments/ccube/corners.yaml index ccc207c6f..cde07ed3d 100644 --- a/config/experiments/ccube/corners.yaml +++ b/config/experiments/ccube/corners.yaml @@ -28,7 +28,6 @@ env: beta_beta: 10.0 bernoulli_eos_prob: 0.1 bernoulli_bts_prob: 0.1 - reward_func: identity # GFlowNet hyperparameters gflownet: diff --git a/config/experiments/ccube/uniform.yaml b/config/experiments/ccube/uniform.yaml index a81d58d05..a7b4da059 100644 --- a/config/experiments/ccube/uniform.yaml +++ b/config/experiments/ccube/uniform.yaml @@ -28,7 +28,6 @@ env: beta_beta: 10.0 bernoulli_eos_prob: 0.1 bernoulli_bts_prob: 0.1 - reward_func: identity # GFlowNet hyperparameters gflownet: diff --git a/config/experiments/clatticeparams/clatticeparams_owl.yaml b/config/experiments/clatticeparams/clatticeparams_owl.yaml index 30f1c1347..94496f329 100644 --- a/config/experiments/clatticeparams/clatticeparams_owl.yaml +++ b/config/experiments/clatticeparams/clatticeparams_owl.yaml @@ -33,7 +33,6 @@ env: beta_beta: 0.01 bernoulli_source_logit: 1.0 bernoulli_eos_logit: 1.0 - reward_func: identity # GFlowNet hyperparameters gflownet: diff --git a/config/experiments/crystals/albatross.yaml b/config/experiments/crystals/albatross.yaml index db33a21c1..169ad834f 100644 --- a/config/experiments/crystals/albatross.yaml +++ b/config/experiments/crystals/albatross.yaml @@ -42,8 +42,6 @@ env: beta_beta: 10.0 bernoulli_eos_prob: 0.1 bernoulli_bts_prob: 0.1 - reward_func: boltzmann - reward_beta: 8 buffer: replay_capacity: 0 test: diff --git a/config/experiments/crystals/albatross_sg_first.yaml b/config/experiments/crystals/albatross_sg_first.yaml index 431f7a4e8..e1d43dd8f 100644 --- a/config/experiments/crystals/albatross_sg_first.yaml +++ b/config/experiments/crystals/albatross_sg_first.yaml @@ -44,8 +44,6 @@ env: beta_beta: 10.0 bernoulli_eos_prob: 0.1 bernoulli_bts_prob: 0.1 - reward_func: boltzmann - reward_beta: 8 buffer: replay_capacity: 0 test: diff --git a/config/experiments/crystals/lattice_parameters.yaml b/config/experiments/crystals/lattice_parameters.yaml index b224cb4e9..f8cfdb6a5 100644 --- a/config/experiments/crystals/lattice_parameters.yaml +++ b/config/experiments/crystals/lattice_parameters.yaml @@ -9,8 +9,6 @@ defaults: # Environment env: - reward_func: boltzmann - reward_beta: 0.3 buffer: replay_capacity: 1000 diff --git a/config/experiments/crystals/pigeon.yaml b/config/experiments/crystals/pigeon.yaml index 880647ce1..293741f47 100644 --- a/config/experiments/crystals/pigeon.yaml +++ b/config/experiments/crystals/pigeon.yaml @@ -46,8 +46,6 @@ env: beta_beta: 10.0 bernoulli_eos_prob: 0.1 bernoulli_bts_prob: 0.1 - reward_func: boltzmann - reward_beta: 8 buffer: replay_capacity: 0 test: diff --git a/config/experiments/icml23/ctorus.yaml b/config/experiments/icml23/ctorus.yaml index 76e0c0b5a..8f5753da3 100644 --- a/config/experiments/icml23/ctorus.yaml +++ b/config/experiments/icml23/ctorus.yaml @@ -12,7 +12,6 @@ env: n_dim: 2 length_traj: 10 n_comp: 5 - reward_func: identity vonmises_min_concentration: 4 policy_encoding_dim_per_angle: 10 diff --git a/config/experiments/icml23/dtorus.yaml b/config/experiments/icml23/dtorus.yaml index cab31e90f..74be66e1d 100644 --- a/config/experiments/icml23/dtorus.yaml +++ b/config/experiments/icml23/dtorus.yaml @@ -12,7 +12,6 @@ env: n_dim: 2 n_angles: 20 length_traj: 20 - reward_func: identity # Proxy proxy: diff --git a/config/experiments/icml23/htorus.dryrun.yaml b/config/experiments/icml23/htorus.dryrun.yaml index e708cec84..b8d4dfb80 100644 --- a/config/experiments/icml23/htorus.dryrun.yaml +++ b/config/experiments/icml23/htorus.dryrun.yaml @@ -11,7 +11,6 @@ defaults: env: n_dim: 2 length_traj: 5 - reward_func: identity # Proxy proxy: diff --git a/config/experiments/icml23/htorus.yaml b/config/experiments/icml23/htorus.yaml index dd31912c8..51f2793f3 100644 --- a/config/experiments/icml23/htorus.yaml +++ b/config/experiments/icml23/htorus.yaml @@ -11,7 +11,6 @@ defaults: env: n_dim: 2 length_traj: 20 - reward_func: identity # Proxy proxy: diff --git a/config/experiments/neurips23/crystal-comp-sg-lp.yaml b/config/experiments/neurips23/crystal-comp-sg-lp.yaml index 278627769..6a7e16927 100644 --- a/config/experiments/neurips23/crystal-comp-sg-lp.yaml +++ b/config/experiments/neurips23/crystal-comp-sg-lp.yaml @@ -16,8 +16,6 @@ env: grid_size: 10 composition_kwargs: elements: 89 - reward_func: boltzmann - reward_beta: 1 # GFlowNet hyperparameters gflownet: diff --git a/config/experiments/scrabble/jay.yaml b/config/experiments/scrabble/jay.yaml index 116595238..f6d7d510e 100644 --- a/config/experiments/scrabble/jay.yaml +++ b/config/experiments/scrabble/jay.yaml @@ -20,7 +20,6 @@ env: n: 1000 output_csv: scrabble_test.csv output_pkl: scrabble_test.pkl - reward_func: identity # Proxy proxy: diff --git a/config/experiments/scrabble/penguin.yaml b/config/experiments/scrabble/penguin.yaml index 8110dab92..09a1792b3 100644 --- a/config/experiments/scrabble/penguin.yaml +++ b/config/experiments/scrabble/penguin.yaml @@ -20,7 +20,6 @@ env: n: 1000 output_csv: scrabble_test.csv output_pkl: scrabble_test.pkl - reward_func: identity # Proxy proxy: diff --git a/config/experiments/simple_tetris.yaml b/config/experiments/simple_tetris.yaml index 716169dae..56255eb73 100644 --- a/config/experiments/simple_tetris.yaml +++ b/config/experiments/simple_tetris.yaml @@ -8,8 +8,6 @@ defaults: - override /logger: wandb env: - reward_func: boltzmann - reward_beta: 10.0 width: 4 height: 4 pieces: ["I", "O", "J", "L", "T"] diff --git a/config/experiments/tree.yaml b/config/experiments/tree.yaml index 94451698d..582b5bd06 100644 --- a/config/experiments/tree.yaml +++ b/config/experiments/tree.yaml @@ -14,8 +14,6 @@ env: continuous: False policy_format: mlp threshold_components: 3 - reward_func: boltzmann - reward_beta: 32 test_args: top_k_trees: 100 buffer: diff --git a/config/experiments/workshop23/discrete-matbench.yaml b/config/experiments/workshop23/discrete-matbench.yaml index a68b0c367..d05da18ff 100644 --- a/config/experiments/workshop23/discrete-matbench.yaml +++ b/config/experiments/workshop23/discrete-matbench.yaml @@ -18,8 +18,6 @@ env: grid_size: 10 composition_kwargs: elements: [1,3,4,5,6,7,8,9,11,12,13,14,15,16,17,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,89,90,91,92,93,94] - reward_func: boltzmann - reward_beta: 1 buffer: replay_capacity: 0 From 4267acba800ddba06cd59b02839bd90ba016d541 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 7 May 2024 00:34:14 -0400 Subject: [PATCH 62/73] Update proxies: outputs are not negative by default anymore; Restore proxy2reward and beta values in configs but as proxy config --- config/experiments/crystals/albatross.yaml | 5 +++++ config/experiments/crystals/albatross_sg_first.yaml | 5 +++++ config/experiments/crystals/lattice_parameters.yaml | 5 +++++ config/experiments/crystals/pigeon.yaml | 5 +++++ config/experiments/neurips23/crystal-comp-sg-lp.yaml | 5 +++++ config/experiments/simple_tetris.yaml | 6 +++++- config/experiments/tree.yaml | 5 +++++ config/experiments/workshop23/discrete-matbench.yaml | 5 +++++ config/proxy/base.yaml | 6 +++--- gflownet/proxy/base.py | 2 +- gflownet/proxy/corners.py | 3 +-- gflownet/proxy/scrabble.py | 2 +- gflownet/proxy/tetris.py | 4 ++-- gflownet/proxy/torus.py | 8 ++++---- gflownet/proxy/uniform.py | 4 ++-- 15 files changed, 54 insertions(+), 16 deletions(-) diff --git a/config/experiments/crystals/albatross.yaml b/config/experiments/crystals/albatross.yaml index 169ad834f..db47fc36b 100644 --- a/config/experiments/crystals/albatross.yaml +++ b/config/experiments/crystals/albatross.yaml @@ -50,6 +50,11 @@ env: output_csv: ccrystal_val.csv output_pkl: ccrystal_val.pkl +# Proxy +proxy: + reward_function: exponential + beta: 8 + # GFlowNet hyperparameters gflownet: random_action_prob: 0.1 diff --git a/config/experiments/crystals/albatross_sg_first.yaml b/config/experiments/crystals/albatross_sg_first.yaml index e1d43dd8f..d79cb2f10 100644 --- a/config/experiments/crystals/albatross_sg_first.yaml +++ b/config/experiments/crystals/albatross_sg_first.yaml @@ -52,6 +52,11 @@ env: output_csv: ccrystal_val.csv output_pkl: ccrystal_val.pkl +# Proxy +proxy: + reward_function: exponential + beta: 8 + # GFlowNet hyperparameters gflownet: random_action_prob: 0.1 diff --git a/config/experiments/crystals/lattice_parameters.yaml b/config/experiments/crystals/lattice_parameters.yaml index f8cfdb6a5..f569fd7f9 100644 --- a/config/experiments/crystals/lattice_parameters.yaml +++ b/config/experiments/crystals/lattice_parameters.yaml @@ -12,6 +12,11 @@ env: buffer: replay_capacity: 1000 +# Proxy +proxy: + reward_function: exponential + beta: 0.3 + # GFlowNet hyperparameters gflownet: random_action_prob: 0.1 diff --git a/config/experiments/crystals/pigeon.yaml b/config/experiments/crystals/pigeon.yaml index 293741f47..2cdb56581 100644 --- a/config/experiments/crystals/pigeon.yaml +++ b/config/experiments/crystals/pigeon.yaml @@ -54,6 +54,11 @@ env: output_csv: ccrystal_val.csv output_pkl: ccrystal_val.pkl +# Proxy +proxy: + reward_function: exponential + beta: 8 + # GFlowNet hyperparameters gflownet: random_action_prob: 0.1 diff --git a/config/experiments/neurips23/crystal-comp-sg-lp.yaml b/config/experiments/neurips23/crystal-comp-sg-lp.yaml index 6a7e16927..d5f6e7b5a 100644 --- a/config/experiments/neurips23/crystal-comp-sg-lp.yaml +++ b/config/experiments/neurips23/crystal-comp-sg-lp.yaml @@ -17,6 +17,11 @@ env: composition_kwargs: elements: 89 +# Proxy +proxy: + reward_function: exponential + beta: 1 + # GFlowNet hyperparameters gflownet: random_action_prob: 0.1 diff --git a/config/experiments/simple_tetris.yaml b/config/experiments/simple_tetris.yaml index 56255eb73..7c8795d3c 100644 --- a/config/experiments/simple_tetris.yaml +++ b/config/experiments/simple_tetris.yaml @@ -20,6 +20,10 @@ env: output_pkl: simple_tetris_val.pkl n: 100 +proxy: + reward_function: exponential + beta: 10 + gflownet: random_action_prob: 0.3 optimizer: @@ -42,4 +46,4 @@ device: cpu logger: do: online: True - project_name: simple_tetris \ No newline at end of file + project_name: simple_tetris diff --git a/config/experiments/tree.yaml b/config/experiments/tree.yaml index 582b5bd06..9b8feed35 100644 --- a/config/experiments/tree.yaml +++ b/config/experiments/tree.yaml @@ -19,6 +19,11 @@ env: buffer: replay_capacity: 100 +# Proxy +proxy: + reward_function: exponential + beta: 32 + # GFlowNet hyperparameters gflownet: random_action_prob: 0.1 diff --git a/config/experiments/workshop23/discrete-matbench.yaml b/config/experiments/workshop23/discrete-matbench.yaml index d05da18ff..a865bc092 100644 --- a/config/experiments/workshop23/discrete-matbench.yaml +++ b/config/experiments/workshop23/discrete-matbench.yaml @@ -21,6 +21,11 @@ env: buffer: replay_capacity: 0 +# Proxy +proxy: + reward_function: exponential + beta: 1 + # GFlowNet hyperparameters gflownet: random_action_prob: 0.1 diff --git a/config/proxy/base.yaml b/config/proxy/base.yaml index 1dd2b1f5f..d60975823 100644 --- a/config/proxy/base.yaml +++ b/config/proxy/base.yaml @@ -1,14 +1,14 @@ _target_: gflownet.proxy.base.Proxy # Reward function: string identifier of the proxy-to-reward function: -# - identity -# - absolute (default) +# - identity (default) +# - absolute # - power # - exponential # - shift # - product # Alternatively, it can be a callable of the function itself. -reward_function: absolute +reward_function: identity # A callable of the proxy-to-logreward function. # None by default, which takes the log of the proxy-to-reward function logreward_function: null diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index d93de928f..031f65c7e 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -24,7 +24,7 @@ def __init__( self, device, float_precision, - reward_function: Optional[Union[Callable, str]] = "absolute", + reward_function: Optional[Union[Callable, str]] = "identity", logreward_function: Optional[Callable] = None, reward_function_kwargs: Optional[dict] = {}, reward_min: float = 0.0, diff --git a/gflownet/proxy/corners.py b/gflownet/proxy/corners.py index e5e4e0a57..f07e275ea 100644 --- a/gflownet/proxy/corners.py +++ b/gflownet/proxy/corners.py @@ -41,8 +41,7 @@ def min(self): def __call__(self, states: TensorType["batch", "state_dim"]) -> TensorType["batch"]: return ( - -1.0 - * self.mulnormal_norm + self.mulnormal_norm * torch.exp( -0.5 * ( diff --git a/gflownet/proxy/scrabble.py b/gflownet/proxy/scrabble.py index 9ac9afbd0..5f08b3a48 100644 --- a/gflownet/proxy/scrabble.py +++ b/gflownet/proxy/scrabble.py @@ -93,7 +93,7 @@ def __call__( ): scores.append(0.0) else: - scores.append(-1.0 * self._sum_scores(sample)) + scores.append(self._sum_scores(sample)) return tfloat(scores, device=self.device, float_type=self.float) else: raise NotImplementedError( diff --git a/gflownet/proxy/tetris.py b/gflownet/proxy/tetris.py index 885091ee0..93f8e3341 100644 --- a/gflownet/proxy/tetris.py +++ b/gflownet/proxy/tetris.py @@ -17,9 +17,9 @@ def setup(self, env=None): @property def norm(self): if self.normalize: - return -(self.height * self.width) + return (self.height * self.width) else: - return -1.0 + return 1.0 def __call__(self, states: TensorType["batch", "state_dim"]) -> TensorType["batch"]: if states.dim() == 2: diff --git a/gflownet/proxy/torus.py b/gflownet/proxy/torus.py index 89d3ba44c..6b78c9aad 100644 --- a/gflownet/proxy/torus.py +++ b/gflownet/proxy/torus.py @@ -19,10 +19,10 @@ def setup(self, env=None): def min(self): if not hasattr(self, "_min"): if self.normalize: - self._min = torch.tensor(-1.0, device=self.device, dtype=self.float) + self._min = torch.tensor(0.0, device=self.device, dtype=self.float) else: self._min = torch.tensor( - -((self.n_dim * 2) ** 3), device=self.device, dtype=self.float + ((self.n_dim * 2) ** 3), device=self.device, dtype=self.float ) return self._min @@ -31,10 +31,10 @@ def norm(self): if not hasattr(self, "_norm"): if self.normalize: self._norm = torch.tensor( - -((self.n_dim * 2) ** 3), device=self.device, dtype=self.float + ((self.n_dim * 2) ** 3), device=self.device, dtype=self.float ) else: - self._norm = torch.tensor(-1.0, device=self.device, dtype=self.float) + self._norm = torch.tensor(1.0, device=self.device, dtype=self.float) return self._norm def __call__(self, states: TensorType["batch", "state_dim"]) -> TensorType["batch"]: diff --git a/gflownet/proxy/uniform.py b/gflownet/proxy/uniform.py index e7a17fd76..d9c6d6b0a 100644 --- a/gflownet/proxy/uniform.py +++ b/gflownet/proxy/uniform.py @@ -13,10 +13,10 @@ def __init__(self, **kwargs): def __call__( self, states: Union[List, TensorType["batch", "state_dim"]] ) -> TensorType["batch"]: - return -1.0 * torch.ones(len(states), device=self.device, dtype=self.float) + return torch.ones(len(states), device=self.device, dtype=self.float) @property def min(self): if not hasattr(self, "_min"): - self._min = torch.tensor(-1.0, device=self.device, dtype=self.float) + self._min = torch.tensor(1.0, device=self.device, dtype=self.float) return self._min From 7e9cc44b3220c5162f84efa59ef91c3fc0f26929 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 7 May 2024 00:46:04 -0400 Subject: [PATCH 63/73] black --- gflownet/proxy/corners.py | 21 +++++++++------------ gflownet/proxy/tetris.py | 2 +- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/gflownet/proxy/corners.py b/gflownet/proxy/corners.py index f07e275ea..839d30a09 100644 --- a/gflownet/proxy/corners.py +++ b/gflownet/proxy/corners.py @@ -40,19 +40,16 @@ def min(self): return self._min def __call__(self, states: TensorType["batch", "state_dim"]) -> TensorType["batch"]: - return ( - self.mulnormal_norm - * torch.exp( - -0.5 - * ( - torch.diag( + return self.mulnormal_norm * torch.exp( + -0.5 + * ( + torch.diag( + torch.tensordot( torch.tensordot( - torch.tensordot( - (torch.abs(states) - self.mu_vec), self.cov_inv, dims=1 - ), - (torch.abs(states) - self.mu_vec).T, - dims=1, - ) + (torch.abs(states) - self.mu_vec), self.cov_inv, dims=1 + ), + (torch.abs(states) - self.mu_vec).T, + dims=1, ) ) ) diff --git a/gflownet/proxy/tetris.py b/gflownet/proxy/tetris.py index 93f8e3341..65216bd05 100644 --- a/gflownet/proxy/tetris.py +++ b/gflownet/proxy/tetris.py @@ -17,7 +17,7 @@ def setup(self, env=None): @property def norm(self): if self.normalize: - return (self.height * self.width) + return self.height * self.width else: return 1.0 From 10e5d7830d8fc95998c839e2f9a142d42c898d27 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 7 May 2024 10:46:51 -0400 Subject: [PATCH 64/73] Fix scrabble proxy and tests --- gflownet/proxy/scrabble.py | 8 +++++--- tests/gflownet/proxy/test_scrabble_proxy.py | 6 +----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/gflownet/proxy/scrabble.py b/gflownet/proxy/scrabble.py index 5f08b3a48..9a6d82828 100644 --- a/gflownet/proxy/scrabble.py +++ b/gflownet/proxy/scrabble.py @@ -72,7 +72,7 @@ def __call__( Returns ------- - A vector with the (negative) score of each sequence in the batch. + A vector with the score of each sequence in the batch. """ if torch.is_tensor(states): output = torch.zeros(states.shape[0], device=self.device, dtype=self.float) @@ -80,8 +80,10 @@ def __call__( is_in_vocabulary = self._is_in_vocabulary(states) else: is_in_vocabulary = torch.ones_like(output, dtype=torch.bool) - output[is_in_vocabulary] = -1.0 * self.scores[states[is_in_vocabulary]].sum( - dim=1 + output[is_in_vocabulary] = tfloat( + self.scores[states[is_in_vocabulary]].sum(dim=1), + float_type=self.float, + device=self.device, ) return output elif isinstance(states, list): diff --git a/tests/gflownet/proxy/test_scrabble_proxy.py b/tests/gflownet/proxy/test_scrabble_proxy.py index aa00f162f..fcb9af4ff 100644 --- a/tests/gflownet/proxy/test_scrabble_proxy.py +++ b/tests/gflownet/proxy/test_scrabble_proxy.py @@ -32,8 +32,6 @@ def env(): def test__scrabble_scorer__returns_expected_scores_list_input_list_tokens( env, proxy, samples, scores_expected ): - # Make scores expected negative - scores_expected = [-s for s in scores_expected] proxy.setup(env) scores = proxy(samples) assert scores.tolist() == scores_expected @@ -51,8 +49,6 @@ def test__scrabble_scorer__returns_expected_scores_list_input_list_tokens( def test__scrabble_scorer__returns_expected_scores_input_list_strings( env, proxy, samples, scores_expected ): - # Make scores expected negative - scores_expected = [-s for s in scores_expected] proxy.setup(env) scores = proxy(samples) assert scores.tolist() == scores_expected @@ -94,4 +90,4 @@ def test__scrabble_scorer__returns_expected_scores_input_state2proxy( env.set_state(env.readable2state(sample)) sample_proxy = env.state2proxy() score = proxy(sample_proxy) - assert score.tolist() == [-1.0 * score_expected] + assert score.tolist() == [score_expected] From 361cdc0f6ef4268f7f3f1c3046f07428e8539633 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Thu, 16 May 2024 22:32:14 -0400 Subject: [PATCH 65/73] Fix sanity check runs config --- mila/dev/sanity_check_runs.yaml | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mila/dev/sanity_check_runs.yaml b/mila/dev/sanity_check_runs.yaml index 6c70471e4..94b15c5c4 100644 --- a/mila/dev/sanity_check_runs.yaml +++ b/mila/dev/sanity_check_runs.yaml @@ -17,6 +17,7 @@ jobs: __value__: grid length: 10 gflownet: flowmatch + proxy: corners - slurm: job_name: sanity-grid-tb script: @@ -24,6 +25,7 @@ jobs: __value__: grid length: 10 gflownet: trajectorybalance + proxy: corners - slurm: job_name: sanity-grid-db script: @@ -31,6 +33,7 @@ jobs: __value__: grid length: 10 gflownet: detailedbalance + proxy: corners - slurm: job_name: sanity-grid-fl script: @@ -38,6 +41,7 @@ jobs: __value__: grid length: 10 gflownet: forwardlooking + proxy: corners # Tetris - slurm: job_name: sanity-tetris-fm @@ -68,9 +72,9 @@ jobs: pieces: ["J", "L", "S", "Z"] allow_eos_before_full: True proxy: + __value__: tetris reward_function: exponential gflownet: flowmatch - proxy: tetris - slurm: job_name: sanity-mintetris-tb script: @@ -81,9 +85,9 @@ jobs: pieces: ["J", "L", "S", "Z"] allow_eos_before_full: True proxy: + __value__: tetris reward_function: exponential gflownet: trajectorybalance - proxy: tetris - slurm: job_name: sanity-mintetris-fl script: @@ -94,9 +98,9 @@ jobs: pieces: ["J", "L", "S", "Z"] allow_eos_before_full: True proxy: + __value__: tetris reward_function: exponential gflownet: forwardlooking - proxy: tetris # Ctorus - slurm: job_name: sanity-ctorus From 4c1b1de8a8d5d52b387694dacc616a93398e47a3 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Thu, 16 May 2024 22:32:23 -0400 Subject: [PATCH 66/73] Docstring of base Proxy --- gflownet/proxy/base.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 031f65c7e..019e9f118 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -16,10 +16,6 @@ class Proxy(ABC): - """ - Generic proxy class - """ - def __init__( self, device, @@ -31,6 +27,38 @@ def __init__( do_clip_rewards: bool = False, **kwargs, ): + r""" + Base Proxy class for GFlowNet proxies. + + A proxy is the input to a reward function. Depending on the + ``reward_function``, the reward may be directly the output of the proxy or a + function of it. + + Arguments + --------- + device : str or torch.device + The device to be passed to torch tensors. + float_precision : int or torch.dtype + The floating point precision to be passed to torch tensors. + reward_function : str or Callable + The transformation applied to the proxy outputs to obtain a GFlowNet + reward. See :py:meth:`Proxy._get_reward_functions`. + logreward_function : Callable + The transformation applied to the proxy outputs to obtain a GFlowNet + log reward. See :meth:`Proxy._get_reward_functions`. If None (default), the + log of the reward function is used. The Callable may be used to improve the + numerical stability of the transformation. + reward_function_kwargs : dict + A dictionary of arguments to be passed to the reward function. + reward_min : float + The minimum value allowed for rewards, 0.0 by default, which results in a + minimum log reward of :py:const:`LOGZERO`. Note that certain loss + functions, for example the Forward Looking loss may not work as desired if + the minimum reward is 0.0. It may be set to a small (positive) value close + to zero in order to prevent numerical stability issues. + do_clip_rewards : bool + Whether to clip the rewards according to the minimum value. + """ # Proxy to reward function self.reward_function = reward_function self.logreward_function = logreward_function From c9879de9ca370d93f6d21b333340c7560c98f98c Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 21 May 2024 10:58:05 -0400 Subject: [PATCH 67/73] Set efault number of grid points for reward density to 40000; Fix typo --- config/logger/base.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/logger/base.yaml b/config/logger/base.yaml index 7d3a6cf2e..33977260a 100644 --- a/config/logger/base.yaml +++ b/config/logger/base.yaml @@ -26,8 +26,8 @@ test: logprobs_bootstrap_size: 10000 # Maximum number of test data points to compute log likelihood probs. max_data_logprobs: 1e5 - # Number of points tor obtain a grid to estimate the reward density - n_grid: 40401 + # Number of points to obtain a grid to estimate the reward density + n_grid: 40000 # Oracle metrics oracle: period: 100000 From 54ffaf1b1412c0f4b6368549e8c708156bf65eb5 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 21 May 2024 11:17:26 -0400 Subject: [PATCH 68/73] Documentation about arguments of proxy-to-reward functions --- config/proxy/base.yaml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/config/proxy/base.yaml b/config/proxy/base.yaml index d60975823..c108748a6 100644 --- a/config/proxy/base.yaml +++ b/config/proxy/base.yaml @@ -12,8 +12,11 @@ reward_function: identity # A callable of the proxy-to-logreward function. # None by default, which takes the log of the proxy-to-reward function logreward_function: null -# Arguments of the proxy-to-reward function. -# The default functions use an argument with key beta +# Arguments of the proxy-to-reward function (beta): +# - power: R(x) = x ** beta +# - exponential: R(x) = exp(x * beta) +# - shift: R(x) = x + beta +# - product: R(x) = x * beta reward_function_kwargs: {} # Minimum reward. Used to clip the rewards. reward_min: 0.0 From 73a874974ff60e27de9b99a20520c077511b2796 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 21 May 2024 11:29:19 -0400 Subject: [PATCH 69/73] Revert test period in icml torus config file and apply it in sanity checks configs --- config/experiments/icml23/ctorus.yaml | 2 +- mila/dev/sanity_check_runs.md | 6 +++--- mila/dev/sanity_check_runs.yaml | 3 +++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/config/experiments/icml23/ctorus.yaml b/config/experiments/icml23/ctorus.yaml index 8f5753da3..2aac326b7 100644 --- a/config/experiments/icml23/ctorus.yaml +++ b/config/experiments/icml23/ctorus.yaml @@ -55,7 +55,7 @@ logger: - continuous - ctorus test: - period: 500 + period: 25 n: 1000 checkpoints: period: 500 diff --git a/mila/dev/sanity_check_runs.md b/mila/dev/sanity_check_runs.md index 3d757643a..8b36b4799 100644 --- a/mila/dev/sanity_check_runs.md +++ b/mila/dev/sanity_check_runs.md @@ -116,17 +116,17 @@ python mila/launch.py --conda_env= user=$USER env=tetris proxy=t `salloc`: ```bash -python main.py user=$USER +experiments=icml23/ctorus device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True +python main.py user=$USER +experiments=icml23/ctorus logger.test.period=500 device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True ``` `sbatch` with `virtualenv`: ```bash -python mila/launch.py --venv= --template=mila/sbatch/template-venv.sh user=$USER +experiments=icml23/ctorus device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True +python mila/launch.py --venv= --template=mila/sbatch/template-venv.sh user=$USER +experiments=icml23/ctorus logger.test.period=500 device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True ``` `sbatch` with `conda`: ```bash -python mila/launch.py --conda_env= user=$USER +experiments=icml23/ctorus device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True +python mila/launch.py --conda_env= user=$USER +experiments=icml23/ctorus logger.test.period=500 device=cpu logger.project_name=gfn_sanity_checks logger.do.online=True ``` diff --git a/mila/dev/sanity_check_runs.yaml b/mila/dev/sanity_check_runs.yaml index 94b15c5c4..cad19cf90 100644 --- a/mila/dev/sanity_check_runs.yaml +++ b/mila/dev/sanity_check_runs.yaml @@ -106,3 +106,6 @@ jobs: job_name: sanity-ctorus script: +experiments: icml23/ctorus + logger: + test: + period: 500 From fa40c51321b39340137d3a08250d7289a60a3f4a Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 21 May 2024 12:41:07 -0400 Subject: [PATCH 70/73] Extend documentation and use of rewards and sample_rewards in plotting methods --- gflownet/envs/cube.py | 6 +++++- gflownet/envs/grid.py | 13 +++++++++---- gflownet/envs/htorus.py | 16 +++++++++------- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/gflownet/envs/cube.py b/gflownet/envs/cube.py index bb57407c6..9f85f0605 100644 --- a/gflownet/envs/cube.py +++ b/gflownet/envs/cube.py @@ -1467,7 +1467,11 @@ def plot_reward_samples( reward has been obtained. These samples are used to plot the contour of reward density. rewards : tensor - The reward of samples_reward. + The rewards of samples_reward. It should be a vector of dimensionality + n_per_dim ** 2 and be sorted such that the each block at rewards[i * + n_per_dim:i * n_per_dim + n_per_dim] correspond to the rewards at the i-th + row of the grid of samples, from top to bottom. The same is assumed for + samples_reward. alpha : float Transparency of the reward contour. dpi : int diff --git a/gflownet/envs/grid.py b/gflownet/envs/grid.py index b7426552c..9b1055079 100644 --- a/gflownet/envs/grid.py +++ b/gflownet/envs/grid.py @@ -337,6 +337,9 @@ def plot_reward_samples( Plots the reward density as a 2D histogram on the grid, alongside a histogram representing the samples density. + It is assumed that the rewards correspond to entire domain of the grid and are + sorted from left to right (first) and top to bottom of the grid of samples. + Parameters ---------- samples : tensor @@ -344,10 +347,12 @@ def plot_reward_samples( will be plotted on top of the reward density. samples_reward : tensor A batch of samples containing a grid over the sample space, from which the - reward has been obtained. These samples are used to plot the contour of - reward density. + reward has been obtained. Ignored by this method. rewards : tensor - The reward of samples_reward. + The rewards of samples_reward. It should be a vector of dimensionality + length ** 2 and be sorted such that the each block at rewards[i * + length:i * length + length] correspond to the rewards at the i-th + row of the grid of samples, from top to bottom. dpi : int Dots per inch, indicating the resolution of the plot. n_ticks_max : int @@ -359,8 +364,8 @@ def plot_reward_samples( if self.n_dim != 2: return None samples = torch2np(samples) - samples_reward = torch2np(samples_reward) rewards = torch2np(rewards) + assert rewards.shape[0] == self.length**2 # Init figure fig, axes = plt.subplots(ncols=2, dpi=dpi) step_ticks = np.ceil(self.length / n_ticks_max).astype(int) diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index 7b1007da7..f5fd32fbf 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -577,7 +577,8 @@ def plot_reward_samples( Plots the reward contour alongside a batch of samples. The samples are previously augmented in order to visualise the periodic aspect - of the sample space. It is assumed that the samples and the rewards are sorted. + of the sample space. It is assumed that the rewards are sorted from left to + right (first) and top to bottom of the grid of samples. Parameters ---------- @@ -586,10 +587,12 @@ def plot_reward_samples( will be plotted on top of the reward density. samples_reward : tensor A batch of samples containing a grid over the sample space, from which the - reward has been obtained. These samples are used to plot the contour of - reward density. + reward has been obtained. Ignored by this method. rewards : tensor - The reward of samples_reward. + The rewards of samples_reward. It should be a vector of dimensionality + n_per_dim ** 2 and be sorted such that the each block at rewards[i * + n_per_dim:i * n_per_dim + n_per_dim] correspond to the rewards at the i-th + row of the grid of samples, from top to bottom. min_domain : float Minimum value of the domain to keep in the plot. max_domain : float @@ -604,10 +607,9 @@ def plot_reward_samples( if self.n_dim != 2: return None samples = torch2np(samples) - samples_reward = torch2np(samples_reward) rewards = torch2np(rewards) - n_per_dim = int(np.sqrt(samples_reward.shape[0])) - assert n_per_dim**2 == samples_reward.shape[0] + n_per_dim = int(np.sqrt(rewards.shape[0])) + assert n_per_dim**2 == rewards.shape[0] # Augment rewards to apply periodic boundary conditions rewards = rewards.reshape((n_per_dim, n_per_dim)) rewards = np.tile(rewards, (3, 3)) From a9ef286c91c8c72656f9e246d0e5aff2f5af091b Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 21 May 2024 12:46:23 -0400 Subject: [PATCH 71/73] Fix fixture in tests of base proxy --- tests/gflownet/proxy/test_base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/gflownet/proxy/test_base.py b/tests/gflownet/proxy/test_base.py index 4633adacc..270228ae2 100644 --- a/tests/gflownet/proxy/test_base.py +++ b/tests/gflownet/proxy/test_base.py @@ -15,8 +15,7 @@ def uniform(): @pytest.fixture() def proxy_identity(beta): return Uniform( - reward_function="power", - reward_function_kwargs={"beta": beta}, + reward_function="identity", device="cpu", float_precision=32, ) From abe50cba5e07a3f5e5a9a1d30de270981800fcf5 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 21 May 2024 12:49:07 -0400 Subject: [PATCH 72/73] Extend tests of identity reward_function with another value of beta --- tests/gflownet/proxy/test_base.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/gflownet/proxy/test_base.py b/tests/gflownet/proxy/test_base.py index 270228ae2..e142dca77 100644 --- a/tests/gflownet/proxy/test_base.py +++ b/tests/gflownet/proxy/test_base.py @@ -150,6 +150,37 @@ def check_proxy2reward(rewards_computed, rewards_expected, atol=1e-3): 4.6052, ], ), + ( + 2, + [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], + [-100, -10, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1, 10, 100], + [ + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + -np.inf, + -2.3025, + -0.6931, + 0.0, + 2.3025, + 4.6052, + ], + [ + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + LOGZERO, + -2.3025, + -0.6931, + 0.0, + 2.3025, + 4.6052, + ], + ), ], ) def test_reward_function_identity__behaves_as_expected( From a1462f1be0ba9df84aefe592278bb42292b571c1 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 21 May 2024 15:54:58 -0400 Subject: [PATCH 73/73] Implement functionality to handle max_reward --- gflownet/gflownet.py | 2 +- gflownet/proxy/base.py | 48 ++++++++++++++++++++++++++++++- gflownet/proxy/corners.py | 8 +++--- gflownet/proxy/torus.py | 10 +++---- gflownet/proxy/uniform.py | 7 +---- tests/gflownet/proxy/test_base.py | 26 +++++++++++++++++ 6 files changed, 84 insertions(+), 17 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 9e9484463..42ce57b53 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -1538,7 +1538,7 @@ def sample_from_reward( format. """ samples_final = [] - max_reward = self.proxy.proxy2reward(self.proxy.min) + max_reward = self.get_max_reward() while len(samples_final) < n_samples: if proposal_distribution == "uniform": # TODO: sample only the remaining number of samples diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 019e9f118..14f5ac268 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -177,13 +177,59 @@ def get_min_reward(self, log: bool = False) -> float: Returns ------- float - The mimnimum (log) reward. + The minimum (log) reward. """ if log: return self.logreward_min else: return self.reward_min + def get_max_reward(self, log: bool = False) -> float: + """ + Returns the maximum value of the (log) reward, retrieved from self.optimum, in + case it is defined. + + Parameters + ---------- + log : bool + If True, returns the logarithm of the maximum reward. If False (default), + returns the natural maximum reward. + + Returns + ------- + float + The maximum (log) reward. + """ + if log: + return self.proxy2logreward(self.optimum) + else: + return self.proxy2reward(self.optimum) + + @property + def optimum(self): + """ + Returns the optimum value of the proxy. + + Not implemented by default but may be implemented for synthetic proxies or when + the optimum is known. + + The optimum is used, for example, to accelerate rejection sampling, to sample + from the reward function. + """ + if not hasattr(self, "_optimum"): + raise NotImplementedError( + "The optimum value of the proxy needs to be implemented explicitly for " + f"each Proxy and is not available for {self.__class__}." + ) + return self._optimum + + @optimum.setter + def optimum(self, value): + """ + Sets the optimum value of the proxy. + """ + self._optimum = value + def _get_reward_functions( self, reward_function: Union[Callable, str], diff --git a/gflownet/proxy/corners.py b/gflownet/proxy/corners.py index 839d30a09..e1db911f3 100644 --- a/gflownet/proxy/corners.py +++ b/gflownet/proxy/corners.py @@ -31,13 +31,13 @@ def setup(self, env=None): self.mulnormal_norm = 1.0 / ((2 * torch.pi) ** 2 * cov_det) ** 0.5 @property - def min(self): - if not hasattr(self, "_min"): + def optimum(self): + if not hasattr(self, "_optimum"): mode = self.mu * torch.ones( self.n_dim, device=self.device, dtype=self.float ) - self._min = self(torch.unsqueeze(mode, 0))[0] - return self._min + self._optimum = self(torch.unsqueeze(mode, 0))[0] + return self._optimum def __call__(self, states: TensorType["batch", "state_dim"]) -> TensorType["batch"]: return self.mulnormal_norm * torch.exp( diff --git a/gflownet/proxy/torus.py b/gflownet/proxy/torus.py index 6b78c9aad..7dbd5b0e8 100644 --- a/gflownet/proxy/torus.py +++ b/gflownet/proxy/torus.py @@ -16,15 +16,15 @@ def setup(self, env=None): self.n_dim = env.n_dim @property - def min(self): - if not hasattr(self, "_min"): + def optimum(self): + if not hasattr(self, "_optimum"): if self.normalize: - self._min = torch.tensor(0.0, device=self.device, dtype=self.float) + self._optimum = torch.tensor(1.0, device=self.device, dtype=self.float) else: - self._min = torch.tensor( + self._optimum = torch.tensor( ((self.n_dim * 2) ** 3), device=self.device, dtype=self.float ) - return self._min + return self._optimum @property def norm(self): diff --git a/gflownet/proxy/uniform.py b/gflownet/proxy/uniform.py index d9c6d6b0a..659612805 100644 --- a/gflownet/proxy/uniform.py +++ b/gflownet/proxy/uniform.py @@ -9,14 +9,9 @@ class Uniform(Proxy): def __init__(self, **kwargs): super().__init__(**kwargs) + self._optimum = torch.tensor(1.0, device=self.device, dtype=self.float) def __call__( self, states: Union[List, TensorType["batch", "state_dim"]] ) -> TensorType["batch"]: return torch.ones(len(states), device=self.device, dtype=self.float) - - @property - def min(self): - if not hasattr(self, "_min"): - self._min = torch.tensor(1.0, device=self.device, dtype=self.float) - return self._min diff --git a/tests/gflownet/proxy/test_base.py b/tests/gflownet/proxy/test_base.py index e142dca77..812b158ee 100644 --- a/tests/gflownet/proxy/test_base.py +++ b/tests/gflownet/proxy/test_base.py @@ -623,3 +623,29 @@ def test_reward_function_callable__behaves_as_expected( # Log Rewards logrewards_exp = tfloat(logrewards_exp, device=proxy.device, float_type=proxy.float) assert all(check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp)) + + +@pytest.mark.parametrize( + "proxy, beta, optimum, reward_max", + [ + ("uniform", None, 1.0, 1.0), + ("uniform", None, 2.0, 2.0), + ("proxy_power", 1, 2.0, 2.0), + ("proxy_power", 2, 2.0, 4.0), + ("proxy_exponential", 1, 1.0, np.exp(1.0)), + ("proxy_exponential", -1, -1.0, np.exp(1.0)), + ("proxy_shift", 5, 10.0, 15.0), + ("proxy_shift", -5, 10.0, 5.0), + ("proxy_product", 2, 2.0, 4.0), + ("proxy_product", -2, -5.0, 10.0), + ], +) +def test__uniform_proxy_initializes_without_errors( + proxy, beta, optimum, reward_max, request +): + proxy = request.getfixturevalue(proxy) + reward_max = torch.tensor(reward_max, dtype=proxy.float, device=proxy.device) + # Forcibly set the optimum for testing purposes, even if the proxy is uniform. + proxy.optimum = torch.tensor(optimum) + assert torch.isclose(proxy.get_max_reward(log=False), reward_max) + assert torch.isclose(proxy.get_max_reward(log=True), torch.log(reward_max))