From 95bd066cde9a492dfb83f10cfd561ef4b5ff2014 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 12 Jul 2024 13:23:41 +0200 Subject: [PATCH] Fix bug in checking whether (log)rewards are available --- gflownet/gflownet.py | 12 ++++++------ gflownet/utils/batch.py | 40 ++++++++++++++++++++++++++++++---------- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index d464e6c7..e7adcc6d 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -1178,17 +1178,17 @@ def train(self): # log-rewards, the latter are computed by taking the log of the rewards. # Numerical issues are not critical in this case, since the derived values # are only used for reporting purposes. - if batch.rewards_available: + if batch.rewards_available(log=False): rewards = batch.get_terminating_rewards(sort_by="trajectory") - if batch.logrewards_available: + if batch.rewards_available(log=True): logrewards = batch.get_terminating_rewards( sort_by="trajectory", log=True ) - if not batch.rewards_available: - assert batch.logrewards_available + if not batch.rewards_available(log=False): + assert batch.rewards_available(log=True) rewards = torch.exp(logrewards) - if not batch.logrewards_available: - assert batch.rewards_available + if not batch.rewards_available(log=True): + assert batch.rewards_available(log=False) logrewards = torch.log(rewards) rewards = rewards.tolist() logrewards = logrewards.tolist() diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index 8d756623..bd319ba7 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -99,10 +99,10 @@ def __init__( self.parents_all_available = False self.masks_forward_available = False self.masks_backward_available = False - self.rewards_available = False + self._rewards_available = False self.rewards_parents_available = False self.rewards_source_available = False - self.logrewards_available = False + self._logrewards_available = False self.logrewards_parents_available = False self.logrewards_source_available = False self.proxy_values_available = False @@ -139,6 +139,26 @@ def traj_idx_action_idx_to_batch_idx( def idx2state_idx(self, idx: int): return self.trajectories[self.traj_indices[idx]].index(idx) + def rewards_available(self, log: bool = False) -> bool: + """ + Returns True if the (log)rewards are available. + + Parameters + ---------- + log : bool + If True, check self._logrewards_available. Otherwise (default), check + self._rewards_available. + + Returns + ------- + bool + True if the (log)rewards are available, False otherwise. + """ + if log: + return self._logrewards_available + else: + return self._rewards_available + def set_env(self, env: GFlowNetEnv): """ Sets the generic environment passed as an argument and initializes the @@ -256,8 +276,8 @@ def add_to_batch( self.masks_backward_available = False self.parents_policy_available = False self.parents_all_available = False - self.rewards_available = False - self.logrewards_available = False + self._rewards_available = False + self._logrewards_available = False def get_n_trajectories(self) -> int: """ @@ -885,7 +905,7 @@ def get_rewards( If True, return the actual rewards of the non-terminating states. If False, non-terminating states will be assigned reward 0. """ - if self.rewards_available is False or force_recompute is True: + if self.rewards_available(log) is False or force_recompute is True: self._compute_rewards(log, do_non_terminating) if log: return self.logrewards @@ -948,10 +968,10 @@ def _compute_rewards( self.proxy_values_available = True if log: self.logrewards = rewards - self.logrewards_available = True + self._logrewards_available = True else: self.rewards = rewards - self.rewards_available = True + self._rewards_available = True def get_rewards_parents(self, log: bool = False) -> TensorType["n_states"]: """ @@ -1146,7 +1166,7 @@ def get_terminating_rewards( indices = np.argsort(self.traj_indices) else: raise ValueError("sort_by must be either insert[ion] or traj[ectory]") - if self.rewards_available is False or force_recompute is True: + if self.rewards_available(log) is False or force_recompute is True: self._compute_rewards(log, do_non_terminating=False) done = self.get_done()[indices] if log: @@ -1305,11 +1325,11 @@ def merge(self, batches: List): self.parents_all = extend(self.parents_all, batch.parents_all) else: self.parents_all = None - if self.rewards_available and batch.rewards_available: + if self._rewards_available and batch._rewards_available: self.rewards = extend(self.rewards, batch.rewards) else: self.rewards = None - if self.logrewards_available and batch.logrewards_available: + if self._logrewards_available and batch._logrewards_available: self.logrewards = extend(self.logrewards, batch.logrewards) else: self.logrewards = None