Skip to content

Commit

Permalink
Rename variables and create additional methods for consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed Jul 12, 2024
1 parent 95bd066 commit 0c3d8aa
Showing 1 changed file with 84 additions and 44 deletions.
128 changes: 84 additions & 44 deletions gflownet/utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,18 @@ def __init__(
self.states_policy = None
self.parents_policy = None
# Flags for available items
self.parents_available = False
self.parents_policy_available = False
self.parents_all_available = False
self.masks_forward_available = False
self.masks_backward_available = False
self._parents_available = False
self._parents_policy_available = False
self._parents_all_available = False
self._masks_forward_available = False
self._masks_backward_available = False
self._rewards_available = False
self.rewards_parents_available = False
self.rewards_source_available = False
self._rewards_parents_available = False
self._rewards_source_available = False
self._logrewards_available = False
self.logrewards_parents_available = False
self.logrewards_source_available = False
self.proxy_values_available = False
self._logrewards_parents_available = False
self._logrewards_source_available = False
self._proxy_values_available = False

def __len__(self):
return self.size
Expand Down Expand Up @@ -159,6 +159,46 @@ def rewards_available(self, log: bool = False) -> bool:
else:
return self._rewards_available

def rewards_parents_available(self, log: bool = False) -> bool:
"""
Returns True if the (log)rewards of the parents are available.
Parameters
----------
log : bool
If True, check self._logrewards_parents_available. Otherwise (default),
check self._rewards_parents_available.
Returns
-------
bool
True if the (log)rewards of the parents are available, False otherwise.
"""
if log:
return self._logrewards_parents_available
else:
return self._rewards_parents_available

def rewards_source_available(self, log: bool = False) -> bool:
"""
Returns True if the (log)rewards of the source are available.
Parameters
----------
log : bool
If True, check self._logrewards_source_available. Otherwise (default),
check self._rewards_source_available.
Returns
-------
bool
True if the (log)rewards of the source are available, False otherwise.
"""
if log:
return self._logrewards_source_available
else:
return self._rewards_source_available

def set_env(self, env: GFlowNetEnv):
"""
Sets the generic environment passed as an argument and initializes the
Expand Down Expand Up @@ -272,10 +312,10 @@ def add_to_batch(
# Increment size of batch
self.size += 1
# Other variables are not available after new items were added to the batch
self.masks_forward_available = False
self.masks_backward_available = False
self.parents_policy_available = False
self.parents_all_available = False
self._masks_forward_available = False
self._masks_backward_available = False
self._parents_policy_available = False
self._parents_all_available = False
self._rewards_available = False
self._logrewards_available = False

Expand Down Expand Up @@ -546,10 +586,10 @@ def get_parents(
self.parents or self.parents_policy : torch.tensor
The parent of all states in the batch.
"""
if self.parents_available is False or force_recompute is True:
if self._parents_available is False or force_recompute is True:
self._compute_parents()
if policy:
if self.parents_policy_available is False or force_recompute is True:
if self._parents_policy_available is False or force_recompute is True:
self._compute_parents_policy()
return self.parents_policy
else:
Expand All @@ -568,7 +608,7 @@ def get_parents_indices(self):
self.parents_indices
The indices in self.states of the parents of self.states.
"""
if self.parents_available is False:
if self._parents_available is False:
self._compute_parents()
return self.parents_indices

Expand All @@ -589,7 +629,7 @@ def _compute_parents(self):
parent is not present in self.states (i.e. it is source), the corresponding
index is -1.
self.parents_available is set to True.
self._parents_available is set to True.
"""
self.parents = []
self.parents_indices = []
Expand Down Expand Up @@ -621,7 +661,7 @@ def _compute_parents(self):
[self.parents_indices[indices_dict[idx]] for idx in range(len(self))],
device=self.device,
)
self.parents_available = True
self._parents_available = True

