diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index bd319ba7..83fc83aa 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -94,18 +94,18 @@ def __init__( self.states_policy = None self.parents_policy = None # Flags for available items - self.parents_available = False - self.parents_policy_available = False - self.parents_all_available = False - self.masks_forward_available = False - self.masks_backward_available = False + self._parents_available = False + self._parents_policy_available = False + self._parents_all_available = False + self._masks_forward_available = False + self._masks_backward_available = False self._rewards_available = False - self.rewards_parents_available = False - self.rewards_source_available = False + self._rewards_parents_available = False + self._rewards_source_available = False self._logrewards_available = False - self.logrewards_parents_available = False - self.logrewards_source_available = False - self.proxy_values_available = False + self._logrewards_parents_available = False + self._logrewards_source_available = False + self._proxy_values_available = False def __len__(self): return self.size @@ -159,6 +159,46 @@ def rewards_available(self, log: bool = False) -> bool: else: return self._rewards_available + def rewards_parents_available(self, log: bool = False) -> bool: + """ + Returns True if the (log)rewards of the parents are available. + + Parameters + ---------- + log : bool + If True, check self._logrewards_parents_available. Otherwise (default), + check self._rewards_parents_available. + + Returns + ------- + bool + True if the (log)rewards of the parents are available, False otherwise. + """ + if log: + return self._logrewards_parents_available + else: + return self._rewards_parents_available + + def rewards_source_available(self, log: bool = False) -> bool: + """ + Returns True if the (log)rewards of the source are available. + + Parameters + ---------- + log : bool + If True, check self._logrewards_source_available. Otherwise (default), + check self._rewards_source_available. + + Returns + ------- + bool + True if the (log)rewards of the source are available, False otherwise. + """ + if log: + return self._logrewards_source_available + else: + return self._rewards_source_available + def set_env(self, env: GFlowNetEnv): """ Sets the generic environment passed as an argument and initializes the @@ -272,10 +312,10 @@ def add_to_batch( # Increment size of batch self.size += 1 # Other variables are not available after new items were added to the batch - self.masks_forward_available = False - self.masks_backward_available = False - self.parents_policy_available = False - self.parents_all_available = False + self._masks_forward_available = False + self._masks_backward_available = False + self._parents_policy_available = False + self._parents_all_available = False self._rewards_available = False self._logrewards_available = False @@ -546,10 +586,10 @@ def get_parents( self.parents or self.parents_policy : torch.tensor The parent of all states in the batch. """ - if self.parents_available is False or force_recompute is True: + if self._parents_available is False or force_recompute is True: self._compute_parents() if policy: - if self.parents_policy_available is False or force_recompute is True: + if self._parents_policy_available is False or force_recompute is True: self._compute_parents_policy() return self.parents_policy else: @@ -568,7 +608,7 @@ def get_parents_indices(self): self.parents_indices The indices in self.states of the parents of self.states. """ - if self.parents_available is False: + if self._parents_available is False: self._compute_parents() return self.parents_indices @@ -589,7 +629,7 @@ def _compute_parents(self): parent is not present in self.states (i.e. it is source), the corresponding index is -1. - self.parents_available is set to True. + self._parents_available is set to True. """ self.parents = [] self.parents_indices = [] @@ -621,7 +661,7 @@ def _compute_parents(self): [self.parents_indices[indices_dict[idx]] for idx in range(len(self))], device=self.device, ) - self.parents_available = True + self._parents_available = True # TODO: consider converting directly from self.parents def _compute_parents_policy(self): @@ -636,7 +676,7 @@ def _compute_parents_policy(self): Shape: [n_states, state_policy_dims] self.parents_policy is stored as a torch tensor and - self.parents_policy_available is set to True. + self._parents_policy_available is set to True. """ self.states_policy = self.get_states(policy=True) self.parents_policy = torch.zeros_like(self.states_policy) @@ -652,7 +692,7 @@ def _compute_parents_policy(self): self.parents_policy[batch_indices[1:]] = self.states_policy[ batch_indices[:-1] ] - self.parents_policy_available = True + self._parents_policy_available = True def get_parents_all( self, policy: bool = False, force_recompute: bool = False @@ -664,7 +704,7 @@ def get_parents_all( """ Returns the whole set of parents, their corresponding actions and indices of all states in the batch. If the parents are not available - (self.parents_all_available is False) or if force_recompute is True, then + (self._parents_all_available is False) or if force_recompute is True, then self._compute_parents_all() is called to compute the required components. The parents are returned in "policy format" if policy is True, otherwise they @@ -696,7 +736,7 @@ def get_parents_all( """ if self.continuous: raise Exception("get_parents() is ill-defined for continuous environments!") - if self.parents_all_available is False or force_recompute is True: + if self._parents_all_available is False or force_recompute is True: self._compute_parents_all() if policy: return ( @@ -726,7 +766,7 @@ def _compute_parents_all(self): Shape: [n_parents, state_policy_dims] All the above components are stored as torch tensors and - self.parents_all_available is set to True. + self._parents_all_available is set to True. """ # Iterate over the trajectories to obtain all parents self.parents_all = [] @@ -763,7 +803,7 @@ def _compute_parents_all(self): device=self.device, ) self.parents_all_policy = torch.cat(self.parents_all_policy) - self.parents_all_available = True + self._parents_all_available = True # TODO: opportunity to improve efficiency by caching. def get_masks_forward( @@ -791,7 +831,7 @@ def get_masks_forward( self.masks_invalid_actions_forward : torch.tensor The forward mask of all states in the batch. """ - if self.masks_forward_available is False or force_recompute is True: + if self._masks_forward_available is False or force_recompute is True: self._compute_masks_forward() # Make tensor masks_invalid_actions_forward = tbool( @@ -826,8 +866,8 @@ def get_masks_forward( def _compute_masks_forward(self): """ Computes the forward mask of invalid actions of all states in the batch, by - calling env.get_mask_invalid_actions_forward(). self.masks_forward_available is - set to True. + calling env.get_mask_invalid_actions_forward(). self._masks_forward_available + is set to True. """ # Iterate over the trajectories to compute all forward masks for idx, mask in enumerate(self.masks_invalid_actions_forward): @@ -839,7 +879,7 @@ def _compute_masks_forward(self): self.masks_invalid_actions_forward[idx] = self.envs[ traj_idx ].get_mask_invalid_actions_forward(state, done) - self.masks_forward_available = True + self._masks_forward_available = True # TODO: opportunity to improve efficiency by caching. Note that # env.get_masks_invalid_actions_backward() may be expensive because it calls @@ -863,14 +903,14 @@ def get_masks_backward( self.masks_invalid_actions_backward : torch.tensor The backward mask of all states in the batch. """ - if self.masks_backward_available is False or force_recompute is True: + if self._masks_backward_available is False or force_recompute is True: self._compute_masks_backward() return tbool(self.masks_invalid_actions_backward, device=self.device) def _compute_masks_backward(self): """ Computes the backward mask of invalid actions of all states in the batch, by - calling env.get_mask_invalid_actions_backward(). self.masks_backward_available + calling env.get_mask_invalid_actions_backward(). self._masks_backward_available is set to True. """ # Iterate over the trajectories to compute all backward masks @@ -883,7 +923,7 @@ def _compute_masks_backward(self): self.masks_invalid_actions_backward[idx] = self.envs[ traj_idx ].get_mask_invalid_actions_backward(state, done) - self.masks_backward_available = True + self._masks_backward_available = True # TODO: better handling of availability of rewards, logrewards, proxy_values. def get_rewards( @@ -928,7 +968,7 @@ def get_proxy_values( If True, return the actual proxy values of the non-terminating states. If False, non-terminating states will be assigned value inf. """ - if self.proxy_values_available is False or force_recompute is True: + if self._proxy_values_available is False or force_recompute is True: self._compute_rewards(do_non_terminating=do_non_terminating) return self.proxy_values @@ -965,7 +1005,7 @@ def _compute_rewards( ) self.proxy_values = proxy_values - self.proxy_values_available = True + self._proxy_values_available = True if log: self.logrewards = rewards self._logrewards_available = True @@ -987,7 +1027,7 @@ def get_rewards_parents(self, log: bool = False) -> TensorType["n_states"]: self.rewards_parents or self.logrewards_parents A tensor containing the rewards of the parents of self.states. """ - if not self.rewards_parents_available: + if not self.rewards_parents_available(log): self._compute_rewards_parents(log) if log: return self.logrewards_parents @@ -1019,10 +1059,10 @@ def _compute_rewards_parents(self, log: bool = False): rewards_parents[parent_is_source] = rewards_source[parent_is_source] if log: self.logrewards_parents = rewards_parents - self.logrewards_parents_available = True + self._logrewards_parents_available = True else: self.rewards_parents = rewards_parents - self.rewards_parents_available = True + self._rewards_parents_available = True def get_rewards_source(self, log: bool = False) -> TensorType["n_states"]: """ @@ -1038,7 +1078,7 @@ def get_rewards_source(self, log: bool = False) -> TensorType["n_states"]: self.rewards_source or self.logrewards_source A tensor containing the rewards the source states. """ - if not self.rewards_source_available: + if not self.rewards_source_available(log): self._compute_rewards_source(log) if log: return self.logrewards_source @@ -1066,10 +1106,10 @@ def _compute_rewards_source(self, log: bool = False): raise NotImplementedError if log: self.logrewards_source = rewards_source - self.logrewards_source_available = True + self._logrewards_source_available = True else: self.rewards_source = rewards_source - self.rewards_source_available = True + self._rewards_source_available = True def get_terminating_states( self, @@ -1202,7 +1242,7 @@ def get_terminating_proxy_values( indices = np.argsort(self.traj_indices) else: raise ValueError("sort_by must be either insert[ion] or traj[ectory]") - if self.proxy_values_available is False or force_recompute is True: + if self._proxy_values_available is False or force_recompute is True: self._compute_rewards(log, do_non_terminating=False) done = self.get_done()[indices] return self.proxy_values[indices][done] @@ -1313,15 +1353,15 @@ def merge(self, batches: List): self.states_policy = extend(self.states_policy, batch.states_policy) else: self.states_policy = None - if self.parents_available and batch.parents_available: + if self._parents_available and batch._parents_available: self.parents = extend(self.parents, batch.parents) else: self.parents = None - if self.parents_policy_available and batch.parents_policy_available: + if self._parents_policy_available and batch._parents_policy_available: self.parents_policy = extend(self.parents_policy, batch.parents_policy) else: self.parents_policy = None - if self.parents_all_available and batch.parents_all_available: + if self._parents_all_available and batch._parents_all_available: self.parents_all = extend(self.parents_all, batch.parents_all) else: self.parents_all = None