From e6a14efa1ef546f39f07b55bbcf68a606d93f304 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 10 Jul 2024 18:01:12 +0200 Subject: [PATCH 1/8] Add docstring and typing to __init__ of policy base. --- gflownet/policy/base.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index eb98e9bd..e6bc5d6b 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -1,13 +1,41 @@ -from abc import ABC, abstractmethod +""" +Base Policy class for GFlowNet policy models. +""" + +from typing import Union import torch from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig +from gflownet.envs.base import GFlowNetEnv from gflownet.utils.common import set_device, set_float_precision class Policy: - def __init__(self, config, env, device, float_precision, base=None): + def __init__( + self, + config: Union[dict, DictConfig], + env: GFlowNetEnv, + device: Union[str, torch.device], + float_precision: [int, torch.dtype], + base=None, + ): + """ + Base Policy class for a :class:`GFlowNetAgent`. + + Parameters + ---------- + config : dict or DictConfig + The configuration dictionary to set up the policy model. + env : GFlowNetEnv + The environment used to train the :class:`GFlowNetAgent`, used to extract + needed properties. + device : str or torch.device + The device to be passed to torch tensors. + float_precision : int or torch.dtype + The floating point precision to be passed to torch tensors. + """ # Device and float precision self.device = set_device(device) self.float = set_float_precision(float_precision) From 57b0b136bcc936e61b60f3b715dec86bd71fdf6f Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 10 Jul 2024 18:04:47 +0200 Subject: [PATCH 2/8] Use kwargs instead of listing parameters explicitly --- gflownet/policy/cnn.py | 10 ++-------- gflownet/policy/mlp.py | 10 ++-------- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/gflownet/policy/cnn.py b/gflownet/policy/cnn.py index f8343d88..e501ab68 100644 --- a/gflownet/policy/cnn.py +++ b/gflownet/policy/cnn.py @@ -6,15 +6,9 @@ class CNNPolicy(Policy): - def __init__(self, config, env, device, float_precision, base=None): + def __init__(self, **kwargs): + super().__init__(**kwargs) self.env = env - super().__init__( - config=config, - env=env, - device=device, - float_precision=float_precision, - base=base, - ) def make_cnn(self): """ diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index b90f6e52..aa332a5d 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -5,14 +5,8 @@ class MLPPolicy(Policy): - def __init__(self, config, env, device, float_precision, base=None): - super().__init__( - config=config, - env=env, - device=device, - float_precision=float_precision, - base=base, - ) + def __init__(self, **kwargs): + super().__init__(**kwargs) def make_mlp(self, activation): """ From 178a08e57606334219a0a89cc84d0b980617613d Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 10 Jul 2024 18:39:05 +0200 Subject: [PATCH 3/8] Policy MLP: docstring and typing. --- gflownet/policy/mlp.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index aa332a5d..a1d3dc32 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -8,17 +8,17 @@ class MLPPolicy(Policy): def __init__(self, **kwargs): super().__init__(**kwargs) - def make_mlp(self, activation): + def make_mlp(self, activation: nn.Module): """ Defines an MLP with no top layer activation - If share_weight == True, - baseModel (the model with which weights are to be shared) must be provided - Args - ---- - layers_dim : list - Dimensionality of each layer - activation : Activation - Activation function + + If config.share_weights is True, the base model with which weights are to be + shared must be provided. + + Parameters + ---------- + activation : nn.Module + Activation function of the MLP layers """ if self.shared_weights == True and self.base is not None: mlp = nn.Sequential( From ace8a28f3e2176428ac260e5b4519dc7b6f70f87 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 10 Jul 2024 19:05:25 +0200 Subject: [PATCH 4/8] Get rid of parse_config and include its content in __init__ --- gflownet/policy/base.py | 15 +++++++-------- gflownet/policy/cnn.py | 26 +++++++++++++------------- gflownet/policy/mlp.py | 22 ++++++++++------------ 3 files changed, 30 insertions(+), 33 deletions(-) diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index e6bc5d6b..68fb489f 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -36,6 +36,9 @@ def __init__( float_precision : int or torch.dtype The floating point precision to be passed to torch tensors. """ + # If config is None, instantiate an empty config (defaults will be used) + if config is None: + config = OmegaConf.create() # Device and float precision self.device = set_device(device) self.float = set_float_precision(float_precision) @@ -46,16 +49,12 @@ def __init__( self.output_dim = len(self.fixed_output) # Optional base model self.base = base - - self.parse_config(config) - self.instantiate() - - def parse_config(self, config): - # If config is null, default to uniform - if config is None: - config = OmegaConf.create() + # Policy type, defaults to uniform self.type = config.get("type", "uniform") + # Checkpoint, defaults to None self.checkpoint = config.get("checkpoint", None) + # Instantiate the model + self.instantiate() def instantiate(self): if self.type == "fixed": diff --git a/gflownet/policy/cnn.py b/gflownet/policy/cnn.py index e501ab68..7d80d127 100644 --- a/gflownet/policy/cnn.py +++ b/gflownet/policy/cnn.py @@ -7,8 +7,20 @@ class CNNPolicy(Policy): def __init__(self, **kwargs): - super().__init__(**kwargs) + # Shared weights, defaults to False + self.shared_weights = config.get("shared_weights", False) + # Reload checkpoint, defaults to False + self.reload_ckpt = config.get("reload_ckpt", False) + # CNN features: number of layers, number of channels, kernel sizes, strides + self.n_layers = config.get("n_layers", 3) + self.channels = config.get("channels", [16] * self.n_layers) + self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers) + self.strides = config.get("strides", [(1, 1)] * self.n_layers) + # Environment + # TODO: rethink whether storing the whole environment is needed self.env = env + # Base init + super().__init__(**kwargs) def make_cnn(self): """ @@ -65,18 +77,6 @@ def make_cnn(self): ) return model.to(self.device) - def parse_config(self, config): - super().parse_config(config) - if config is None: - config = OmegaConf.create() - self.checkpoint = config.get("checkpoint", None) - self.shared_weights = config.get("shared_weights", False) - self.reload_ckpt = config.get("reload_ckpt", False) - self.n_layers = config.get("n_layers", 3) - self.channels = config.get("channels", [16] * self.n_layers) - self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers) - self.strides = config.get("strides", [(1, 1)] * self.n_layers) - def instantiate(self): self.model = self.make_cnn() self.is_model = True diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index a1d3dc32..bf012e1b 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -6,13 +6,22 @@ class MLPPolicy(Policy): def __init__(self, **kwargs): + # Shared weights, defaults to False + self.shared_weights = config.get("shared_weights", False) + # Reload checkpoint, defaults to False + self.reload_ckpt = config.get("reload_ckpt", False) + # MLP features: number of layers, number of hidden units, tail, etc. + self.n_layers = config.get("n_layers", 2) + self.n_hid = config.get("n_hid", 128) + self.tail = config.get("tail", []) + # Base init super().__init__(**kwargs) def make_mlp(self, activation: nn.Module): """ Defines an MLP with no top layer activation - If config.share_weights is True, the base model with which weights are to be + If self.shared_weights is True, the base model with which weights are to be shared must be provided. Parameters @@ -53,17 +62,6 @@ def make_mlp(self, activation: nn.Module): "Base Model must be provided when shared_weights is set to True" ) - def parse_config(self, config): - super().parse_config(config) - if config is None: - config = OmegaConf.create() - self.checkpoint = config.get("checkpoint", None) - self.shared_weights = config.get("shared_weights", False) - self.n_hid = config.get("n_hid", 128) - self.n_layers = config.get("n_layers", 2) - self.tail = config.get("tail", []) - self.reload_ckpt = config.get("reload_ckpt", False) - def instantiate(self): self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) self.is_model = True From 8e6f03d5f37a7272e7a98a9bec516cfaafd8e5cc Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 10 Jul 2024 19:33:36 +0200 Subject: [PATCH 5/8] Combine instantiate and make_* into a single method make_model() --- gflownet/policy/base.py | 21 +++++++++++++++------ gflownet/policy/cnn.py | 28 +++++++++++++++++----------- gflownet/policy/mlp.py | 21 +++++++++++++-------- 3 files changed, 45 insertions(+), 25 deletions(-) diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index 68fb489f..fb5e8615 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -54,15 +54,24 @@ def __init__( # Checkpoint, defaults to None self.checkpoint = config.get("checkpoint", None) # Instantiate the model - self.instantiate() + self.model, self.is_model = self.make_model() - def instantiate(self): + def make_model(self) -> Tuple[Union[torch.Tensor, torch.nn.Module], bool]: + """ + Instantiates the model of the policy. + + Returns + ------- + model : torch.tensor or torch.nn.Module + A tensor representing the output of the policy or a torch model. + is_model : bool + True if the policy is a model (for example, a neural network) and False if + it is a fixed tensor (for example to make a uniform distribution). + """ if self.type == "fixed": - self.model = self.fixed_distribution - self.is_model = False + return self.fixed_distribution, False elif self.type == "uniform": - self.model = self.uniform_distribution - self.is_model = False + return self.uniform_distribution, False else: raise "Policy model type not defined" diff --git a/gflownet/policy/cnn.py b/gflownet/policy/cnn.py index 7d80d127..1a7d191b 100644 --- a/gflownet/policy/cnn.py +++ b/gflownet/policy/cnn.py @@ -22,9 +22,16 @@ def __init__(self, **kwargs): # Base init super().__init__(**kwargs) - def make_cnn(self): + def make_model(self): """ - Defines an CNN with no top layer activation + Instantiates a CNN with no top layer activation. + + Returns + ------- + model : torch.nn.Module + A torch model containing the CNN. + is_model : bool + True because a CNN is a model. """ if self.shared_weights and self.base is not None: layers = list(self.base.model.children())[:-1] @@ -33,14 +40,15 @@ def make_cnn(self): ) model = nn.Sequential(*layers, last_layer).to(self.device) - return model + return model, True current_channels = 1 conv_module = nn.Sequential() if len(self.kernel_sizes) != self.n_layers: raise ValueError( - f"Inconsistent dimensions kernel_sizes != n_layers, {len(self.kernel_sizes)} != {self.n_layers}" + f"Inconsistent dimensions kernel_sizes != n_layers, " + "{len(self.kernel_sizes)} != {self.n_layers}" ) for i in range(self.n_layers): @@ -65,21 +73,19 @@ def make_cnn(self): in_channels = conv_module(dummy_input).numel() if in_channels >= 500_000: # TODO: this could better be handled raise RuntimeWarning( - "Input channels for the dense layer are too big, this will increase number of parameters" + "Input channels for the dense layer are too big, this will " + "increase number of parameters" ) except RuntimeError as e: raise RuntimeError( - "Failed during convolution operation. Ensure that the kernel sizes and strides are appropriate for the input dimensions." + "Failed during convolution operation. Ensure that the kernel sizes " + "and strides are appropriate for the input dimensions." ) from e model = nn.Sequential( conv_module, nn.Flatten(), nn.Linear(in_channels, self.output_dim) ) - return model.to(self.device) - - def instantiate(self): - self.model = self.make_cnn() - self.is_model = True + return model.to(self.device), True def __call__(self, states): states = states.unsqueeze(1) # (batch_size, channels, height, width) diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index bf012e1b..534ba369 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -17,9 +17,9 @@ def __init__(self, **kwargs): # Base init super().__init__(**kwargs) - def make_mlp(self, activation: nn.Module): + def make_model(self, activation: nn.Module = nn.LeakyReLU()): """ - Defines an MLP with no top layer activation + Instantiates an MLP with no top layer activation as the policy model. If self.shared_weights is True, the base model with which weights are to be shared must be provided. @@ -28,7 +28,16 @@ def make_mlp(self, activation: nn.Module): ---------- activation : nn.Module Activation function of the MLP layers + + Returns + ------- + model : torch.tensor or torch.nn.Module + A torch model containing the MLP. + is_model : bool + True because an MLP is a model. """ + activation.to(self.device) + if self.shared_weights == True and self.base is not None: mlp = nn.Sequential( self.base.model[:-1], @@ -36,7 +45,7 @@ def make_mlp(self, activation: nn.Module): self.base.model[-1].in_features, self.base.model[-1].out_features ), ) - return mlp + return mlp, True elif self.shared_weights == False: layers_dim = ( [self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)] @@ -56,15 +65,11 @@ def make_mlp(self, activation: nn.Module): + self.tail ) ) - return mlp + return mlp, True else: raise ValueError( "Base Model must be provided when shared_weights is set to True" ) - def instantiate(self): - self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) - self.is_model = True - def __call__(self, states): return self.model(states) From 774c411a69fcc733b94b7cae2e1a29669cd5640d Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 10 Jul 2024 20:02:45 +0200 Subject: [PATCH 6/8] Missing import --- gflownet/policy/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index fb5e8615..81f0d4b3 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -2,7 +2,7 @@ Base Policy class for GFlowNet policy models. """ -from typing import Union +from typing import Tuple, Union import torch from omegaconf import OmegaConf From 9520315e1609b77cd333ea59a1977a7601f3e16d Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 10 Jul 2024 20:13:56 +0200 Subject: [PATCH 7/8] Fix config issue by implementing _get_config() --- gflownet/policy/base.py | 24 +++++++++++++++++++++--- gflownet/policy/cnn.py | 1 + gflownet/policy/mlp.py | 1 + 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index 81f0d4b3..73329382 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -36,9 +36,7 @@ def __init__( float_precision : int or torch.dtype The floating point precision to be passed to torch tensors. """ - # If config is None, instantiate an empty config (defaults will be used) - if config is None: - config = OmegaConf.create() + config = self._get_config(config) # Device and float precision self.device = set_device(device) self.float = set_float_precision(float_precision) @@ -56,6 +54,26 @@ def __init__( # Instantiate the model self.model, self.is_model = self.make_model() + @staticmethod + def _get_config(config: Union[dict, DictConfig]) -> Union[dict, DictConfig]: + """ + Returns a configuration dictionary, even if the input is None. + + Parameters + ---------- + config : dict or DictConfig + The configuration dictionary to set up the policy model. It may be None, in + which an empty config is created and the defaults will be used. + + Returns + ------- + config : dict or DictConfig + The configuration dictionary to set up the policy model. + """ + if config is None: + config = OmegaConf.create() + return config + def make_model(self) -> Tuple[Union[torch.Tensor, torch.nn.Module], bool]: """ Instantiates the model of the policy. diff --git a/gflownet/policy/cnn.py b/gflownet/policy/cnn.py index 1a7d191b..52693fc3 100644 --- a/gflownet/policy/cnn.py +++ b/gflownet/policy/cnn.py @@ -7,6 +7,7 @@ class CNNPolicy(Policy): def __init__(self, **kwargs): + config = self._get_config(kwargs["config"]) # Shared weights, defaults to False self.shared_weights = config.get("shared_weights", False) # Reload checkpoint, defaults to False diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index 534ba369..8f4fbc80 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -6,6 +6,7 @@ class MLPPolicy(Policy): def __init__(self, **kwargs): + config = self._get_config(kwargs["config"]) # Shared weights, defaults to False self.shared_weights = config.get("shared_weights", False) # Reload checkpoint, defaults to False From 910a9480f69730b392d1edc9f0483dc4d03936bf Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 10 Jul 2024 20:22:45 +0200 Subject: [PATCH 8/8] Docstring for base argument --- gflownet/policy/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index 73329382..5d26e0c1 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -35,6 +35,8 @@ def __init__( The device to be passed to torch tensors. float_precision : int or torch.dtype The floating point precision to be passed to torch tensors. + base: Policy (optional) + A base policy to be used as backbone for the backward policy. """ config = self._get_config(config) # Device and float precision