diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 9e9484463..42ce57b53 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -1538,7 +1538,7 @@ def sample_from_reward( format. """ samples_final = [] - max_reward = self.proxy.proxy2reward(self.proxy.min) + max_reward = self.get_max_reward() while len(samples_final) < n_samples: if proposal_distribution == "uniform": # TODO: sample only the remaining number of samples diff --git a/gflownet/proxy/base.py b/gflownet/proxy/base.py index 019e9f118..14f5ac268 100644 --- a/gflownet/proxy/base.py +++ b/gflownet/proxy/base.py @@ -177,13 +177,59 @@ def get_min_reward(self, log: bool = False) -> float: Returns ------- float - The mimnimum (log) reward. + The minimum (log) reward. """ if log: return self.logreward_min else: return self.reward_min + def get_max_reward(self, log: bool = False) -> float: + """ + Returns the maximum value of the (log) reward, retrieved from self.optimum, in + case it is defined. + + Parameters + ---------- + log : bool + If True, returns the logarithm of the maximum reward. If False (default), + returns the natural maximum reward. + + Returns + ------- + float + The maximum (log) reward. + """ + if log: + return self.proxy2logreward(self.optimum) + else: + return self.proxy2reward(self.optimum) + + @property + def optimum(self): + """ + Returns the optimum value of the proxy. + + Not implemented by default but may be implemented for synthetic proxies or when + the optimum is known. + + The optimum is used, for example, to accelerate rejection sampling, to sample + from the reward function. + """ + if not hasattr(self, "_optimum"): + raise NotImplementedError( + "The optimum value of the proxy needs to be implemented explicitly for " + f"each Proxy and is not available for {self.__class__}." + ) + return self._optimum + + @optimum.setter + def optimum(self, value): + """ + Sets the optimum value of the proxy. + """ + self._optimum = value + def _get_reward_functions( self, reward_function: Union[Callable, str], diff --git a/gflownet/proxy/corners.py b/gflownet/proxy/corners.py index 839d30a09..e1db911f3 100644 --- a/gflownet/proxy/corners.py +++ b/gflownet/proxy/corners.py @@ -31,13 +31,13 @@ def setup(self, env=None): self.mulnormal_norm = 1.0 / ((2 * torch.pi) ** 2 * cov_det) ** 0.5 @property - def min(self): - if not hasattr(self, "_min"): + def optimum(self): + if not hasattr(self, "_optimum"): mode = self.mu * torch.ones( self.n_dim, device=self.device, dtype=self.float ) - self._min = self(torch.unsqueeze(mode, 0))[0] - return self._min + self._optimum = self(torch.unsqueeze(mode, 0))[0] + return self._optimum def __call__(self, states: TensorType["batch", "state_dim"]) -> TensorType["batch"]: return self.mulnormal_norm * torch.exp( diff --git a/gflownet/proxy/torus.py b/gflownet/proxy/torus.py index 6b78c9aad..7dbd5b0e8 100644 --- a/gflownet/proxy/torus.py +++ b/gflownet/proxy/torus.py @@ -16,15 +16,15 @@ def setup(self, env=None): self.n_dim = env.n_dim @property - def min(self): - if not hasattr(self, "_min"): + def optimum(self): + if not hasattr(self, "_optimum"): if self.normalize: - self._min = torch.tensor(0.0, device=self.device, dtype=self.float) + self._optimum = torch.tensor(1.0, device=self.device, dtype=self.float) else: - self._min = torch.tensor( + self._optimum = torch.tensor( ((self.n_dim * 2) ** 3), device=self.device, dtype=self.float ) - return self._min + return self._optimum @property def norm(self): diff --git a/gflownet/proxy/uniform.py b/gflownet/proxy/uniform.py index d9c6d6b0a..659612805 100644 --- a/gflownet/proxy/uniform.py +++ b/gflownet/proxy/uniform.py @@ -9,14 +9,9 @@ class Uniform(Proxy): def __init__(self, **kwargs): super().__init__(**kwargs) + self._optimum = torch.tensor(1.0, device=self.device, dtype=self.float) def __call__( self, states: Union[List, TensorType["batch", "state_dim"]] ) -> TensorType["batch"]: return torch.ones(len(states), device=self.device, dtype=self.float) - - @property - def min(self): - if not hasattr(self, "_min"): - self._min = torch.tensor(1.0, device=self.device, dtype=self.float) - return self._min diff --git a/tests/gflownet/proxy/test_base.py b/tests/gflownet/proxy/test_base.py index e142dca77..812b158ee 100644 --- a/tests/gflownet/proxy/test_base.py +++ b/tests/gflownet/proxy/test_base.py @@ -623,3 +623,29 @@ def test_reward_function_callable__behaves_as_expected( # Log Rewards logrewards_exp = tfloat(logrewards_exp, device=proxy.device, float_type=proxy.float) assert all(check_proxy2reward(proxy.proxy2logreward(proxy_values), logrewards_exp)) + + +@pytest.mark.parametrize( + "proxy, beta, optimum, reward_max", + [ + ("uniform", None, 1.0, 1.0), + ("uniform", None, 2.0, 2.0), + ("proxy_power", 1, 2.0, 2.0), + ("proxy_power", 2, 2.0, 4.0), + ("proxy_exponential", 1, 1.0, np.exp(1.0)), + ("proxy_exponential", -1, -1.0, np.exp(1.0)), + ("proxy_shift", 5, 10.0, 15.0), + ("proxy_shift", -5, 10.0, 5.0), + ("proxy_product", 2, 2.0, 4.0), + ("proxy_product", -2, -5.0, 10.0), + ], +) +def test__uniform_proxy_initializes_without_errors( + proxy, beta, optimum, reward_max, request +): + proxy = request.getfixturevalue(proxy) + reward_max = torch.tensor(reward_max, dtype=proxy.float, device=proxy.device) + # Forcibly set the optimum for testing purposes, even if the proxy is uniform. + proxy.optimum = torch.tensor(optimum) + assert torch.isclose(proxy.get_max_reward(log=False), reward_max) + assert torch.isclose(proxy.get_max_reward(log=True), torch.log(reward_max))