-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Scalar estimators allow for the reduction over many output values (i.… #215
Conversation
…e., the output of the nn.Module does not need to be a scalar, because the Estimator will apply a reduction to the final output if required).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! I added a comment for possible refactoring.
src/gfn/modules.py
Outdated
reduction_fxns = { | ||
"mean": torch.mean, | ||
"sum": torch.sum, | ||
"prod": torch.prod, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here you can use the global constant, if you follow the previous comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks - I've done this
src/gfn/modules.py
Outdated
reduction_fxns = { | ||
"mean": torch.mean, | ||
"sum": torch.sum, | ||
"prod": torch.prod, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is constant, for convention it should go outside with upper case name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! I left a few questions and comments below.
src/gfn/modules.py
Outdated
Attributes: | ||
preprocessor: Preprocessor object that transforms raw States objects to tensors | ||
that can be used as input to the module. Optional, defaults to | ||
`IdentityPreprocessor`. | ||
module: The module to use. If the module is a Tabular module (from | ||
`gfn.utils.modules`), then the environment preprocessor needs to be an | ||
`EnumPreprocessor`. | ||
preprocessor: Preprocessor from the environment. | ||
_output_dim_is_checked: Flag for tracking whether the output dimenions of | ||
the states (after being preprocessed and transformed by the modules) have | ||
been verified. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be updated accordingly, e.g., add is_backward
and reduction
and remove _output_dim_is_checked
.
src/gfn/modules.py
Outdated
Attributes: | ||
preprocessor: Preprocessor object that transforms raw States objects to tensors | ||
that can be used as input to the module. Optional, defaults to | ||
`IdentityPreprocessor`. | ||
module: The module to use. If the module is a Tabular module (from | ||
`gfn.utils.modules`), then the environment preprocessor needs to be an | ||
`EnumPreprocessor`. | ||
preprocessor: Preprocessor from the environment. | ||
reduction_fxn: the selected torch reduction operation. | ||
_output_dim_is_checked: Flag for tracking whether the output dimensions of | ||
the states (after being preprocessed and transformed by the modules) have | ||
been verified. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DITTO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks :)
src/gfn/modules.py
Outdated
@@ -134,9 +134,71 @@ def to_probability_distribution( | |||
|
|||
|
|||
class ScalarEstimator(GFNModule): | |||
r"""Class for estimating scalars such as LogZ. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that logZ
for unconditional TB is usually modeled with a single learnable parameter (nn.Parameter
).
Should we consider modifying ScalarEstimator
to support this kind of behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The GFNs themselves support this directly (you do not need to pass an estimator at all, instead you just pass a float for Z
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is because of such as LogZ
in the docstring!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not entirely sure what would be most clear here but I'm open to suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just state flow functions of DB/SubTB??
src/gfn/samplers.py
Outdated
) -> Tuple[ | ||
Actions, | ||
torch.Tensor | None, | ||
torch.Tensor | None, | ||
]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing the last ,
will make this one line.
def expected_output_dim(self) -> int: | ||
return 1 | ||
|
||
def forward(self, input: States | torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In which case is the input torch.Tensor
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I was looking at this and not entirely sure. It might be in the case of conditioning, where we currently don't have any sort of container, conditioning is done with a raw Tensor
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, it should be conditioning (e.g., conditional log Z(c)).
However, it might be a bit confusing whether to use ConditionalScalarEstimator
or ScalarEstimator
to model log Z(c).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well of note, ScalarEstimators are used for more than just logZ, but in this case, I see it like this:
- LogZ can be a single parameter.
- LogZ can be estimated using a neural network - in this case, the output of the network can actually be multiple items that are averaged over.
- LogZ can be conditionally estimated using a neural network - in this case, the output of the network can actually be multiple items that are averaged over.
From an optimization POV, sometimes having logZ only be estimated by a single parameter can cause problems (i.e., the gradients push the number around a lot), so using a neural network helps.
I agree we could make it clearer though -- I am open to suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ConditionalScalarEstimator
is used to take in both the State and the Conditioning, i.e., it's a two-headed estimator. I think this is the normal conditioning case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved!
…e., the output of the nn.Module does not need to be a scalar, because the Estimator will apply a reduction to the final output if required).
I am not in love with the current code organization of
modules.py
-- there is some duplication, but I am thinking that a bigger refactoring effort might be en route and perhaps we should wait to optimize. Open to feedback on this!