# TODO: consider converting directly from self.parents
def _compute_parents_policy(self):
Expand All @@ -636,7 +676,7 @@ def _compute_parents_policy(self):
Shape: [n_states, state_policy_dims]
self.parents_policy is stored as a torch tensor and
self.parents_policy_available is set to True.
self._parents_policy_available is set to True.
"""
self.states_policy = self.get_states(policy=True)
self.parents_policy = torch.zeros_like(self.states_policy)
Expand All @@ -652,7 +692,7 @@ def _compute_parents_policy(self):
self.parents_policy[batch_indices[1:]] = self.states_policy[
batch_indices[:-1]
]
self.parents_policy_available = True
self._parents_policy_available = True

def get_parents_all(
self, policy: bool = False, force_recompute: bool = False
Expand All @@ -664,7 +704,7 @@ def get_parents_all(
"""
Returns the whole set of parents, their corresponding actions and indices of
all states in the batch. If the parents are not available
(self.parents_all_available is False) or if force_recompute is True, then
(self._parents_all_available is False) or if force_recompute is True, then
self._compute_parents_all() is called to compute the required components.
The parents are returned in "policy format" if policy is True, otherwise they
Expand Down Expand Up @@ -696,7 +736,7 @@ def get_parents_all(
"""
if self.continuous:
raise Exception("get_parents() is ill-defined for continuous environments!")
if self.parents_all_available is False or force_recompute is True:
if self._parents_all_available is False or force_recompute is True:
self._compute_parents_all()
if policy:
return (
Expand Down Expand Up @@ -726,7 +766,7 @@ def _compute_parents_all(self):
Shape: [n_parents, state_policy_dims]
All the above components are stored as torch tensors and
self.parents_all_available is set to True.
self._parents_all_available is set to True.
"""
# Iterate over the trajectories to obtain all parents
self.parents_all = []
Expand Down Expand Up @@ -763,7 +803,7 @@ def _compute_parents_all(self):
device=self.device,
)
self.parents_all_policy = torch.cat(self.parents_all_policy)
self.parents_all_available = True
self._parents_all_available = True

