Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scalar estimators allow for the reduction over many output values (i.… #215

Merged
merged 7 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 134 additions & 11 deletions src/gfn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
from gfn.states import DiscreteStates, States
from gfn.utils.distributions import UnsqueezedCategorical

REDUCTION_FXNS = {
"mean": torch.mean,
"sum": torch.sum,
"prod": torch.prod,
}


class GFNModule(ABC, nn.Module):
r"""Base class for modules mapping states distributions.
Expand Down Expand Up @@ -41,9 +47,11 @@ class GFNModule(ABC, nn.Module):
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor from the environment.
_output_dim_is_checked: Flag for tracking whether the output dimenions of
_output_dim_is_checked: Flag for tracking whether the output dimensions of
the states (after being preprocessed and transformed by the modules) have
been verified.
_is_backward: Flag for tracking whether this estimator is used for predicting
probability distributions over parents.
"""

def __init__(
Expand All @@ -52,7 +60,7 @@ def __init__(
preprocessor: Preprocessor | None = None,
is_backward: bool = False,
) -> None:
"""Initalize the FunctionEstimator with an environment and a module.
"""Initialize the GFNModule with nn.Module and a preprocessor.
Args:
module: The module to use. If the module is a Tabular module (from
`gfn.utils.modules`), then the environment preprocessor needs to be an
Expand Down Expand Up @@ -134,9 +142,82 @@ def to_probability_distribution(


class ScalarEstimator(GFNModule):
r"""Class for estimating scalars such as LogZ or state flow functions of DB/SubTB.

Training a GFlowNet requires sometimes requires the estimation of precise scalar
values, such as the partition function of flows on the DAG. This Estimator is
designed for those cases.

The function approximator used for `module` need not directly output a scalar. If
it does not, `reduction` will be used to aggregate the outputs of the module into
a single scalar.

Attributes:
preprocessor: Preprocessor object that transforms raw States objects to tensors
that can be used as input to the module. Optional, defaults to
`IdentityPreprocessor`.
module: The module to use. If the module is a Tabular module (from
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor from the environment.
_output_dim_is_checked: Flag for tracking whether the output dimensions of
the states (after being preprocessed and transformed by the modules) have
been verified.
_is_backward: Flag for tracking whether this estimator is used for predicting
probability distributions over parents.
reduction_function: String denoting the
"""

def __init__(
self,
module: nn.Module,
preprocessor: Preprocessor | None = None,
is_backward: bool = False,
reduction: str = "mean",
):
"""Initialize the GFNModule with a scalar output.
Args:
module: The module to use. If the module is a Tabular module (from
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor object.
is_backward: Flags estimators of probability distributions over parents.
reduction: str name of the one of the REDUCTION_FXNS keys: {}
""".format(
list(REDUCTION_FXNS.keys())
)
super().__init__(module, preprocessor, is_backward)
assert reduction in REDUCTION_FXNS, "reduction function not one of {}".format(
REDUCTION_FXNS.keys()
)
self.reduction_fxn = REDUCTION_FXNS[reduction]

def expected_output_dim(self) -> int:
return 1

def forward(self, input: States | torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which case is the input torch.Tensor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I was looking at this and not entirely sure. It might be in the case of conditioning, where we currently don't have any sort of container, conditioning is done with a raw Tensor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it should be conditioning (e.g., conditional log Z(c)).
However, it might be a bit confusing whether to use ConditionalScalarEstimator or ScalarEstimator to model log Z(c).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well of note, ScalarEstimators are used for more than just logZ, but in this case, I see it like this:

  • LogZ can be a single parameter.
  • LogZ can be estimated using a neural network - in this case, the output of the network can actually be multiple items that are averaged over.
  • LogZ can be conditionally estimated using a neural network - in this case, the output of the network can actually be multiple items that are averaged over.

From an optimization POV, sometimes having logZ only be estimated by a single parameter can cause problems (i.e., the gradients push the number around a lot), so using a neural network helps.

I agree we could make it clearer though -- I am open to suggestions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConditionalScalarEstimator is used to take in both the State and the Conditioning, i.e., it's a two-headed estimator. I think this is the normal conditioning case.

"""Forward pass of the module.

Args:
input: The input to the module, as states or a tensor.

Returns the output of the module, as a tensor of shape (*batch_shape, output_dim).
"""
if isinstance(input, States):
input = self.preprocessor(input)

out = self.module(input)

# Ensures estimator outputs are always scalar.
if out.shape[-1] != 1:
out = self.reduction_fxn(out, -1)

if not self._output_dim_is_checked:
self.check_output_dim(out)
self._output_dim_is_checked = True

return out


class DiscretePolicyEstimator(GFNModule):
r"""Container for forward and backward policy estimators for discrete environments.
Expand Down Expand Up @@ -290,14 +371,57 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:


class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator):
r"""Class for conditionally estimating scalars (LogZ, DB/SubTB state logF).

Training a GFlowNet requires sometimes requires the estimation of precise scalar
values, such as the partition function of flows on the DAG. In the case of a
conditional GFN, the logZ or logF estimate is also conditional. This Estimator is
designed for those cases.

The function approximator used for `final_module` need not directly output a scalar.
If it does not, `reduction` will be used to aggregate the outputs of the module into
a single scalar.

Attributes:
preprocessor: Preprocessor object that transforms raw States objects to tensors
that can be used as input to the module. Optional, defaults to
`IdentityPreprocessor`.
module: The module to use. If the module is a Tabular module (from
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor from the environment.
reduction_fxn: the selected torch reduction operation.
_output_dim_is_checked: Flag for tracking whether the output dimensions of
the states (after being preprocessed and transformed by the modules) have
been verified.
_is_backward: Flag for tracking whether this estimator is used for predicting
probability distributions over parents.
reduction_function: String denoting the
"""

def __init__(
self,
state_module: nn.Module,
conditioning_module: nn.Module,
final_module: nn.Module,
preprocessor: Preprocessor | None = None,
is_backward: bool = False,
reduction: str = "mean",
):
"""Initialize a conditional GFNModule with a scalar output.
Args:
state_module: The module to use for state representations. If the module is
a Tabular module (from `gfn.utils.modules`), then the environment
preprocessor needs to be an `EnumPreprocessor`.
conditioning_module: The module to use for conditioning representations.
final_module: The module to use for computing the final output.
preprocessor: Preprocessor object.
is_backward: Flags estimators of probability distributions over parents.
reduction: str name of the one of the REDUCTION_FXNS keys: {}
""".format(
list(REDUCTION_FXNS.keys())
)

super().__init__(
state_module,
conditioning_module,
Expand All @@ -306,6 +430,10 @@ def __init__(
preprocessor=preprocessor,
is_backward=is_backward,
)
assert reduction in REDUCTION_FXNS, "reduction function not one of {}".format(
REDUCTION_FXNS.keys()
)
self.reduction_fxn = REDUCTION_FXNS[reduction]

def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:
"""Forward pass of the module.
Expand All @@ -318,6 +446,10 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:
"""
out = self._forward_trunk(states, conditioning)

# Ensures estimator outputs are always scalar.
if out.shape[-1] != 1:
out = self.reduction_fxn(out, -1)

if not self._output_dim_is_checked:
self.check_output_dim(out)
self._output_dim_is_checked = True
Expand All @@ -333,13 +465,4 @@ def to_probability_distribution(
module_output: torch.Tensor,
**policy_kwargs: Any,
) -> Distribution:
"""Transform the output of the module into a probability distribution.

Args:
states: The states to use.
module_output: The output of the module as a tensor of shape (*batch_shape, output_dim).
**policy_kwargs: Keyword arguments to modify the distribution.

Returns a distribution object.
"""
raise NotImplementedError
2 changes: 1 addition & 1 deletion src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def sample_actions(
save_estimator_outputs: bool = False,
save_logprobs: bool = True,
**policy_kwargs: Any,
) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None,]:
) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None]:
"""Samples actions from the given states.

Args:
Expand Down
Loading