Skip to content

Commit

Permalink
Merge pull request #335 from alexhernandezgarcia/ahg/293-flexible-pol…
Browse files Browse the repository at this point in the history
…icy-definition

[WIP, Policy] Docstring and refactoring on top of PR 327
  • Loading branch information
josephdviviano authored Sep 18, 2024
2 parents 7e02200 + 910a948 commit c9ec03f
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 79 deletions.
82 changes: 69 additions & 13 deletions gflownet/policy/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,44 @@
from abc import ABC, abstractmethod
"""
Base Policy class for GFlowNet policy models.
"""

from typing import Tuple, Union

import torch
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig

from gflownet.envs.base import GFlowNetEnv
from gflownet.utils.common import set_device, set_float_precision


class Policy:
def __init__(self, config, env, device, float_precision, base=None):
def __init__(
self,
config: Union[dict, DictConfig],
env: GFlowNetEnv,
device: Union[str, torch.device],
float_precision: [int, torch.dtype],
base=None,
):
"""
Base Policy class for a :class:`GFlowNetAgent`.
Parameters
----------
config : dict or DictConfig
The configuration dictionary to set up the policy model.
env : GFlowNetEnv
The environment used to train the :class:`GFlowNetAgent`, used to extract
needed properties.
device : str or torch.device
The device to be passed to torch tensors.
float_precision : int or torch.dtype
The floating point precision to be passed to torch tensors.
base: Policy (optional)
A base policy to be used as backbone for the backward policy.
"""
config = self._get_config(config)
# Device and float precision
self.device = set_device(device)
self.float = set_float_precision(float_precision)
Expand All @@ -18,24 +49,49 @@ def __init__(self, config, env, device, float_precision, base=None):
self.output_dim = len(self.fixed_output)
# Optional base model
self.base = base
# Policy type, defaults to uniform
self.type = config.get("type", "uniform")
# Checkpoint, defaults to None
self.checkpoint = config.get("checkpoint", None)
# Instantiate the model
self.model, self.is_model = self.make_model()

@staticmethod
def _get_config(config: Union[dict, DictConfig]) -> Union[dict, DictConfig]:
"""
Returns a configuration dictionary, even if the input is None.
self.parse_config(config)
self.instantiate()
Parameters
----------
config : dict or DictConfig
The configuration dictionary to set up the policy model. It may be None, in
which an empty config is created and the defaults will be used.
def parse_config(self, config):
# If config is null, default to uniform
Returns
-------
config : dict or DictConfig
The configuration dictionary to set up the policy model.
"""
if config is None:
config = OmegaConf.create()
self.type = config.get("type", "uniform")
self.checkpoint = config.get("checkpoint", None)
return config

def make_model(self) -> Tuple[Union[torch.Tensor, torch.nn.Module], bool]:
"""
Instantiates the model of the policy.
def instantiate(self):
Returns
-------
model : torch.tensor or torch.nn.Module
A tensor representing the output of the policy or a torch model.
is_model : bool
True if the policy is a model (for example, a neural network) and False if
it is a fixed tensor (for example to make a uniform distribution).
"""
if self.type == "fixed":
self.model = self.fixed_distribution
self.is_model = False
return self.fixed_distribution, False
elif self.type == "uniform":
self.model = self.uniform_distribution
self.is_model = False
return self.uniform_distribution, False
else:
raise "Policy model type not defined"

Expand Down
63 changes: 32 additions & 31 deletions gflownet/policy/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,33 @@


class CNNPolicy(Policy):
def __init__(self, config, env, device, float_precision, base=None):
def __init__(self, **kwargs):
config = self._get_config(kwargs["config"])
# Shared weights, defaults to False
self.shared_weights = config.get("shared_weights", False)
# Reload checkpoint, defaults to False
self.reload_ckpt = config.get("reload_ckpt", False)
# CNN features: number of layers, number of channels, kernel sizes, strides
self.n_layers = config.get("n_layers", 3)
self.channels = config.get("channels", [16] * self.n_layers)
self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers)
self.strides = config.get("strides", [(1, 1)] * self.n_layers)
# Environment
# TODO: rethink whether storing the whole environment is needed
self.env = env
super().__init__(
config=config,
env=env,
device=device,
float_precision=float_precision,
base=base,
)
# Base init
super().__init__(**kwargs)

def make_cnn(self):
def make_model(self):
"""
Defines an CNN with no top layer activation
Instantiates a CNN with no top layer activation.
Returns
-------
model : torch.nn.Module
A torch model containing the CNN.
is_model : bool
True because a CNN is a model.
"""
if self.shared_weights and self.base is not None:
layers = list(self.base.model.children())[:-1]
Expand All @@ -27,14 +41,15 @@ def make_cnn(self):
)