# TODO: opportunity to improve efficiency by caching.
def get_masks_forward(
Expand Down Expand Up @@ -791,7 +831,7 @@ def get_masks_forward(
self.masks_invalid_actions_forward : torch.tensor
The forward mask of all states in the batch.
"""
if self.masks_forward_available is False or force_recompute is True:
if self._masks_forward_available is False or force_recompute is True:
self._compute_masks_forward()
# Make tensor
masks_invalid_actions_forward = tbool(
Expand Down Expand Up @@ -826,8 +866,8 @@ def get_masks_forward(
def _compute_masks_forward(self):
"""
Computes the forward mask of invalid actions of all states in the batch, by
calling env.get_mask_invalid_actions_forward(). self.masks_forward_available is
set to True.
calling env.get_mask_invalid_actions_forward(). self._masks_forward_available
is set to True.
"""
# Iterate over the trajectories to compute all forward masks
for idx, mask in enumerate(self.masks_invalid_actions_forward):
Expand All @@ -839,7 +879,7 @@ def _compute_masks_forward(self):
self.masks_invalid_actions_forward[idx] = self.envs[
traj_idx
].get_mask_invalid_actions_forward(state, done)
self.masks_forward_available = True
self._masks_forward_available = True

# TODO: opportunity to improve efficiency by caching. Note that
# env.get_masks_invalid_actions_backward() may be expensive because it calls
Expand All @@ -863,14 +903,14 @@ def get_masks_backward(
self.masks_invalid_actions_backward : torch.tensor
The backward mask of all states in the batch.
"""
if self.masks_backward_available is False or force_recompute is True:
if self._masks_backward_available is False or force_recompute is True:
self._compute_masks_backward()
return tbool(self.masks_invalid_actions_backward, device=self.device)

def _compute_masks_backward(self):
"""
Computes the backward mask of invalid actions of all states in the batch, by
calling env.get_mask_invalid_actions_backward(). self.masks_backward_available
calling env.get_mask_invalid_actions_backward(). self._masks_backward_available
is set to True.
"""
# Iterate over the trajectories to compute all backward masks
Expand All @@ -883,7 +923,7 @@ def _compute_masks_backward(self):
self.masks_invalid_actions_backward[idx] = self.envs[
traj_idx
].get_mask_invalid_actions_backward(state, done)
self.masks_backward_available = True
self._masks_backward_available = True

# TODO: better handling of availability of rewards, logrewards, proxy_values.
def get_rewards(
Expand Down Expand Up @@ -928,7 +968,7 @@ def get_proxy_values(
If True, return the actual proxy values of the non-terminating states. If
False, non-terminating states will be assigned value inf.
"""
if self.proxy_values_available is False or force_recompute is True:
if self._proxy_values_available is False or force_recompute is True:
self._compute_rewards(do_non_terminating=do_non_terminating)
return self.proxy_values

Expand Down Expand Up @@ -965,7 +1005,7 @@ def _compute_rewards(
)

self.proxy_values = proxy_values
self.proxy_values_available = True
self._proxy_values_available = True
if log:
self.logrewards = rewards
self._logrewards_available = True
Expand All @@ -987,7 +1027,7 @@ def get_rewards_parents(self, log: bool = False) -> TensorType["n_states"]:
self.rewards_parents or self.logrewards_parents
A tensor containing the rewards of the parents of self.states.
"""
if not self.rewards_parents_available:
if not self.rewards_parents_available(log):
self._compute_rewards_parents(log)
if log:
return self.logrewards_parents
Expand Down Expand Up @@ -1019,10 +1059,10 @@ def _compute_rewards_parents(self, log: bool = False):
rewards_parents[parent_is_source] = rewards_source[parent_is_source]
if log:
self.logrewards_parents = rewards_parents
self.logrewards_parents_available = True
self._logrewards_parents_available = True
else:
self.rewards_parents = rewards_parents
self.rewards_parents_available = True
self._rewards_parents_available = True

def get_rewards_source(self, log: bool = False) -> TensorType["n_states"]:
"""
Expand All @@ -1038,7 +1078,7 @@ def get_rewards_source(self, log: bool = False) -> TensorType["n_states"]:
self.rewards_source or self.logrewards_source
A tensor containing the rewards the source states.
"""
if not self.rewards_source_available:
if not self.rewards_source_available(log):
self._compute_rewards_source(log)
if log:
return self.logrewards_source
Expand Down Expand Up @@ -1066,10 +1106,10 @@ def _compute_rewards_source(self, log: bool = False):
raise NotImplementedError
if log:
self.logrewards_source = rewards_source
self.logrewards_source_available = True
self._logrewards_source_available = True
else:
self.rewards_source = rewards_source
self.rewards_source_available = True
self._rewards_source_available = True

def get_terminating_states(
self,
Expand Down Expand Up @@ -1202,7 +1242,7 @@ def get_terminating_proxy_values(
indices = np.argsort(self.traj_indices)
else:
raise ValueError("sort_by must be either insert[ion] or traj[ectory]")
if self.proxy_values_available is False or force_recompute is True:
if self._proxy_values_available is False or force_recompute is True:
self._compute_rewards(log, do_non_terminating=False)
done = self.get_done()[indices]
return self.proxy_values[indices][done]
Expand Down Expand Up @@ -1313,15 +1353,15 @@ def merge(self, batches: List):
self.states_policy = extend(self.states_policy, batch.states_policy)
else:
self.states_policy = None
if self.parents_available and batch.parents_available:
if self._parents_available and batch._parents_available:
self.parents = extend(self.parents, batch.parents)
else:
self.parents = None
if self.parents_policy_available and batch.parents_policy_available:
if self._parents_policy_available and batch._parents_policy_available:
self.parents_policy = extend(self.parents_policy, batch.parents_policy)
else:
self.parents_policy = None
if self.parents_all_available and batch.parents_all_available:
if self._parents_all_available and batch._parents_all_available:
self.parents_all = extend(self.parents_all, batch.parents_all)
else:
self.parents_all = None
Expand Down

0 comments on commit 0c3d8aa

Please sign in to comment.