Skip to content

Commit

Permalink
[WIP] Refactor of Policy: get rid of config parameter and set argumen…
Browse files Browse the repository at this point in the history
…ts directly; isolate random, uniform and fixed policies; docstring, etc.
  • Loading branch information
alexhernandezgarcia committed Jul 10, 2024
1 parent 910a948 commit d85ad5f
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 120 deletions.
2 changes: 1 addition & 1 deletion config/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defaults:
- _self_
- env: grid
- gflownet: trajectorybalance
- policy: mlp_${gflownet}
- policy: mlp
- proxy: uniform
- logger: wandb
- user: default
Expand Down
5 changes: 5 additions & 0 deletions config/policy/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_target_: gflownet.policy.base.Policy

base: null
shared_weights: False
checkpoint: null
23 changes: 9 additions & 14 deletions config/policy/cnn.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
_target_: gflownet.policy.cnn.CNNPolicy

shared: null
defaults:
- base

forward:
n_layers: 2
channels: [16, 32]
kernel_sizes: [[3, 3], [2, 2]] # Each tuple represents (height, width)
strides: [[1, 1], [1, 1]] # Each tuple represents (vertical_stride, horizontal_stride)
checkpoint: null
reload_ckpt: False
_target_: gflownet.policy.cnn.CNNPolicy

backward:
shared_weights: True
checkpoint: null
reload_ckpt: False
n_layers: 2
channels: [16, 32]
# Kernels: Each tuple represents (height, width)
kernels: [[3, 3], [2, 2]]
# Strides: Each tuple represents (vertical_stride, horizontal_stride)
strides: [[1, 1], [1, 1]]
15 changes: 6 additions & 9 deletions config/policy/mlp.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
_target_: gflownet.policy.mlp.MLPPolicy

shared: null
defaults:
- base

forward:
n_hid: 128
n_layers: 2
checkpoint: null
reload_ckpt: False
_target_: gflownet.policy.mlp.MLPPolicy

backward: null
n_hid: 128
n_layers: 2
tail: []
88 changes: 18 additions & 70 deletions gflownet/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,18 @@
class Policy:
def __init__(
self,
config: Union[dict, DictConfig],
env: GFlowNetEnv,
device: Union[str, torch.device],
float_precision: [int, torch.dtype],
base=None,
shared_weights: bool = False,
checkpoint: str = None,
):
"""
Base Policy class for a :class:`GFlowNetAgent`.
Parameters
----------
config : dict or DictConfig
The configuration dictionary to set up the policy model.
env : GFlowNetEnv
The environment used to train the :class:`GFlowNetAgent`, used to extract
needed properties.
Expand All @@ -37,8 +36,12 @@ def __init__(
The floating point precision to be passed to torch tensors.
base: Policy (optional)
A base policy to be used as backbone for the backward policy.
shared_weights: bool (optional)
Whether the weights of the backward policy are shared with the (base)
forward policy model. Defaults to False.
checkpoint: str (optional)
The path the a checkpoint file to be reloaded as the policy model.
"""
config = self._get_config(config)
# Device and float precision
self.device = set_device(device)
self.float = set_float_precision(float_precision)
Expand All @@ -49,36 +52,17 @@ def __init__(
self.output_dim = len(self.fixed_output)
# Optional base model
self.base = base
# Policy type, defaults to uniform
self.type = config.get("type", "uniform")
# Shared weights, defaults to False
self.shared_weights = shared_weights
# Checkpoint, defaults to None
self.checkpoint = config.get("checkpoint", None)
self.checkpoint = checkpoint
# Instantiate the model
self.model, self.is_model = self.make_model()

@staticmethod
def _get_config(config: Union[dict, DictConfig]) -> Union[dict, DictConfig]:
"""
Returns a configuration dictionary, even if the input is None.
Parameters
----------
config : dict or DictConfig
The configuration dictionary to set up the policy model. It may be None, in
which an empty config is created and the defaults will be used.
Returns
-------
config : dict or DictConfig
The configuration dictionary to set up the policy model.
"""
if config is None:
config = OmegaConf.create()
return config

@abstractmethod
def make_model(self) -> Tuple[Union[torch.Tensor, torch.nn.Module], bool]:
"""
Instantiates the model of the policy.
Instantiates the model or fixed tensor of the policy.
Returns
-------
Expand All @@ -88,51 +72,15 @@ def make_model(self) -> Tuple[Union[torch.Tensor, torch.nn.Module], bool]:
True if the policy is a model (for example, a neural network) and False if
it is a fixed tensor (for example to make a uniform distribution).
"""
if self.type == "fixed":
return self.fixed_distribution, False
elif self.type == "uniform":
return self.uniform_distribution, False
else:
raise "Policy model type not defined"

def __call__(self, states):
return self.model(states)

def fixed_distribution(self, states):
"""
Returns the fixed distribution specified by the environment.
Parameters
----------
states : tensor
The states for which the fixed distribution is to be returned
"""
return torch.tile(self.fixed_output, (len(states), 1)).to(
dtype=self.float, device=self.device
)

def random_distribution(self, states):
"""
Returns the random distribution specified by the environment.
Parameters
----------
states : tensor
The states for which the random distribution is to be returned
"""
return torch.tile(self.random_output, (len(states), 1)).to(
dtype=self.float, device=self.device
)
pass

def uniform_distribution(self, states):
def __call__(self, states: torch.Tensor):
"""
Return action logits (log probabilities) from a uniform distribution
Returns the outputs of the policy model on a batch of states.
Parameters
----------
states : tensor
The states for which the uniform distribution is to be returned
states : torch.Tensor
A batch of states in policy format.
"""
return torch.ones(
(len(states), self.output_dim), dtype=self.float, device=self.device
)
return self.model(states)
58 changes: 44 additions & 14 deletions gflownet/policy/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,47 @@


class CNNPolicy(Policy):
def __init__(self, **kwargs):
config = self._get_config(kwargs["config"])
# Shared weights, defaults to False
self.shared_weights = config.get("shared_weights", False)
# Reload checkpoint, defaults to False
self.reload_ckpt = config.get("reload_ckpt", False)
def __init__(
self,
n_layers: int = 2,
channels: Union[int, List] = [16, 32],
kernels: Union[int, List] = [[3, 3], [2, 2]],
strides: Union[int, List] = [[1, 1], [1, 1]],
**kwargs,
):
"""
CNN Policy class for a :class:`GFlowNetAgent`.
Parameters
----------
n_layers : int
The number of layers in the CNN architecture.
channels : int or list
The number of channels in the convolutional layers or a list of number of
channels for each layer.
kernels : int or list
The kernel size of the convolutions or a list of kernel sizes for each
layer.
strides : int or list
The stride of the convolutions or a list of strides for each layer.
"""
# CNN features: number of layers, number of channels, kernel sizes, strides
self.n_layers = config.get("n_layers", 3)
self.channels = config.get("channels", [16] * self.n_layers)
self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers)
self.strides = config.get("strides", [(1, 1)] * self.n_layers)
self.n_layers = n_layers
if isinstance(channels, int):
self.channels = [channels] * self.n_layers
else:
# TODO: check if valid
self.channels = channels
if isinstance(kernels, int):
self.kernels = [(kernels, kernels)] * self.n_layers
else:
# TODO: check if valid
self.kernels = kernels
if isinstance(strides, int):
self.strides = [(stride, stride)] * self.n_layers
else:
# TODO: check if valid
self.strides = strides
# Environment
# TODO: rethink whether storing the whole environment is needed
self.env = env
Expand Down Expand Up @@ -46,10 +76,10 @@ def make_model(self):
current_channels = 1
conv_module = nn.Sequential()

if len(self.kernel_sizes) != self.n_layers:
if len(self.kernels) != self.n_layers:
raise ValueError(
f"Inconsistent dimensions kernel_sizes != n_layers, "
"{len(self.kernel_sizes)} != {self.n_layers}"
f"Inconsistent dimensions kernels != n_layers, "
"{len(self.kernels)} != {self.n_layers}"
)

for i in range(self.n_layers):
Expand All @@ -58,7 +88,7 @@ def make_model(self):
nn.Conv2d(
in_channels=current_channels,
out_channels=self.channels[i],
kernel_size=tuple(self.kernel_sizes[i]),
kernel_size=tuple(self.kernels[i]),
stride=tuple(self.strides[i]),
padding=0,
padding_mode="zeros", # Constant zero padding
Expand Down
25 changes: 25 additions & 0 deletions gflownet/policy/fixed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from gflownet.policy.base import Policy


class FixedPolicy(Policy):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def make_model(self):
"""
Instantiates the policy model as a fixed tensor with the values of
`self.fixed_output` defined by the environment.
Returns
-------
model : torch.tensor
The tensor `self.fixed_output` defined by the environment.
is_model : bool
False, because the policy is not a model.
"""
return (
torch.tile(self.fixed_output, (len(states), 1)).to(
dtype=self.float, device=self.device
),
False,
)
30 changes: 18 additions & 12 deletions gflownet/policy/mlp.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
from typing import List

from omegaconf import OmegaConf
from torch import nn

from gflownet.policy.base import Policy


class MLPPolicy(Policy):
def __init__(self, **kwargs):
config = self._get_config(kwargs["config"])
# Shared weights, defaults to False
self.shared_weights = config.get("shared_weights", False)
# Reload checkpoint, defaults to False
self.reload_ckpt = config.get("reload_ckpt", False)
def __init__(self, n_layers: int = 2, n_hid: int = 128, tail: List = [], **kwargs):
"""
MLP Policy class for a :class:`GFlowNetAgent`.
Parameters
----------
n_layers : int
The number of layers in the MLP architecture.
n_hid : int
The number of hidden units per layer.
tail : list
A list of layers to conform the top (tail) of the MLP architecture.
"""
# MLP features: number of layers, number of hidden units, tail, etc.
self.n_layers = config.get("n_layers", 2)
self.n_hid = config.get("n_hid", 128)
self.tail = config.get("tail", [])
self.n_layers = n_layers
self.n_hid = n_hid
self.tail = tail
# Base init
super().__init__(**kwargs)

Expand Down Expand Up @@ -71,6 +80,3 @@ def make_model(self, activation: nn.Module = nn.LeakyReLU()):
raise ValueError(
"Base Model must be provided when shared_weights is set to True"
)

def __call__(self, states):
return self.model(states)
25 changes: 25 additions & 0 deletions gflownet/policy/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from gflownet.policy.base import Policy


class RandomPolicy(Policy):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def make_model(self):
"""
Instantiates the policy model as a fixed tensor with the values of
`self.random_output` defined by the environment.
Returns
-------
model : torch.tensor
The tensor `self.random_output` defined by the environment.
is_model : bool
False, because the policy is not a model.
"""
return (
torch.tile(self.random_output, (len(states), 1)).to(
dtype=self.float, device=self.device
),
False,
)
25 changes: 25 additions & 0 deletions gflownet/policy/uniform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from gflownet.policy.base import Policy


class UniformPolicy(Policy):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def make_model(self):
"""
Instantiates the policy model as a fixed tensor of ones, to define a uniform
distribution over the action space.
Returns
-------
model : torch.tensor
A tensor of `self.output_dim` ones.
is_model : bool
False, because the policy is not a model.
"""
return (
torch.ones(
(len(states), self.output_dim), dtype=self.float, device=self.device
),
False,
)

0 comments on commit d85ad5f

Please sign in to comment.