model = nn.Sequential(*layers, last_layer).to(self.device)
return model
return model, True

current_channels = 1
conv_module = nn.Sequential()

if len(self.kernel_sizes) != self.n_layers:
raise ValueError(
f"Inconsistent dimensions kernel_sizes != n_layers, {len(self.kernel_sizes)} != {self.n_layers}"
f"Inconsistent dimensions kernel_sizes != n_layers, "
"{len(self.kernel_sizes)} != {self.n_layers}"
)

for i in range(self.n_layers):
Expand All @@ -59,33 +74,19 @@ def make_cnn(self):
in_channels = conv_module(dummy_input).numel()
if in_channels >= 500_000: # TODO: this could better be handled
raise RuntimeWarning(
"Input channels for the dense layer are too big, this will increase number of parameters"
"Input channels for the dense layer are too big, this will "
"increase number of parameters"
)
except RuntimeError as e:
raise RuntimeError(
"Failed during convolution operation. Ensure that the kernel sizes and strides are appropriate for the input dimensions."
"Failed during convolution operation. Ensure that the kernel sizes "
"and strides are appropriate for the input dimensions."
) from e

model = nn.Sequential(
conv_module, nn.Flatten(), nn.Linear(in_channels, self.output_dim)
)
return model.to(self.device)

def parse_config(self, config):
super().parse_config(config)
if config is None:
config = OmegaConf.create()
self.checkpoint = config.get("checkpoint", None)
self.shared_weights = config.get("shared_weights", False)
self.reload_ckpt = config.get("reload_ckpt", False)
self.n_layers = config.get("n_layers", 3)
self.channels = config.get("channels", [16] * self.n_layers)
self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers)
self.strides = config.get("strides", [(1, 1)] * self.n_layers)

def instantiate(self):
self.model = self.make_cnn()
self.is_model = True
return model.to(self.device), True

def __call__(self, states):
states = states.unsqueeze(1) # (batch_size, channels, height, width)
Expand Down
68 changes: 33 additions & 35 deletions gflownet/policy/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,48 @@


class MLPPolicy(Policy):
def __init__(self, config, env, device, float_precision, base=None):
super().__init__(
config=config,
env=env,
device=device,
float_precision=float_precision,
base=base,
)
def __init__(self, **kwargs):
config = self._get_config(kwargs["config"])
# Shared weights, defaults to False
self.shared_weights = config.get("shared_weights", False)
# Reload checkpoint, defaults to False
self.reload_ckpt = config.get("reload_ckpt", False)
# MLP features: number of layers, number of hidden units, tail, etc.
self.n_layers = config.get("n_layers", 2)
self.n_hid = config.get("n_hid", 128)
self.tail = config.get("tail", [])
# Base init
super().__init__(**kwargs)

def make_mlp(self, activation):
def make_model(self, activation: nn.Module = nn.LeakyReLU()):
"""
Defines an MLP with no top layer activation
If share_weight == True,
baseModel (the model with which weights are to be shared) must be provided
Args
----
layers_dim : list
Dimensionality of each layer
activation : Activation
Activation function
Instantiates an MLP with no top layer activation as the policy model.
If self.shared_weights is True, the base model with which weights are to be
shared must be provided.
Parameters
----------
activation : nn.Module
Activation function of the MLP layers
Returns
-------
model : torch.tensor or torch.nn.Module
A torch model containing the MLP.
is_model : bool
True because an MLP is a model.
"""
activation.to(self.device)

if self.shared_weights == True and self.base is not None:
mlp = nn.Sequential(
self.base.model[:-1],
nn.Linear(
self.base.model[-1].in_features, self.base.model[-1].out_features
),
)
return mlp
return mlp, True
elif self.shared_weights == False:
layers_dim = (
[self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)]
Expand All @@ -53,26 +66,11 @@ def make_mlp(self, activation):
+ self.tail
)
)
return mlp
return mlp, True
else:
raise ValueError(
"Base Model must be provided when shared_weights is set to True"
)

def parse_config(self, config):
super().parse_config(config)
if config is None:
config = OmegaConf.create()
self.checkpoint = config.get("checkpoint", None)
self.shared_weights = config.get("shared_weights", False)
self.n_hid = config.get("n_hid", 128)
self.n_layers = config.get("n_layers", 2)
self.tail = config.get("tail", [])
self.reload_ckpt = config.get("reload_ckpt", False)

def instantiate(self):
self.model = self.make_mlp(nn.LeakyReLU()).to(self.device)
self.is_model = True

def __call__(self, states):
return self.model(states)

0 comments on commit c9ec03f

Please sign in to comment.