Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize recalculate_all_logprobs parameter across PF-based GFlowNets #231

Merged
merged 10 commits into from
Jan 24, 2025

Conversation

saleml
Copy link
Collaborator

@saleml saleml commented Jan 22, 2025

This PR standardizes the handling of recalculate_all_logprobs parameter across all policy-based GFlowNet implementations (TB, DB, SubTB).

Changes:

  • Added recalculate_all_logprobs parameter to PFBasedGFlowNet.loss() abstract method
  • Added the parameter to all implementing classes (DB, TB, SubTB)
  • Default behavior: use existing logprobs if available (when recalculate_all_logprobs=False)
  • When recalculate_all_logprobs=True, force recalculation of logprobs even if they exist

This change:

Context:

@@ -71,17 +71,12 @@ def to_training_samples(self, trajectories: Trajectories) -> TrainingSampleType:
"""Converts trajectories to training samples. The type depends on the GFlowNet."""

@abstractmethod
def loss(self, env: Env, training_objects: Any):
def loss(self, env: Env, training_objects: Any, **kwargs: Any):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why adding kwargs instead of recalculate_all_logprobs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed indeed

env: Env,
training_objects: Any,
recalculate_all_logprobs: bool = False,
**kwargs: Any,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this kwargs?

Quoding added a commit to Quoding/torchgfn that referenced this pull request Jan 22, 2025
Add doscrings
Add input validation (as proposed by saleml)
Add PFBasedGFlowNet verification instead of only TBGFNs (needs merge
of GFNOrg#231)
Copy link
Collaborator

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm - sorry I thought I already reviewed this one.

@saleml saleml merged commit fb17e2b into master Jan 24, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants