Skip to content

Commit

Permalink
Extend docstring and wrap docstring lines
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed Nov 27, 2023
1 parent 993b9de commit 06a6fb5
Showing 1 changed file with 36 additions and 12 deletions.
48 changes: 36 additions & 12 deletions gflownet/utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,27 +525,37 @@ def get_parents(

def get_parents_indices(self):
"""
Returns indices of the parents of the states in the batch.
Each index corresponds to the position of the patent in the self.states tensor, if it is peresent there.
If a parent is not present in self.states (i.e. it is source), the corresponding index is -1
Returns the indices of the parents of the states in the batch.
Each item idx in the returned list corresponds to the index in self.states that
contains the parent of self.states[idx], if it is peresent there. If a parent
is not present in self.states (because it is the source), the index is -1.
Returns
-------
self.parents_indices
The indices in self.states of the parents of self.states.
"""
if self.parents_available is False:
self._compute_parents()
return self.parents_indices

def _compute_parents(self):
"""
Obtains the parent (single parent for each state) of all states in the batch and its index.
Obtains the parent (single parent for each state) of all states in the batch
and its index.
The parents are computed, obtaining all necessary components, if they are not
readily available. Missing components and newly computed components are added
to the batch (self.component is set). The following variable is stored:
to the batch (self.component is set). The following variables are stored:
- self.parents: the parent of each state in the batch. It will be the same type
as self.states (list of lists or tensor)
Length: n_states
Shape: [n_states, state_dims]
- self.parents_indices: the position of each parent in self.states tensor.
If a parent is not present in self.states (i.e. it is source), the corresponding index is -1
- self.parents_indices: the position of each parent in self.states tensor. If a
parent is not present in self.states (i.e. it is source), the corresponding
index is -1.
self.parents_available is set to True.
"""
Expand Down Expand Up @@ -887,16 +897,23 @@ def _compute_rewards(self, do_non_terminating: Optional[bool] = False):

def get_rewards_parents(self) -> TensorType["n_states"]:
"""
Returns the rewards of all parents in the batch
Returns the rewards of all parents in the batch.
Returns
-------
self.rewards_parents
A tensor containing the rewards of the parents of self.states.
"""
if not self.rewards_parents_available:
self._compute_rewards_parents()
return self.rewards_parents

def _compute_rewards_parents(self):
"""
Computes rewards of the self.parents by reusing rewards of the states (i.e. self.rewards).
Stores the result in self.rewards_parents
Computes the rewards of self.parents by reusing the rewards of the states
(self.rewards).
Stores the result in self.rewards_parents.
"""
# TODO: this may return zero rewards for all parents if before
# rewards for states were computed with do_non_terminating=False
Expand All @@ -914,15 +931,22 @@ def _compute_rewards_parents(self):
def get_rewards_source(self) -> TensorType["n_states"]:
"""
Returns rewards of the corresponding source states for each state in the batch.
Returns
-------
self.rewards_source
A tensor containing the rewards the source states.
"""
if not self.rewards_source_available:
self._compute_rewards_source()
return self.rewards_source

def _compute_rewards_source(self):
"""
Computes a tensor of length len(self.states) with rewards of the corresponding source states.
Stores the result in self.rewards_source
Computes a tensor of length len(self.states) with the rewards of the
corresponding source states.
Stores the result in self.rewards_source.
"""
# This will not work if source is randomised
if not self.conditional:
Expand Down

0 comments on commit 06a6fb5

Please sign in to comment.