Skip to content

Commit

Permalink
Merge pull request #231 from GFNOrg/recalculate_vs_save
Browse files Browse the repository at this point in the history
Standardize recalculate_all_logprobs parameter across PF-based GFlowNets
  • Loading branch information
saleml authored Jan 24, 2025
2 parents 8b52c4a + 752b937 commit fb17e2b
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 13 deletions.
23 changes: 16 additions & 7 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,8 @@ def loss(self, env: Env, training_objects: Any):
"""Computes the loss given the training objects."""


class PFBasedGFlowNet(GFlowNet[TrainingSampleType]):
r"""Base class for gflownets that explicitly uses $P_F$.
Attributes:
pf: GFNModule
pb: GFNModule
"""
class PFBasedGFlowNet(GFlowNet[TrainingSampleType], ABC):
"""A GFlowNet that uses forward (PF) and backward (PB) policy networks."""

def __init__(self, pf: GFNModule, pb: GFNModule):
super().__init__()
Expand Down Expand Up @@ -116,6 +111,20 @@ def pf_pb_named_parameters(self):
def pf_pb_parameters(self):
return [v for k, v in self.named_parameters() if "pb" in k or "pf" in k]

@abstractmethod
def loss(
self, env: Env, training_objects: Any, recalculate_all_logprobs: bool = False
):
"""Computes the loss given the training objects.
Args:
env: The environment to compute the loss for
training_objects: The objects to compute the loss on
recalculate_all_logprobs: If True, always recalculate logprobs even if they exist.
If False, use existing logprobs when available.
**kwargs: Additional arguments specific to the loss
"""


class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]):
def get_pfs_and_pbs(
Expand Down
8 changes: 6 additions & 2 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,13 @@ def get_scores(

return scores

def loss(self, env: Env, transitions: Transitions) -> torch.Tensor:
def loss(
self, env: Env, transitions: Transitions, recalculate_all_logprobs: bool = False
) -> torch.Tensor:
"""Calculates the modified detailed balance loss."""
scores = self.get_scores(transitions)
scores = self.get_scores(
transitions, recalculate_all_logprobs=recalculate_all_logprobs
)
return torch.mean(scores**2)

def to_training_samples(self, trajectories: Trajectories) -> Transitions:
Expand Down
20 changes: 16 additions & 4 deletions src/gfn/gflownet/sub_trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,10 @@ def calculate_masks(
return full_mask, sink_states_mask, is_terminal_mask

def get_scores(
self, env: Env, trajectories: Trajectories
self,
env: Env,
trajectories: Trajectories,
recalculate_all_logprobs: bool = False,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Scores all submitted trajectories.
Expand All @@ -290,7 +293,9 @@ def get_scores(
True if the corresponding sub-trajectory does not exist.
"""
log_pf_trajectories, log_pb_trajectories = self.get_pfs_and_pbs(
trajectories, fill_value=-float("inf")
trajectories,
fill_value=-float("inf"),
recalculate_all_logprobs=recalculate_all_logprobs,
)

log_pf_trajectories_cum = self.cumulative_logprobs(
Expand Down Expand Up @@ -493,9 +498,16 @@ def get_geometric_within_contributions(

return contributions

def loss(self, env: Env, trajectories: Trajectories) -> torch.Tensor:
def loss(
self,
env: Env,
trajectories: Trajectories,
recalculate_all_logprobs: bool = False,
) -> torch.Tensor:
# Get all scores and masks from the trajectories.
scores, flattening_masks = self.get_scores(env, trajectories)
scores, flattening_masks = self.get_scores(
env, trajectories, recalculate_all_logprobs=recalculate_all_logprobs
)
flattening_mask = torch.cat(flattening_masks)
all_scores = torch.cat(scores, 0)

Expand Down

0 comments on commit fb17e2b

Please sign in to comment.