diff --git a/config/env/ccube.yaml b/config/env/ccube.yaml index 714638524..b9e813c49 100644 --- a/config/env/ccube.yaml +++ b/config/env/ccube.yaml @@ -5,7 +5,6 @@ _target_: gflownet.envs.cube.ContinuousCube id: ccube continuous: True -func: corners # Dimensions of hypercube n_dim: 2 # Constant to restrict interval of test sets diff --git a/config/env/ctorus.yaml b/config/env/ctorus.yaml index fa194956f..f648dd4c4 100644 --- a/config/env/ctorus.yaml +++ b/config/env/ctorus.yaml @@ -4,7 +4,6 @@ defaults: _target_: gflownet.envs.ctorus.ContinuousTorus id: ctorus -func: sincos # Dimensions of hypertorus n_dim: 2 policy_encoding_dim_per_angle: null diff --git a/config/env/grid.yaml b/config/env/grid.yaml index 7e2df40fb..d0c3d7415 100644 --- a/config/env/grid.yaml +++ b/config/env/grid.yaml @@ -4,7 +4,6 @@ defaults: _target_: gflownet.envs.grid.Grid id: grid -func: corners # Dimensions of hypergrid n_dim: 2 # Number of cells per dimension diff --git a/config/env/htorus.yaml b/config/env/htorus.yaml index c53b95acc..1324ca9dd 100644 --- a/config/env/htorus.yaml +++ b/config/env/htorus.yaml @@ -5,7 +5,6 @@ _target_: gflownet.envs.htorus.HybridTorus id: ctorus continuous: True -func: sincos # Dimensions of hypertorus n_dim: 2 policy_encoding_dim_per_angle: null diff --git a/config/env/torus.yaml b/config/env/torus.yaml index a5959b79a..97cfcba48 100644 --- a/config/env/torus.yaml +++ b/config/env/torus.yaml @@ -4,7 +4,6 @@ defaults: _target_: gflownet.envs.torus.Torus id: torus -func: sincos # Dimensions of hypertorus n_dim: 2 # Number of angles per dimension diff --git a/config/experiments/ccube/branin.yaml b/config/experiments/ccube/branin.yaml new file mode 100644 index 000000000..7aeaddc8c --- /dev/null +++ b/config/experiments/ccube/branin.yaml @@ -0,0 +1,75 @@ +# @package _global_ +# A configuration that works well with the Branin proxy. +# wandb: https://wandb.ai/alexhg/Branin/runs/xlxfdc6k + +defaults: + - override /env: ccube + - override /gflownet: trajectorybalance + - override /proxy: box/branin + - override /logger: wandb + - override /user: alex + +# Environment +env: + n_dim: 2 + n_comp: 5 + beta_params_min: 0.1 + beta_params_max: 100.0 + min_incr: 0.1 + fixed_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + reward_func: identity + +# GFlowNet hyperparameters +gflownet: + random_action_prob: 0.1 + optimizer: + batch_size: + forward: 100 + lr: 0.0001 + z_dim: 16 + lr_z_mult: 100 + n_train_steps: 10000 + +# Policy +policy: + forward: + type: mlp + n_hid: 128 + n_layers: 2 + checkpoint: forward + backward: + shared_weights: True + checkpoint: backward + +# WandB +logger: + do: + online: true + lightweight: True + project_name: "branin" + tags: + - gflownet + - continuous + - ccube + - branin + test: + period: 500 + n: 1000 + checkpoints: + period: 500 + +# Hydra +hydra: + run: + dir: ${user.logdir.root}/ccube/branin/${now:%Y-%m-%d_%H-%M-%S} diff --git a/config/experiments/ccube/corners.yaml b/config/experiments/ccube/corners.yaml index ccc207c6f..ebf609422 100644 --- a/config/experiments/ccube/corners.yaml +++ b/config/experiments/ccube/corners.yaml @@ -5,7 +5,7 @@ defaults: - override /env: ccube - override /gflownet: trajectorybalance - - override /proxy: corners + - override /proxy: box/corners - override /logger: wandb - override /user: alex diff --git a/config/experiments/ccube/hartmann.yaml b/config/experiments/ccube/hartmann.yaml new file mode 100644 index 000000000..16c0823d4 --- /dev/null +++ b/config/experiments/ccube/hartmann.yaml @@ -0,0 +1,78 @@ +# @package _global_ +# A configuration that works well with the Hartmann proxy. +# wandb: https://wandb.ai/alexhg/Hartmann/runs/9l4ry4gm + +defaults: + - override /env: ccube + - override /gflownet: trajectorybalance + - override /proxy: box/hartmann + - override /logger: wandb + - override /user: alex + +# Environment +env: + n_dim: 6 + n_comp: 5 + beta_params_min: 0.1 + beta_params_max: 100.0 + min_incr: 0.1 + fixed_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + reward_func: identity + +# GFlowNet hyperparameters +gflownet: + random_action_prob: 0.1 + optimizer: + batch_size: + forward: 100 + lr: 0.0001 + z_dim: 16 + lr_z_mult: 100 + n_train_steps: 10000 + +# Policy +policy: + forward: + type: mlp + n_hid: 256 + n_layers: 3 + checkpoint: forward + backward: + shared_weights: False + type: mlp + n_hid: 256 + n_layers: 3 + checkpoint: backward + +# WandB +logger: + do: + online: true + lightweight: True + project_name: "hartmann" + tags: + - gflownet + - continuous + - ccube + - hartmann + test: + period: 500 + n: 1000 + checkpoints: + period: 500 + +# Hydra +hydra: + run: + dir: ${user.logdir.root}/ccube/hartmann/${now:%Y-%m-%d_%H-%M-%S} diff --git a/config/experiments/ccube/hyperparams_search_20230920_batch1.yaml b/config/experiments/ccube/hyperparams_search_20230920_batch1.yaml index 87e44bfb5..aafd7170d 100644 --- a/config/experiments/ccube/hyperparams_search_20230920_batch1.yaml +++ b/config/experiments/ccube/hyperparams_search_20230920_batch1.yaml @@ -27,7 +27,7 @@ shared: output_csv: ccube_test.csv output_pkl: ccube_test.pkl # Proxy - proxy: corners + proxy: box/corners # GFlowNet config gflownet: __value__: trajectorybalance diff --git a/config/experiments/ccube/hyperparams_search_20230920_batch2.yaml b/config/experiments/ccube/hyperparams_search_20230920_batch2.yaml index 93491e3e9..900ec652a 100644 --- a/config/experiments/ccube/hyperparams_search_20230920_batch2.yaml +++ b/config/experiments/ccube/hyperparams_search_20230920_batch2.yaml @@ -27,7 +27,7 @@ shared: output_csv: ccube_test.csv output_pkl: ccube_test.pkl # Proxy - proxy: corners + proxy: box/corners # GFlowNet config gflownet: __value__: trajectorybalance diff --git a/config/experiments/ccube/hyperparams_search_20230920_batch3.yaml b/config/experiments/ccube/hyperparams_search_20230920_batch3.yaml index 7912af9b3..749bf951a 100644 --- a/config/experiments/ccube/hyperparams_search_20230920_batch3.yaml +++ b/config/experiments/ccube/hyperparams_search_20230920_batch3.yaml @@ -27,7 +27,7 @@ shared: output_csv: ccube_test.csv output_pkl: ccube_test.pkl # Proxy - proxy: corners + proxy: box/corners # GFlowNet config gflownet: __value__: trajectorybalance diff --git a/config/experiments/ccube/hyperparams_search_20230920_batch4.yaml b/config/experiments/ccube/hyperparams_search_20230920_batch4.yaml index cc82e322c..84386b876 100644 --- a/config/experiments/ccube/hyperparams_search_20230920_batch4.yaml +++ b/config/experiments/ccube/hyperparams_search_20230920_batch4.yaml @@ -27,7 +27,7 @@ shared: output_csv: ccube_test.csv output_pkl: ccube_test.pkl # Proxy - proxy: corners + proxy: box/corners # GFlowNet config gflownet: __value__: trajectorybalance diff --git a/config/experiments/clatticeparams/clatticeparams_owl.yaml b/config/experiments/clatticeparams/clatticeparams_owl.yaml index 30f1c1347..0ef562de9 100644 --- a/config/experiments/clatticeparams/clatticeparams_owl.yaml +++ b/config/experiments/clatticeparams/clatticeparams_owl.yaml @@ -3,7 +3,7 @@ defaults: - override /env: crystals/clattice_parameters - override /gflownet: trajectorybalance - - override /proxy: corners + - override /proxy: box/corners - override /logger: wandb - override /user: alex diff --git a/config/experiments/grid/branin.yaml b/config/experiments/grid/branin.yaml new file mode 100644 index 000000000..ad966acef --- /dev/null +++ b/config/experiments/grid/branin.yaml @@ -0,0 +1,70 @@ +# @package _global_ +# 100x100 grid with a configuration that works well with the Branin proxy. +# wandb: https://wandb.ai/alexhg/Branin/runs/0ujb3hwl + +defaults: + - override /env: grid + - override /gflownet: trajectorybalance + - override /proxy: box/branin + - override /logger: wandb + - override /user: alex + +# Environment +env: + n_dim: 2 + length: 100 + max_increment: 1 + max_dim_per_action: 1 + reward_func: identity + # Buffer + buffer: + train: null + test: + type: uniform + n: 1000 + seed: 0 + output_csv: grid_test.csv + output_pkl: grid_test.pkl + +# GFlowNet hyperparameters +gflownet: + random_action_prob: 0.01 + optimizer: + batch_size: + forward: 100 + lr: 0.0001 + z_dim: 16 + lr_z_mult: 100 + n_train_steps: 10000 + +# Policy +policy: + forward: + type: mlp + n_hid: 128 + n_layers: 2 + checkpoint: forward + backward: + shared_weights: True + checkpoint: backward + +# WandB +logger: + do: + online: true + lightweight: True + project_name: "branin" + tags: + - gflownet + - grid + - branin + test: + period: 500 + n: 1000 + checkpoints: + period: 500 + +# Hydra +hydra: + run: + dir: ${user.logdir.root}/grid/branin/${now:%Y-%m-%d_%H-%M-%S} diff --git a/config/experiments/grid/hartmann.yaml b/config/experiments/grid/hartmann.yaml new file mode 100644 index 000000000..f469c5769 --- /dev/null +++ b/config/experiments/grid/hartmann.yaml @@ -0,0 +1,73 @@ +# @package _global_ +# 10^6 grid with a configuration that works well with the Hartmann proxy. +# wandb: https://wandb.ai/alexhg/Hartmann/runs/1l1y5xwb + +defaults: + - override /env: grid + - override /gflownet: trajectorybalance + - override /proxy: box/hartmann + - override /logger: wandb + - override /user: alex + +# Environment +env: + n_dim: 6 + length: 10 + max_increment: 1 + max_dim_per_action: 1 + reward_func: identity + # Buffer + buffer: + train: null + test: + type: uniform + n: 1000 + seed: 0 + output_csv: grid_test.csv + output_pkl: grid_test.pkl + +# GFlowNet hyperparameters +gflownet: + random_action_prob: 0.01 + optimizer: + batch_size: + forward: 100 + lr: 0.0001 + z_dim: 16 + lr_z_mult: 100 + n_train_steps: 10000 + +# Policy +policy: + forward: + type: mlp + n_hid: 256 + n_layers: 3 + checkpoint: forward + backward: + shared_weights: False + type: mlp + n_hid: 256 + n_layers: 3 + checkpoint: backward + +# WandB +logger: + do: + online: true + lightweight: True + project_name: "hartmann" + tags: + - gflownet + - grid + - hartmann + test: + period: 500 + n: 1000 + checkpoints: + period: 500 + +# Hydra +hydra: + run: + dir: ${user.logdir.root}/grid/hartmann/${now:%Y-%m-%d_%H-%M-%S} diff --git a/config/main.yaml b/config/main.yaml index dd5d98edf..e2133d3a3 100644 --- a/config/main.yaml +++ b/config/main.yaml @@ -3,7 +3,7 @@ defaults: - env: grid - gflownet: flowmatch - policy: mlp_${gflownet} - - proxy: corners + - proxy: box/corners - logger: wandb - user: default diff --git a/config/proxy/box/branin.yaml b/config/proxy/box/branin.yaml new file mode 100644 index 000000000..693a16c07 --- /dev/null +++ b/config/proxy/box/branin.yaml @@ -0,0 +1,4 @@ +_target_: gflownet.proxy.box.branin.Branin + +fidelity: 1.0 +do_domain_map: True diff --git a/config/proxy/box/corners.yaml b/config/proxy/box/corners.yaml new file mode 100644 index 000000000..ea20f8d0e --- /dev/null +++ b/config/proxy/box/corners.yaml @@ -0,0 +1,4 @@ +_target_: gflownet.proxy.box.corners.Corners + +mu: 0.75 +sigma: 0.05 diff --git a/config/proxy/box/hartmann.yaml b/config/proxy/box/hartmann.yaml new file mode 100644 index 000000000..ddef73072 --- /dev/null +++ b/config/proxy/box/hartmann.yaml @@ -0,0 +1,3 @@ +_target_: gflownet.proxy.box.hartmann.Hartmann + +fidelity: 1.0 diff --git a/config/proxy/corners.yaml b/config/proxy/corners.yaml deleted file mode 100644 index 081004490..000000000 --- a/config/proxy/corners.yaml +++ /dev/null @@ -1,5 +0,0 @@ -_target_: gflownet.proxy.corners.Corners - -mu: 0.75 -sigma: 0.05 -higher_is_better: False diff --git a/config/proxy/length.yaml b/config/proxy/length.yaml index f63ead4c4..ed8e6dd51 100644 --- a/config/proxy/length.yaml +++ b/config/proxy/length.yaml @@ -2,4 +2,3 @@ _target_: gflownet.proxy.aptamers.Aptamers oracle_id: length norm: True -higher_is_better: False diff --git a/config/proxy/tetris.yaml b/config/proxy/tetris.yaml index 9774e1dbb..940e5c56b 100644 --- a/config/proxy/tetris.yaml +++ b/config/proxy/tetris.yaml @@ -1,5 +1,4 @@ _target_: gflownet.proxy.tetris.Tetris normalize: True -higher_is_better: False diff --git a/config/proxy/torus.yaml b/config/proxy/torus.yaml index 9739c85b0..0006c3e9e 100644 --- a/config/proxy/torus.yaml +++ b/config/proxy/torus.yaml @@ -3,4 +3,3 @@ _target_: gflownet.proxy.torus.Torus normalize: True alpha: 1.0 beta: 1.0 -higher_is_better: False diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index 3a9f6579c..2348c9476 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -74,7 +74,6 @@ def __init__( # Proxy self.proxy = proxy self.setup_proxy() - self.proxy_factor = -1.0 self.proxy_state_format = proxy_state_format # Flag to skip checking if action is valid (computing mask) before step self.skip_mask_check = skip_mask_check @@ -796,12 +795,12 @@ def reward_batch(self, states: List[List], done=None): def proxy2reward(self, proxy_vals): """ - Prepares the output of a proxy for GFlowNet: the inputs proxy_vals is expected - to be a negative value (energy), unless self.denorm_proxy is True. If the - latter, the proxy values are first de-normalized according to the mean and - standard deviation in self.energies_stats. The output of the function is a - strictly positive reward - provided self.reward_norm and self.reward_beta are - positive - and larger than self.min_reward. + Prepares the output of an oracle for GFlowNet: the inputs proxy_vals is + expected to be a negative value (energy), unless self.denorm_proxy is True. If + the latter, the proxy values are first de-normalized according to the mean and + standard deviation in self.energies_stats, then made negative. The output of + the function is a strictly positive reward - self.reward_norm and + self.reward_beta must be positive - and larger than or equal to self.min_reward. """ if self.denorm_proxy: # TODO: do with torch @@ -811,21 +810,22 @@ def proxy2reward(self, proxy_vals): + self.energies_stats[0] ) # proxy_vals = proxy_vals * self.energies_stats[3] + self.energies_stats[2] + proxy_vals = -1.0 * proxy_vals if self.reward_func == "power": return torch.clamp( - (self.proxy_factor * proxy_vals / self.reward_norm) ** self.reward_beta, + (proxy_vals / self.reward_norm) ** self.reward_beta, min=self.min_reward, max=None, ) elif self.reward_func == "boltzmann": return torch.clamp( - torch.exp(self.proxy_factor * self.reward_beta * proxy_vals), + torch.exp(self.reward_beta * proxy_vals), min=self.min_reward, max=None, ) elif self.reward_func == "identity": return torch.clamp( - self.proxy_factor * proxy_vals, + proxy_vals, min=self.min_reward, max=None, ) @@ -844,7 +844,7 @@ def reward2proxy(self, reward): a proxy. """ if self.reward_func == "power": - return self.proxy_factor * torch.exp( + return -1.0 * torch.exp( ( torch.log(reward) + self.reward_beta * torch.log(torch.as_tensor(self.reward_norm)) @@ -852,11 +852,11 @@ def reward2proxy(self, reward): / self.reward_beta ) elif self.reward_func == "boltzmann": - return self.proxy_factor * torch.log(reward) / self.reward_beta + return -1.0 * torch.log(reward) / self.reward_beta elif self.reward_func == "identity": - return self.proxy_factor * reward + return -1.0 * reward elif self.reward_func == "shift": - return self.proxy_factor * (reward - self.reward_beta) + return -1.0 * (reward - self.reward_beta) else: raise NotImplementedError diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 045ed7841..422646483 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -3,9 +3,11 @@ """ from abc import ABC, abstractmethod +from typing import List, Union import numpy as np import numpy.typing as npt +from torchtyping import TensorType from gflownet.utils.common import set_device, set_float_precision @@ -15,29 +17,44 @@ class Proxy(ABC): Generic proxy class """ - def __init__(self, device, float_precision, higher_is_better=False, **kwargs): + def __init__(self, device, float_precision, **kwargs): # Device self.device = set_device(device) # Float precision self.float = set_float_precision(float_precision) - # Reward2Proxy multiplicative factor (1 or -1) - self.higher_is_better = higher_is_better def setup(self, env=None): pass @abstractmethod - def __call__(self, states: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: + def __call__( + self, + states: Union[TensorType["batch", "state_dim"], npt.NDArray[np.float32], List], + ) -> TensorType["batch"]: """ - Implement this function to call the get_reward method of the appropriate Proxy - Class (EI, UCB, Proxy, Oracle etc). + Computes the values of the proxy for a batch of states. Parameters ---------- - states: ndarray + states: torch.tensor, ndarray, list + A batch of states in proxy format. + + Returns + ------- + torch.tensor + The proxy value for each state in the input batch. """ pass + @staticmethod + def map_to_standard_range(values: TensorType["batch"]) -> TensorType["batch"]: + """ + Maps a batch of proxy values back onto the standard range of the proxy or + oracle. By default, it returns the values as are, so this method may be + overwritten when needed. + """ + return values + def infer_on_train_set(self): """ Implement this method in specific proxies. diff --git a/gflownet/proxy/box/branin.py b/gflownet/proxy/box/branin.py new file mode 100644 index 000000000..ba5bc9467 --- /dev/null +++ b/gflownet/proxy/box/branin.py @@ -0,0 +1,109 @@ +""" +Branin objective function, relying on the botorch implementation. + +This code is based on the implementation by Nikita Saxena (nikita-0209) in +https://github.com/alexhernandezgarcia/activelearning + +The implementation assumes by default that the inputs will be on [0, 1] x [0, 1] and +will be mapped to the standard domain of the Branin function (see X1_DOMAIN and +X2_DOMAIN). Setting do_domain_map to False will prevent the mapping. + +Branin function is typically used as a minimization problem, with the minima around +zero but positive. In order to map the range into the convential negative range, an +upper bound of of Branin in the standard domain (UPPER_BOUND_IN_DOMAIN) is subtracted. +""" + +import torch +from botorch.test_functions.multi_fidelity import AugmentedBranin +from torchtyping import TensorType + +from gflownet.proxy.base import Proxy + +X1_DOMAIN = [-5, 10] +X1_LENGTH = X1_DOMAIN[1] - X1_DOMAIN[0] +X2_DOMAIN = [0, 15] +X2_LENGTH = X2_DOMAIN[1] - X2_DOMAIN[0] +UPPER_BOUND_IN_DOMAIN = 309 + + +class Branin(Proxy): + def __init__(self, fidelity=1.0, do_domain_map=True, **kwargs): + """ + fidelity : float + Fidelity of the Branin oracle. 1.0 corresponds to the original Branin. + Smaller values (up to 0.0) reduce the fidelity of the oracle. + + See: https://botorch.org/api/test_functions.html + """ + super().__init__(**kwargs) + self.fidelity = fidelity + self.do_domain_map = do_domain_map + self.function_mf_botorch = AugmentedBranin(negate=False) + # Modes and extremum compatible with 100x100 grid + self.modes = [ + [12.4, 81.833], + [54.266, 15.16], + [94.98, 16.5], + ] + self.extremum = 0.397887 + + def __call__(self, states: TensorType["batch", "2"]) -> TensorType["batch"]: + if states.shape[1] != 2: + raise ValueError( + """ + Inputs to the Branin function must be 2-dimensional, but inputs with + {states.shape[1]} dimensions were passed. + """ + ) + if self.do_domain_map: + states = Branin.map_to_standard_domain(states) + # Append fidelity as a new dimension of states + states = torch.cat( + [ + states, + self.fidelity + * torch.ones( + states.shape[0], device=self.device, dtype=self.float + ).unsqueeze(-1), + ], + dim=1, + ) + return Branin.map_to_negative_range(self.function_mf_botorch(states)) + + @property + def min(self): + if not hasattr(self, "_min"): + self._min = torch.tensor( + -UPPER_BOUND_IN_DOMAIN, device=self.device, dtype=self.float + ) + return self._min + + @staticmethod + def map_to_standard_domain( + states: TensorType["batch", "2"] + ) -> TensorType["batch", "2"]: + """ + Maps a batch of input states onto the domain typically used to evaluate the + Branin function. See X1_DOMAIN and X2_DOMAIN. It assumes that the inputs are on + [0, 1] x [0, 1]. + """ + states[:, 0] = X1_DOMAIN[0] + states[:, 0] * X1_LENGTH + states[:, 1] = X2_DOMAIN[0] + states[:, 1] * X2_LENGTH + return states + + @staticmethod + def map_to_negative_range(values: TensorType["batch"]) -> TensorType["batch"]: + """ + Maps a batch of function values onto a negative range by substracting an upper + bound of the Branin function in the standard domain (UPPER_BOUND_IN_DOMAIN). + """ + return values - UPPER_BOUND_IN_DOMAIN + + @staticmethod + def map_to_standard_range(values: TensorType["batch"]) -> TensorType["batch"]: + """ + Maps a batch of function values in a negative range back onto the standard + range by adding an upper bound of the Branin function in the standard domain + (UPPER_BOUND_IN_DOMAIN). + """ + return values + UPPER_BOUND_IN_DOMAIN diff --git a/gflownet/proxy/corners.py b/gflownet/proxy/box/corners.py similarity index 100% rename from gflownet/proxy/corners.py rename to gflownet/proxy/box/corners.py diff --git a/gflownet/proxy/box/hartmann.py b/gflownet/proxy/box/hartmann.py new file mode 100644 index 000000000..7d457bac0 --- /dev/null +++ b/gflownet/proxy/box/hartmann.py @@ -0,0 +1,67 @@ +""" +Hartmann objective function, relying on the botorch implementation. + +See: +https://botorch.org/api/test_functions.html#botorch.test_functions.synthetic.Hartmann + +This code is based on the implementation by Nikita Saxena (nikita-0209) in +https://github.com/alexhernandezgarcia/activelearning + +The implementation assumes that the inputs will be on [0, 1]^6 as is typical in the +uses of the Hartmann function. The original range is negative, which is the convention +for other proxy classes, and negate=False is used in the call to the BoTorch method in +order to keep the range. +""" + +import numpy as np +import torch +from botorch.test_functions.multi_fidelity import AugmentedHartmann +from torchtyping import TensorType + +from gflownet.proxy.base import Proxy + + +class Hartmann(Proxy): + def __init__(self, fidelity=1.0, **kwargs): + super().__init__(**kwargs) + self.fidelity = fidelity + self.function_mf_botorch = AugmentedHartmann(negate=False) + # This is just a rough estimate of modes + self.modes = [ + [0.2, 0.2, 0.5, 0.3, 0.3, 0.7], + [0.4, 0.9, 0.9, 0.6, 0.1, 0.0], + [0.3, 0.1, 0.4, 0.3, 0.3, 0.7], + [0.4, 0.9, 0.4, 0.6, 0.0, 0.0], + [0.4, 0.9, 0.6, 0.6, 0.3, 0.0], + ] + # Global optimum, according to BoTorch + self.extremum = -3.32237 + + def __call__(self, states: TensorType["batch", "state_dim"]) -> TensorType["batch"]: + if states.shape[1] != 6: + raise ValueError( + """ + Inputs to the Hartmann function must be 6-dimensional, but inputs with + {states.shape[1]} dimensions were passed. + """ + ) + # Append fidelity as a new dimension of states + states = torch.cat( + [ + states, + self.fidelity + * torch.ones( + states.shape[0], device=self.device, dtype=self.float + ).unsqueeze(-1), + ], + dim=1, + ) + return self.function_mf_botorch(states) + + @property + def min(self): + if not hasattr(self, "_min"): + self._min = torch.tensor( + self.extremum, device=self.device, dtype=self.float + ) + return self._min diff --git a/pyproject.toml b/pyproject.toml index 8743e587f..1f027dfc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ priority = "primary" [tool.poetry.dependencies] # Base dependencies. +botorch = ">=0.10.0" hydra-core = ">=1.3.2" matplotlib = "*" numpy = "*" diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 803704ecc..2d2a766a8 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -5,7 +5,7 @@ from gflownet.envs.ctorus import ContinuousTorus from gflownet.envs.grid import Grid from gflownet.envs.tetris import Tetris -from gflownet.proxy.corners import Corners +from gflownet.proxy.box.corners import Corners from gflownet.proxy.tetris import Tetris as TetrisScore from gflownet.utils.batch import Batch from gflownet.utils.common import (