Skip to content

Commit

Permalink
Fix config issue by implementing _get_config()
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed Jul 10, 2024
1 parent 774c411 commit 9520315
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
24 changes: 21 additions & 3 deletions gflownet/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def __init__(
float_precision : int or torch.dtype
The floating point precision to be passed to torch tensors.
"""
# If config is None, instantiate an empty config (defaults will be used)
if config is None:
config = OmegaConf.create()
config = self._get_config(config)
# Device and float precision
self.device = set_device(device)
self.float = set_float_precision(float_precision)
Expand All @@ -56,6 +54,26 @@ def __init__(
# Instantiate the model
self.model, self.is_model = self.make_model()

@staticmethod
def _get_config(config: Union[dict, DictConfig]) -> Union[dict, DictConfig]:
"""
Returns a configuration dictionary, even if the input is None.
Parameters
----------
config : dict or DictConfig
The configuration dictionary to set up the policy model. It may be None, in
which an empty config is created and the defaults will be used.
Returns
-------
config : dict or DictConfig
The configuration dictionary to set up the policy model.
"""
if config is None:
config = OmegaConf.create()
return config

def make_model(self) -> Tuple[Union[torch.Tensor, torch.nn.Module], bool]:
"""
Instantiates the model of the policy.
Expand Down
1 change: 1 addition & 0 deletions gflownet/policy/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

class CNNPolicy(Policy):
def __init__(self, **kwargs):
config = self._get_config(kwargs["config"])
# Shared weights, defaults to False
self.shared_weights = config.get("shared_weights", False)
# Reload checkpoint, defaults to False
Expand Down
1 change: 1 addition & 0 deletions gflownet/policy/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

class MLPPolicy(Policy):
def __init__(self, **kwargs):
config = self._get_config(kwargs["config"])
# Shared weights, defaults to False
self.shared_weights = config.get("shared_weights", False)
# Reload checkpoint, defaults to False
Expand Down

0 comments on commit 9520315

Please sign in to comment.