diff --git a/config/main.yaml b/config/main.yaml index 22c106d3..6aa52c09 100644 --- a/config/main.yaml +++ b/config/main.yaml @@ -2,7 +2,7 @@ defaults: - _self_ - env: grid - gflownet: trajectorybalance - - policy: mlp_${gflownet} + - policy: mlp - proxy: uniform - logger: wandb - user: default diff --git a/config/policy/base.yaml b/config/policy/base.yaml new file mode 100644 index 00000000..c2b746f1 --- /dev/null +++ b/config/policy/base.yaml @@ -0,0 +1,5 @@ +_target_: gflownet.policy.base.Policy + +base: null +shared_weights: False +checkpoint: null diff --git a/config/policy/cnn.yaml b/config/policy/cnn.yaml index 98818bd8..6874e484 100644 --- a/config/policy/cnn.yaml +++ b/config/policy/cnn.yaml @@ -1,16 +1,11 @@ -_target_: gflownet.policy.cnn.CNNPolicy - -shared: null +defaults: + - base -forward: - n_layers: 2 - channels: [16, 32] - kernel_sizes: [[3, 3], [2, 2]] # Each tuple represents (height, width) - strides: [[1, 1], [1, 1]] # Each tuple represents (vertical_stride, horizontal_stride) - checkpoint: null - reload_ckpt: False +_target_: gflownet.policy.cnn.CNNPolicy -backward: - shared_weights: True - checkpoint: null - reload_ckpt: False +n_layers: 2 +channels: [16, 32] +# Kernels: Each tuple represents (height, width) +kernels: [[3, 3], [2, 2]] +# Strides: Each tuple represents (vertical_stride, horizontal_stride) +strides: [[1, 1], [1, 1]] diff --git a/config/policy/mlp.yaml b/config/policy/mlp.yaml index b9e2f9a6..b9ecc99f 100644 --- a/config/policy/mlp.yaml +++ b/config/policy/mlp.yaml @@ -1,11 +1,8 @@ -_target_: gflownet.policy.mlp.MLPPolicy - -shared: null +defaults: + - base -forward: - n_hid: 128 - n_layers: 2 - checkpoint: null - reload_ckpt: False +_target_: gflownet.policy.mlp.MLPPolicy -backward: null +n_hid: 128 +n_layers: 2 +tail: [] diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index 5d26e0c1..a6ea4971 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -15,19 +15,18 @@ class Policy: def __init__( self, - config: Union[dict, DictConfig], env: GFlowNetEnv, device: Union[str, torch.device], float_precision: [int, torch.dtype], base=None, + shared_weights: bool = False, + checkpoint: str = None, ): """ Base Policy class for a :class:`GFlowNetAgent`. Parameters ---------- - config : dict or DictConfig - The configuration dictionary to set up the policy model. env : GFlowNetEnv The environment used to train the :class:`GFlowNetAgent`, used to extract needed properties. @@ -37,8 +36,12 @@ def __init__( The floating point precision to be passed to torch tensors. base: Policy (optional) A base policy to be used as backbone for the backward policy. + shared_weights: bool (optional) + Whether the weights of the backward policy are shared with the (base) + forward policy model. Defaults to False. + checkpoint: str (optional) + The path the a checkpoint file to be reloaded as the policy model. """ - config = self._get_config(config) # Device and float precision self.device = set_device(device) self.float = set_float_precision(float_precision) @@ -49,36 +52,17 @@ def __init__( self.output_dim = len(self.fixed_output) # Optional base model self.base = base - # Policy type, defaults to uniform - self.type = config.get("type", "uniform") + # Shared weights, defaults to False + self.shared_weights = shared_weights # Checkpoint, defaults to None - self.checkpoint = config.get("checkpoint", None) + self.checkpoint = checkpoint # Instantiate the model self.model, self.is_model = self.make_model() - @staticmethod - def _get_config(config: Union[dict, DictConfig]) -> Union[dict, DictConfig]: - """ - Returns a configuration dictionary, even if the input is None. - - Parameters - ---------- - config : dict or DictConfig - The configuration dictionary to set up the policy model. It may be None, in - which an empty config is created and the defaults will be used. - - Returns - ------- - config : dict or DictConfig - The configuration dictionary to set up the policy model. - """ - if config is None: - config = OmegaConf.create() - return config - + @abstractmethod def make_model(self) -> Tuple[Union[torch.Tensor, torch.nn.Module], bool]: """ - Instantiates the model of the policy. + Instantiates the model or fixed tensor of the policy. Returns ------- @@ -88,51 +72,15 @@ def make_model(self) -> Tuple[Union[torch.Tensor, torch.nn.Module], bool]: True if the policy is a model (for example, a neural network) and False if it is a fixed tensor (for example to make a uniform distribution). """ - if self.type == "fixed": - return self.fixed_distribution, False - elif self.type == "uniform": - return self.uniform_distribution, False - else: - raise "Policy model type not defined" - - def __call__(self, states): - return self.model(states) - - def fixed_distribution(self, states): - """ - Returns the fixed distribution specified by the environment. - - Parameters - ---------- - states : tensor - The states for which the fixed distribution is to be returned - """ - return torch.tile(self.fixed_output, (len(states), 1)).to( - dtype=self.float, device=self.device - ) - - def random_distribution(self, states): - """ - Returns the random distribution specified by the environment. - - Parameters - ---------- - states : tensor - The states for which the random distribution is to be returned - """ - return torch.tile(self.random_output, (len(states), 1)).to( - dtype=self.float, device=self.device - ) + pass - def uniform_distribution(self, states): + def __call__(self, states: torch.Tensor): """ - Return action logits (log probabilities) from a uniform distribution + Returns the outputs of the policy model on a batch of states. Parameters ---------- - states : tensor - The states for which the uniform distribution is to be returned + states : torch.Tensor + A batch of states in policy format. """ - return torch.ones( - (len(states), self.output_dim), dtype=self.float, device=self.device - ) + return self.model(states) diff --git a/gflownet/policy/cnn.py b/gflownet/policy/cnn.py index 52693fc3..d9ce142e 100644 --- a/gflownet/policy/cnn.py +++ b/gflownet/policy/cnn.py @@ -6,17 +6,47 @@ class CNNPolicy(Policy): - def __init__(self, **kwargs): - config = self._get_config(kwargs["config"]) - # Shared weights, defaults to False - self.shared_weights = config.get("shared_weights", False) - # Reload checkpoint, defaults to False - self.reload_ckpt = config.get("reload_ckpt", False) + def __init__( + self, + n_layers: int = 2, + channels: Union[int, List] = [16, 32], + kernels: Union[int, List] = [[3, 3], [2, 2]], + strides: Union[int, List] = [[1, 1], [1, 1]], + **kwargs, + ): + """ + CNN Policy class for a :class:`GFlowNetAgent`. + + Parameters + ---------- + n_layers : int + The number of layers in the CNN architecture. + channels : int or list + The number of channels in the convolutional layers or a list of number of + channels for each layer. + kernels : int or list + The kernel size of the convolutions or a list of kernel sizes for each + layer. + strides : int or list + The stride of the convolutions or a list of strides for each layer. + """ # CNN features: number of layers, number of channels, kernel sizes, strides - self.n_layers = config.get("n_layers", 3) - self.channels = config.get("channels", [16] * self.n_layers) - self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers) - self.strides = config.get("strides", [(1, 1)] * self.n_layers) + self.n_layers = n_layers + if isinstance(channels, int): + self.channels = [channels] * self.n_layers + else: + # TODO: check if valid + self.channels = channels + if isinstance(kernels, int): + self.kernels = [(kernels, kernels)] * self.n_layers + else: + # TODO: check if valid + self.kernels = kernels + if isinstance(strides, int): + self.strides = [(stride, stride)] * self.n_layers + else: + # TODO: check if valid + self.strides = strides # Environment # TODO: rethink whether storing the whole environment is needed self.env = env @@ -46,10 +76,10 @@ def make_model(self): current_channels = 1 conv_module = nn.Sequential() - if len(self.kernel_sizes) != self.n_layers: + if len(self.kernels) != self.n_layers: raise ValueError( - f"Inconsistent dimensions kernel_sizes != n_layers, " - "{len(self.kernel_sizes)} != {self.n_layers}" + f"Inconsistent dimensions kernels != n_layers, " + "{len(self.kernels)} != {self.n_layers}" ) for i in range(self.n_layers): @@ -58,7 +88,7 @@ def make_model(self): nn.Conv2d( in_channels=current_channels, out_channels=self.channels[i], - kernel_size=tuple(self.kernel_sizes[i]), + kernel_size=tuple(self.kernels[i]), stride=tuple(self.strides[i]), padding=0, padding_mode="zeros", # Constant zero padding diff --git a/gflownet/policy/fixed.py b/gflownet/policy/fixed.py new file mode 100644 index 00000000..015aff5b --- /dev/null +++ b/gflownet/policy/fixed.py @@ -0,0 +1,25 @@ +from gflownet.policy.base import Policy + + +class FixedPolicy(Policy): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def make_model(self): + """ + Instantiates the policy model as a fixed tensor with the values of + `self.fixed_output` defined by the environment. + + Returns + ------- + model : torch.tensor + The tensor `self.fixed_output` defined by the environment. + is_model : bool + False, because the policy is not a model. + """ + return ( + torch.tile(self.fixed_output, (len(states), 1)).to( + dtype=self.float, device=self.device + ), + False, + ) diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index 8f4fbc80..e94215c8 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -1,3 +1,5 @@ +from typing import List + from omegaconf import OmegaConf from torch import nn @@ -5,16 +7,23 @@ class MLPPolicy(Policy): - def __init__(self, **kwargs): - config = self._get_config(kwargs["config"]) - # Shared weights, defaults to False - self.shared_weights = config.get("shared_weights", False) - # Reload checkpoint, defaults to False - self.reload_ckpt = config.get("reload_ckpt", False) + def __init__(self, n_layers: int = 2, n_hid: int = 128, tail: List = [], **kwargs): + """ + MLP Policy class for a :class:`GFlowNetAgent`. + + Parameters + ---------- + n_layers : int + The number of layers in the MLP architecture. + n_hid : int + The number of hidden units per layer. + tail : list + A list of layers to conform the top (tail) of the MLP architecture. + """ # MLP features: number of layers, number of hidden units, tail, etc. - self.n_layers = config.get("n_layers", 2) - self.n_hid = config.get("n_hid", 128) - self.tail = config.get("tail", []) + self.n_layers = n_layers + self.n_hid = n_hid + self.tail = tail # Base init super().__init__(**kwargs) @@ -71,6 +80,3 @@ def make_model(self, activation: nn.Module = nn.LeakyReLU()): raise ValueError( "Base Model must be provided when shared_weights is set to True" ) - - def __call__(self, states): - return self.model(states) diff --git a/gflownet/policy/random.py b/gflownet/policy/random.py new file mode 100644 index 00000000..e87dcbc0 --- /dev/null +++ b/gflownet/policy/random.py @@ -0,0 +1,25 @@ +from gflownet.policy.base import Policy + + +class RandomPolicy(Policy): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def make_model(self): + """ + Instantiates the policy model as a fixed tensor with the values of + `self.random_output` defined by the environment. + + Returns + ------- + model : torch.tensor + The tensor `self.random_output` defined by the environment. + is_model : bool + False, because the policy is not a model. + """ + return ( + torch.tile(self.random_output, (len(states), 1)).to( + dtype=self.float, device=self.device + ), + False, + ) diff --git a/gflownet/policy/uniform.py b/gflownet/policy/uniform.py new file mode 100644 index 00000000..f901905b --- /dev/null +++ b/gflownet/policy/uniform.py @@ -0,0 +1,25 @@ +from gflownet.policy.base import Policy + + +class UniformPolicy(Policy): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def make_model(self): + """ + Instantiates the policy model as a fixed tensor of ones, to define a uniform + distribution over the action space. + + Returns + ------- + model : torch.tensor + A tensor of `self.output_dim` ones. + is_model : bool + False, because the policy is not a model. + """ + return ( + torch.ones( + (len(states), self.output_dim), dtype=self.float, device=self.device + ), + False, + )