-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Standardize recalculate_all_logprobs parameter across PF-based GFlowNets #231
Conversation
src/gfn/gflownet/base.py
Outdated
@@ -71,17 +71,12 @@ def to_training_samples(self, trajectories: Trajectories) -> TrainingSampleType: | |||
"""Converts trajectories to training samples. The type depends on the GFlowNet.""" | |||
|
|||
@abstractmethod | |||
def loss(self, env: Env, training_objects: Any): | |||
def loss(self, env: Env, training_objects: Any, **kwargs: Any): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why adding kwargs instead of recalculate_all_logprobs
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed indeed
src/gfn/gflownet/base.py
Outdated
env: Env, | ||
training_objects: Any, | ||
recalculate_all_logprobs: bool = False, | ||
**kwargs: Any, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this kwargs?
Add doscrings Add input validation (as proposed by saleml) Add PFBasedGFlowNet verification instead of only TBGFNs (needs merge of GFNOrg#231)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm - sorry I thought I already reviewed this one.
This PR standardizes the handling of
recalculate_all_logprobs
parameter across all policy-based GFlowNet implementations (TB, DB, SubTB).Changes:
recalculate_all_logprobs
parameter toPFBasedGFlowNet.loss()
abstract methodrecalculate_all_logprobs=False
)recalculate_all_logprobs=True
, force recalculation of logprobs even if they existThis change:
save_logprobs=False
(introduced in Default behavior: recalculate logprobs #230)Context:
save_logprobs=False
as default during sampling to improve memory efficiencyrecalculate_all_logprobs
parameter