Skip to content

Commit

Permalink
Fix bug in checking whether (log)rewards are available
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed Jul 12, 2024
1 parent 19b1f27 commit 95bd066
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 16 deletions.
12 changes: 6 additions & 6 deletions gflownet/gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,17 +1178,17 @@ def train(self):
# log-rewards, the latter are computed by taking the log of the rewards.
# Numerical issues are not critical in this case, since the derived values
# are only used for reporting purposes.
if batch.rewards_available:
if batch.rewards_available(log=False):
rewards = batch.get_terminating_rewards(sort_by="trajectory")
if batch.logrewards_available:
if batch.rewards_available(log=True):
logrewards = batch.get_terminating_rewards(
sort_by="trajectory", log=True
)
if not batch.rewards_available:
assert batch.logrewards_available
if not batch.rewards_available(log=False):
assert batch.rewards_available(log=True)
rewards = torch.exp(logrewards)
if not batch.logrewards_available:
assert batch.rewards_available
if not batch.rewards_available(log=True):
assert batch.rewards_available(log=False)
logrewards = torch.log(rewards)
rewards = rewards.tolist()
logrewards = logrewards.tolist()
Expand Down
40 changes: 30 additions & 10 deletions gflownet/utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ def __init__(
self.parents_all_available = False
self.masks_forward_available = False
self.masks_backward_available = False
self.rewards_available = False
self._rewards_available = False
self.rewards_parents_available = False
self.rewards_source_available = False
self.logrewards_available = False
self._logrewards_available = False
self.logrewards_parents_available = False
self.logrewards_source_available = False
self.proxy_values_available = False
Expand Down Expand Up @@ -139,6 +139,26 @@ def traj_idx_action_idx_to_batch_idx(
def idx2state_idx(self, idx: int):
return self.trajectories[self.traj_indices[idx]].index(idx)

def rewards_available(self, log: bool = False) -> bool:
"""
Returns True if the (log)rewards are available.
Parameters
----------
log : bool
If True, check self._logrewards_available. Otherwise (default), check
self._rewards_available.
Returns
-------
bool
True if the (log)rewards are available, False otherwise.
"""
if log:
return self._logrewards_available
else:
return self._rewards_available

def set_env(self, env: GFlowNetEnv):
"""
Sets the generic environment passed as an argument and initializes the
Expand Down Expand Up @@ -256,8 +276,8 @@ def add_to_batch(
self.masks_backward_available = False
self.parents_policy_available = False
self.parents_all_available = False
self.rewards_available = False
self.logrewards_available = False
self._rewards_available = False
self._logrewards_available = False

def get_n_trajectories(self) -> int:
"""
Expand Down Expand Up @@ -885,7 +905,7 @@ def get_rewards(
If True, return the actual rewards of the non-terminating states. If
False, non-terminating states will be assigned reward 0.
"""
if self.rewards_available is False or force_recompute is True:
if self.rewards_available(log) is False or force_recompute is True:
self._compute_rewards(log, do_non_terminating)
if log:
return self.logrewards
Expand Down Expand Up @@ -948,10 +968,10 @@ def _compute_rewards(
self.proxy_values_available = True
if log:
self.logrewards = rewards
self.logrewards_available = True
self._logrewards_available = True
else:
self.rewards = rewards
self.rewards_available = True
self._rewards_available = True

def get_rewards_parents(self, log: bool = False) -> TensorType["n_states"]:
"""
Expand Down Expand Up @@ -1146,7 +1166,7 @@ def get_terminating_rewards(
indices = np.argsort(self.traj_indices)
else:
raise ValueError("sort_by must be either insert[ion] or traj[ectory]")
if self.rewards_available is False or force_recompute is True:
if self.rewards_available(log) is False or force_recompute is True:
self._compute_rewards(log, do_non_terminating=False)
done = self.get_done()[indices]
if log:
Expand Down Expand Up @@ -1305,11 +1325,11 @@ def merge(self, batches: List):
self.parents_all = extend(self.parents_all, batch.parents_all)
else:
self.parents_all = None
if self.rewards_available and batch.rewards_available:
if self._rewards_available and batch._rewards_available:
self.rewards = extend(self.rewards, batch.rewards)
else:
self.rewards = None
if self.logrewards_available and batch.logrewards_available:
if self._logrewards_available and batch._logrewards_available:
self.logrewards = extend(self.logrewards, batch.logrewards)
else:
self.logrewards = None
Expand Down

0 comments on commit 95bd066

Please sign in to comment.