From 1e9b41056f2a72296e69511d31d760ebb28e6900 Mon Sep 17 00:00:00 2001 From: Fabio Seel Date: Wed, 2 Oct 2024 16:08:48 +0200 Subject: [PATCH 01/15] fix: import path for moved utils file --- retinal_rl/rl/analysis/statistics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/retinal_rl/rl/analysis/statistics.py b/retinal_rl/rl/analysis/statistics.py index 5f0f2b46..86bfec4e 100644 --- a/retinal_rl/rl/analysis/statistics.py +++ b/retinal_rl/rl/analysis/statistics.py @@ -11,7 +11,7 @@ from retinal_rl.models.brain import Brain from retinal_rl.models.circuits.convolutional import ConvolutionalEncoder -from retinal_rl.models.util import encoder_out_size, rf_size_and_start +from retinal_rl.util import encoder_out_size, rf_size_and_start from tqdm import tqdm import math From 8682dcde3809ba64dc67b14e7e3e6c13f72a3e11 Mon Sep 17 00:00:00 2001 From: Fabio Seel Date: Wed, 2 Oct 2024 16:56:01 +0200 Subject: [PATCH 02/15] fix: adjust config paths and fix load brain methods --- retinal_rl/rl/sample_factory/sf_framework.py | 23 ++++++++++---------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/retinal_rl/rl/sample_factory/sf_framework.py b/retinal_rl/rl/sample_factory/sf_framework.py index 24a9ea81..79e55ff7 100644 --- a/retinal_rl/rl/sample_factory/sf_framework.py +++ b/retinal_rl/rl/sample_factory/sf_framework.py @@ -67,12 +67,13 @@ def load_brain_from_checkpoint( with open(os.path.join(path, "config.json")) as f: config = Namespace(**json.load(f)) checkpoint_dict, config = SFFramework.get_checkpoint(config) + config = DictConfig(config) model_dict: Dict[str, Any] = checkpoint_dict["model"] brain_dict: Dict[str, Any] = {} for key in model_dict.keys(): if "brain" in key: brain_dict[key[6:]] = model_dict[key] - brain = Brain(**config["brain"]) + brain = Brain(**config.brain) if load_weights: brain.load_state_dict(brain_dict) brain.to(device) @@ -83,34 +84,34 @@ def load_brain_and_config( config_path: str, weights_path: str, device: Optional[torch.device] = None ) -> Brain: with open(os.path.join(config_path, "config.json")) as f: - config = json.load(f) + config = DictConfig(json.load(f)) checkpoint_dict = torch.load(weights_path) model_dict = checkpoint_dict["model"] brain_dict = {} for key in model_dict.keys(): if "brain" in key: brain_dict[key[6:]] = model_dict[key] - brain = Brain(**config["brain"]) + brain = Brain(**config.brain) brain.load_state_dict(brain_dict) brain.to(device) return brain def to_sf_cfg(self, cfg: DictConfig) -> Config: - sf_cfg = self._get_default_cfg(cfg.experiment.rl.env_name) # Load Defaults + sf_cfg = self._get_default_cfg(cfg.rl.env_name) # Load Defaults # overwrite default values with those set in cfg # TODO: which other parameters need to be set_ - self._set_cfg_cli_argument(sf_cfg, "learning_rate", cfg.experiment.training.learning_rate) + self._set_cfg_cli_argument(sf_cfg, "learning_rate", cfg.training.learning_rate) # Using this function is necessary to make sure that the parameters are not overwritten when sample_factory loads a checkpoint - self._set_cfg_cli_argument(sf_cfg, "res_h", cfg.experiment.rl.viewport_height) - self._set_cfg_cli_argument(sf_cfg, "res_w", cfg.experiment.rl.viewport_width) - self._set_cfg_cli_argument(sf_cfg, "env", cfg.experiment.rl.env_name) - self._set_cfg_cli_argument(sf_cfg, "input_satiety", cfg.experiment.rl.input_satiety) + self._set_cfg_cli_argument(sf_cfg, "res_h", cfg.rl.viewport_height) + self._set_cfg_cli_argument(sf_cfg, "res_w", cfg.rl.viewport_width) + self._set_cfg_cli_argument(sf_cfg, "env", cfg.rl.env_name) + self._set_cfg_cli_argument(sf_cfg, "input_satiety", cfg.rl.input_satiety) self._set_cfg_cli_argument(sf_cfg, "device", cfg.system.device) - self._set_cfg_cli_argument(sf_cfg, "optimizer", cfg.experiment.training.optimizer) + self._set_cfg_cli_argument(sf_cfg, "optimizer", cfg.training.optimizer) - self._set_cfg_cli_argument(sf_cfg, "brain", OmegaConf.to_object(cfg.experiment.brain)) + self._set_cfg_cli_argument(sf_cfg, "brain", OmegaConf.to_object(cfg.brain)) return sf_cfg def analyze( From fc0bdd7afd61e859c48ee8cab0e486ee3fc29bf1 Mon Sep 17 00:00:00 2001 From: Fabio Seel Date: Fri, 25 Oct 2024 13:46:55 +0200 Subject: [PATCH 03/15] add example rl configs --- main.py | 5 +++- .../user/brain/feedforward.yaml | 23 +++++++++++++++++++ .../user/dataset/rl-apples.yaml | 6 +++++ .../user/experiment/gathering-apples.yaml | 17 ++++++++++++++ .../user/optimizer/rl-base.yaml | 7 ++++++ retinal_rl/rl/sample_factory/sf_framework.py | 12 +++++----- 6 files changed, 63 insertions(+), 7 deletions(-) create mode 100644 resources/config_templates/user/brain/feedforward.yaml create mode 100644 resources/config_templates/user/dataset/rl-apples.yaml create mode 100644 resources/config_templates/user/experiment/gathering-apples.yaml create mode 100644 resources/config_templates/user/optimizer/rl-base.yaml diff --git a/main.py b/main.py index d13ad7c2..1b7bb593 100644 --- a/main.py +++ b/main.py @@ -51,7 +51,10 @@ def _program(cfg: DictConfig): if hasattr(cfg, "optimizer"): optimizer = instantiate(cfg.optimizer.optimizer, brain.parameters()) - objective = instantiate(cfg.optimizer.objective, brain=brain) + if hasattr(cfg.optimizer, "objective"): + objective = instantiate(cfg.optimizer.objective, brain=brain) + else: + warnings.warn("No objective specified, is that wanted?") else: warnings.warn("No optimizer config specified, is that wanted?") diff --git a/resources/config_templates/user/brain/feedforward.yaml b/resources/config_templates/user/brain/feedforward.yaml new file mode 100644 index 00000000..b51a7082 --- /dev/null +++ b/resources/config_templates/user/brain/feedforward.yaml @@ -0,0 +1,23 @@ +name: feedforward +sensors: + vision: + - 3 + - ${dataset.vision_height} + - ${dataset.vision_width} +connections: + - ["vision", "encoder"] + - ["encoder", "action_decoder"] +circuits: + encoder: + _target_: retinal_rl.models.circuits.convolutional.ConvolutionalEncoder + num_layers: 3 + num_channels: [4,8,16] + kernel_size: 6 + stride: 2 + activation: ${activation} + action_decoder: + _target_: retinal_rl.models.circuits.fully_connected.FullyConnectedDecoder + output_shape: ${action_decoder_out} + hidden_units: ${latent_dimension} + activation: ${activation} + diff --git a/resources/config_templates/user/dataset/rl-apples.yaml b/resources/config_templates/user/dataset/rl-apples.yaml new file mode 100644 index 00000000..38818701 --- /dev/null +++ b/resources/config_templates/user/dataset/rl-apples.yaml @@ -0,0 +1,6 @@ +name: rl-apples + +env_name: gathering-apples +vision_width: 160 +vision_height: 120 +input_satiety: true \ No newline at end of file diff --git a/resources/config_templates/user/experiment/gathering-apples.yaml b/resources/config_templates/user/experiment/gathering-apples.yaml new file mode 100644 index 00000000..a81f0c6b --- /dev/null +++ b/resources/config_templates/user/experiment/gathering-apples.yaml @@ -0,0 +1,17 @@ +# @package _global_ +defaults: + - _self_ + - override /dataset: rl-apples + - override /brain: feedforward + - override /optimizer: rl-base + +framework: rl + +### Interpolation Parameters ### + +# This is a free list of parameters that can be interpolated by the subconfigs +# in sweep, dataset, brain, and optimizer. A major use for this is interpolating +# values in the subconfigs, and then looping over them in a sweep. +activation: "elu" +latent_dimension: [2048,1024] +action_decoder_out: [512] diff --git a/resources/config_templates/user/optimizer/rl-base.yaml b/resources/config_templates/user/optimizer/rl-base.yaml new file mode 100644 index 00000000..13ad4ec7 --- /dev/null +++ b/resources/config_templates/user/optimizer/rl-base.yaml @@ -0,0 +1,7 @@ +# The optimizer to use +optimizer: # torch.optim Class and parameters + _target_: torch.optim.Adam + lr: 0.0003 + +# The objective function +# TODO: Implement in RL and update config diff --git a/retinal_rl/rl/sample_factory/sf_framework.py b/retinal_rl/rl/sample_factory/sf_framework.py index 79e55ff7..07841e35 100644 --- a/retinal_rl/rl/sample_factory/sf_framework.py +++ b/retinal_rl/rl/sample_factory/sf_framework.py @@ -88,7 +88,7 @@ def load_brain_and_config( checkpoint_dict = torch.load(weights_path) model_dict = checkpoint_dict["model"] brain_dict = {} - for key in model_dict.keys(): + for key in model_dict: if "brain" in key: brain_dict[key[6:]] = model_dict[key] brain = Brain(**config.brain) @@ -97,17 +97,17 @@ def load_brain_and_config( return brain def to_sf_cfg(self, cfg: DictConfig) -> Config: - sf_cfg = self._get_default_cfg(cfg.rl.env_name) # Load Defaults + sf_cfg = self._get_default_cfg(cfg.dataset.env_name) # Load Defaults # overwrite default values with those set in cfg # TODO: which other parameters need to be set_ self._set_cfg_cli_argument(sf_cfg, "learning_rate", cfg.training.learning_rate) # Using this function is necessary to make sure that the parameters are not overwritten when sample_factory loads a checkpoint - self._set_cfg_cli_argument(sf_cfg, "res_h", cfg.rl.viewport_height) - self._set_cfg_cli_argument(sf_cfg, "res_w", cfg.rl.viewport_width) - self._set_cfg_cli_argument(sf_cfg, "env", cfg.rl.env_name) - self._set_cfg_cli_argument(sf_cfg, "input_satiety", cfg.rl.input_satiety) + self._set_cfg_cli_argument(sf_cfg, "res_h", cfg.dataset.vision_width) + self._set_cfg_cli_argument(sf_cfg, "res_w", cfg.dataset.vision_height) + self._set_cfg_cli_argument(sf_cfg, "env", cfg.dataset.env_name) + self._set_cfg_cli_argument(sf_cfg, "input_satiety", cfg.dataset.input_satiety) self._set_cfg_cli_argument(sf_cfg, "device", cfg.system.device) self._set_cfg_cli_argument(sf_cfg, "optimizer", cfg.training.optimizer) From 90e9391073448ed8e2dcb12a527ac993cd9660bf Mon Sep 17 00:00:00 2001 From: Fabio Seel Date: Fri, 25 Oct 2024 14:23:50 +0200 Subject: [PATCH 04/15] fix: change optimizer referencing in to_sf_cfg --- retinal_rl/rl/sample_factory/sf_framework.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/retinal_rl/rl/sample_factory/sf_framework.py b/retinal_rl/rl/sample_factory/sf_framework.py index 07841e35..79ed53c6 100644 --- a/retinal_rl/rl/sample_factory/sf_framework.py +++ b/retinal_rl/rl/sample_factory/sf_framework.py @@ -101,7 +101,7 @@ def to_sf_cfg(self, cfg: DictConfig) -> Config: # overwrite default values with those set in cfg # TODO: which other parameters need to be set_ - self._set_cfg_cli_argument(sf_cfg, "learning_rate", cfg.training.learning_rate) + self._set_cfg_cli_argument(sf_cfg, "learning_rate", cfg.optimizer.optimizer.lr) # Using this function is necessary to make sure that the parameters are not overwritten when sample_factory loads a checkpoint self._set_cfg_cli_argument(sf_cfg, "res_h", cfg.dataset.vision_width) @@ -109,7 +109,8 @@ def to_sf_cfg(self, cfg: DictConfig) -> Config: self._set_cfg_cli_argument(sf_cfg, "env", cfg.dataset.env_name) self._set_cfg_cli_argument(sf_cfg, "input_satiety", cfg.dataset.input_satiety) self._set_cfg_cli_argument(sf_cfg, "device", cfg.system.device) - self._set_cfg_cli_argument(sf_cfg, "optimizer", cfg.training.optimizer) + optimizer_name = str.split(cfg.optimizer.optimizer._target_, sep='.')[-1] + self._set_cfg_cli_argument(sf_cfg, "optimizer", optimizer_name) self._set_cfg_cli_argument(sf_cfg, "brain", OmegaConf.to_object(cfg.brain)) return sf_cfg From 86088a4a8537bcd9d1df3c0dc633e86f7672ae4e Mon Sep 17 00:00:00 2001 From: Fabio Seel Date: Fri, 25 Oct 2024 14:51:51 +0200 Subject: [PATCH 05/15] fix: add brain factory method as interface --- main.py | 16 ++------------ retinal_rl/rl/sample_factory/models.py | 29 +++++++++++++------------- runner/util.py | 15 +++++++++++++ 3 files changed, 32 insertions(+), 28 deletions(-) diff --git a/main.py b/main.py index 1b7bb593..71d68d97 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,6 @@ import os import sys import warnings -from typing import Dict, List, cast import hydra import torch @@ -11,14 +10,13 @@ from omegaconf import DictConfig, OmegaConf from retinal_rl.framework_interface import TrainingFramework -from retinal_rl.models.brain import Brain from retinal_rl.rl.sample_factory.sf_framework import SFFramework from runner.analyze import analyze from runner.dataset import get_datasets from runner.initialize import initialize from runner.sweep import launch_sweep from runner.train import train -from runner.util import assemble_neural_circuits, delete_results +from runner.util import create_brain, delete_results # Load the eval resolver for OmegaConf OmegaConf.register_new_resolver("eval", eval) @@ -37,17 +35,7 @@ def _program(cfg: DictConfig): device = torch.device(cfg.system.device) - sensors = OmegaConf.to_container(cfg.brain.sensors, resolve=True) - sensors = cast(Dict[str, List[int]], sensors) - - connections = OmegaConf.to_container(cfg.brain.connections, resolve=True) - connections = cast(List[List[str]], connections) - - connectome, circuits = assemble_neural_circuits( - cfg.brain.circuits, sensors, connections - ) - - brain = Brain(circuits, sensors, connectome).to(device) + brain = create_brain(cfg.brain).to(device) if hasattr(cfg, "optimizer"): optimizer = instantiate(cfg.optimizer.optimizer, brain.parameters()) diff --git a/retinal_rl/rl/sample_factory/models.py b/retinal_rl/rl/sample_factory/models.py index 453bce31..8d6b3ef1 100644 --- a/retinal_rl/rl/sample_factory/models.py +++ b/retinal_rl/rl/sample_factory/models.py @@ -1,21 +1,22 @@ +import warnings +from enum import Enum from typing import Dict, Optional, Tuple + +import networkx as nx +import numpy as np +import torch +from sample_factory.algo.utils.context import global_model_factory +from sample_factory.algo.utils.tensor_dict import TensorDict from sample_factory.model.actor_critic import ActorCritic -from sample_factory.model.encoder import Encoder -from sample_factory.model.decoder import Decoder from sample_factory.model.core import ModelCore -from sample_factory.utils.typing import ActionSpace, Config, ObsSpace -from sample_factory.algo.utils.context import global_model_factory +from sample_factory.model.decoder import Decoder +from sample_factory.model.encoder import Encoder from sample_factory.model.model_utils import model_device -from sample_factory.algo.utils.tensor_dict import TensorDict -from torch import Tensor -import torch -import numpy as np -import networkx as nx -from retinal_rl.models.brain import Brain +from sample_factory.utils.typing import ActionSpace, Config, ObsSpace +from torch import Tensor, nn + from retinal_rl.rl.sample_factory.sf_interfaces import ActorCriticProtocol -import warnings -from enum import Enum -from torch import nn +from runner.util import create_brain #TODO: Remove runner reference! class CoreMode(Enum): @@ -30,7 +31,7 @@ def __init__(self, cfg: Config, obs_space: ObsSpace, action_space: ActionSpace): super().__init__(obs_space, action_space, cfg) self.set_brain( - Brain(**cfg.brain) + create_brain(cfg.brain) ) # TODO: Find way to instantiate brain outside dec_out_shape = self.brain.circuits[self.decoder_name].output_shape diff --git a/runner/util.py b/runner/util.py index 27f87ce7..6dc19972 100644 --- a/runner/util.py +++ b/runner/util.py @@ -6,6 +6,7 @@ import os import shutil from typing import Any, Dict, List, Tuple +from typing import Dict, List, cast import networkx as nx import torch @@ -16,6 +17,7 @@ from torch.optim.optimizer import Optimizer from retinal_rl.models.neural_circuit import NeuralCircuit +from retinal_rl.models.brain import Brain nx.DiGraph.__class_getitem__ = classmethod(lambda _, __: "nx.DiGraph") # type: ignore @@ -80,6 +82,19 @@ def delete_results(cfg: DictConfig) -> None: print("Deletion cancelled.") +def create_brain(brain_cfg: DictConfig) -> Brain: + sensors = OmegaConf.to_container(brain_cfg.sensors, resolve=True) + sensors = cast(Dict[str, List[int]], sensors) + + connections = OmegaConf.to_container(brain_cfg.connections, resolve=True) + connections = cast(List[List[str]], connections) + + connectome, circuits = assemble_neural_circuits( + brain_cfg.circuits, sensors, connections + ) + + return Brain(circuits, sensors, connectome) + def assemble_neural_circuits( circuits: DictConfig, sensors: Dict[str, List[int]], From 98eea38a596e7827cce064a518cb71bd9c602797 Mon Sep 17 00:00:00 2001 From: Fabio Seel Date: Fri, 25 Oct 2024 17:42:19 +0200 Subject: [PATCH 06/15] fix: back and forth of config loses DictConfig . accessing --- retinal_rl/rl/sample_factory/models.py | 15 ++++++++------- retinal_rl/rl/sample_factory/sf_framework.py | 5 +++-- runner/util.py | 4 ++-- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/retinal_rl/rl/sample_factory/models.py b/retinal_rl/rl/sample_factory/models.py index 8d6b3ef1..a8e4f974 100644 --- a/retinal_rl/rl/sample_factory/models.py +++ b/retinal_rl/rl/sample_factory/models.py @@ -4,6 +4,7 @@ import networkx as nx import numpy as np +from omegaconf import DictConfig import torch from sample_factory.algo.utils.context import global_model_factory from sample_factory.algo.utils.tensor_dict import TensorDict @@ -16,7 +17,8 @@ from torch import Tensor, nn from retinal_rl.rl.sample_factory.sf_interfaces import ActorCriticProtocol -from runner.util import create_brain #TODO: Remove runner reference! +from runner.util import create_brain # TODO: Remove runner reference! +from retinal_rl.models.brain import Brain class CoreMode(Enum): @@ -25,21 +27,21 @@ class CoreMode(Enum): RNN = (2,) MULTI_MODULES = (3,) + class SampleFactoryBrain(ActorCritic, ActorCriticProtocol): def __init__(self, cfg: Config, obs_space: ObsSpace, action_space: ActionSpace): # Attention: make_actor_critic passes [cfg, obs_space, action_space], but ActorCritic takes the reversed order of arguments [obs_space, action_space, cfg] super().__init__(obs_space, action_space, cfg) - self.set_brain( - create_brain(cfg.brain) - ) # TODO: Find way to instantiate brain outside + self.set_brain(create_brain(DictConfig(cfg.brain))) + # TODO: Find way to instantiate brain outside dec_out_shape = self.brain.circuits[self.decoder_name].output_shape decoder_out_size = np.prod(dec_out_shape) self.critic_linear = nn.Linear(decoder_out_size, 1) self.action_parameterization = self.get_action_parameterization( decoder_out_size - ) # boils down to a linear layer mapping to num_action_outputs + ) # boils down to a linear layer mapping to num_action_outputs def set_brain(self, brain: Brain): """ @@ -154,10 +156,9 @@ def forward( def get_brain(self) -> Brain: return self.brain - # Methods need to be overwritten 'cause the use .encoders def device_for_input_tensor(self, input_tensor_name: str) -> torch.device: return model_device(self) def type_for_input_tensor(self, input_tensor_name: str) -> torch.dtype: - return torch.float32 \ No newline at end of file + return torch.float32 diff --git a/retinal_rl/rl/sample_factory/sf_framework.py b/retinal_rl/rl/sample_factory/sf_framework.py index 79ed53c6..4fdd9883 100644 --- a/retinal_rl/rl/sample_factory/sf_framework.py +++ b/retinal_rl/rl/sample_factory/sf_framework.py @@ -24,6 +24,7 @@ add_retinal_env_eval_args, retinal_override_defaults, ) +from runner.util import create_brain import json from retinal_rl.rl.sample_factory.environment import register_retinal_env # from retinal_rl.rl.sample_factory.observer import RetinalAlgoObserver @@ -73,7 +74,7 @@ def load_brain_from_checkpoint( for key in model_dict.keys(): if "brain" in key: brain_dict[key[6:]] = model_dict[key] - brain = Brain(**config.brain) + brain = create_brain(config.brain) if load_weights: brain.load_state_dict(brain_dict) brain.to(device) @@ -91,7 +92,7 @@ def load_brain_and_config( for key in model_dict: if "brain" in key: brain_dict[key[6:]] = model_dict[key] - brain = Brain(**config.brain) + brain = create_brain(config.brain) brain.load_state_dict(brain_dict) brain.to(device) return brain diff --git a/runner/util.py b/runner/util.py index 6dc19972..d1d8bed7 100644 --- a/runner/util.py +++ b/runner/util.py @@ -86,11 +86,11 @@ def create_brain(brain_cfg: DictConfig) -> Brain: sensors = OmegaConf.to_container(brain_cfg.sensors, resolve=True) sensors = cast(Dict[str, List[int]], sensors) - connections = OmegaConf.to_container(brain_cfg.connections, resolve=True) + connections = OmegaConf.to_container(brain_cfg.sensors, resolve=True) connections = cast(List[List[str]], connections) connectome, circuits = assemble_neural_circuits( - brain_cfg.circuits, sensors, connections + brain_cfg['circuits'], sensors, connections ) return Brain(circuits, sensors, connectome) From d908f2843a1e880e1c21186300d59893378ba7d8 Mon Sep 17 00:00:00 2001 From: Fabio Seel Date: Mon, 28 Oct 2024 11:39:43 +0100 Subject: [PATCH 07/15] fix: update config after change in model --- resources/config_templates/user/brain/feedforward.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resources/config_templates/user/brain/feedforward.yaml b/resources/config_templates/user/brain/feedforward.yaml index b51a7082..95289777 100644 --- a/resources/config_templates/user/brain/feedforward.yaml +++ b/resources/config_templates/user/brain/feedforward.yaml @@ -16,7 +16,7 @@ circuits: stride: 2 activation: ${activation} action_decoder: - _target_: retinal_rl.models.circuits.fully_connected.FullyConnectedDecoder + _target_: retinal_rl.models.circuits.fully_connected.FullyConnected output_shape: ${action_decoder_out} hidden_units: ${latent_dimension} activation: ${activation} From a559a12e258787a0bcba4deedbbd9b73d9738d74 Mon Sep 17 00:00:00 2001 From: Fabio Seel Date: Mon, 28 Oct 2024 11:40:47 +0100 Subject: [PATCH 08/15] fix: wrong reference in create_brain --- runner/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runner/util.py b/runner/util.py index d1d8bed7..6dc19972 100644 --- a/runner/util.py +++ b/runner/util.py @@ -86,11 +86,11 @@ def create_brain(brain_cfg: DictConfig) -> Brain: sensors = OmegaConf.to_container(brain_cfg.sensors, resolve=True) sensors = cast(Dict[str, List[int]], sensors) - connections = OmegaConf.to_container(brain_cfg.sensors, resolve=True) + connections = OmegaConf.to_container(brain_cfg.connections, resolve=True) connections = cast(List[List[str]], connections) connectome, circuits = assemble_neural_circuits( - brain_cfg['circuits'], sensors, connections + brain_cfg.circuits, sensors, connections ) return Brain(circuits, sensors, connectome) From 6161013da9a8b091b9f874a89046cbe16c1bc111 Mon Sep 17 00:00:00 2001 From: Fabio Seel Date: Mon, 28 Oct 2024 11:48:32 +0100 Subject: [PATCH 09/15] fix: sample_factory needs optimizer in lowercase --- retinal_rl/rl/sample_factory/sf_framework.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/retinal_rl/rl/sample_factory/sf_framework.py b/retinal_rl/rl/sample_factory/sf_framework.py index 4fdd9883..934797b0 100644 --- a/retinal_rl/rl/sample_factory/sf_framework.py +++ b/retinal_rl/rl/sample_factory/sf_framework.py @@ -110,7 +110,7 @@ def to_sf_cfg(self, cfg: DictConfig) -> Config: self._set_cfg_cli_argument(sf_cfg, "env", cfg.dataset.env_name) self._set_cfg_cli_argument(sf_cfg, "input_satiety", cfg.dataset.input_satiety) self._set_cfg_cli_argument(sf_cfg, "device", cfg.system.device) - optimizer_name = str.split(cfg.optimizer.optimizer._target_, sep='.')[-1] + optimizer_name = str.lower(str.split(cfg.optimizer.optimizer._target_, sep='.')[-1]) self._set_cfg_cli_argument(sf_cfg, "optimizer", optimizer_name) self._set_cfg_cli_argument(sf_cfg, "brain", OmegaConf.to_object(cfg.brain)) From d1dabb57ef8cbaaf5a91fc665d6c043e7f0383bd Mon Sep 17 00:00:00 2001 From: Fabio Seel Date: Mon, 28 Oct 2024 11:59:22 +0100 Subject: [PATCH 10/15] autofix by ruff on changed files --- retinal_rl/rl/analysis/statistics.py | 17 +++++----- retinal_rl/rl/sample_factory/models.py | 10 ++---- retinal_rl/rl/sample_factory/sf_framework.py | 33 ++++++++++---------- runner/util.py | 5 ++- 4 files changed, 29 insertions(+), 36 deletions(-) diff --git a/retinal_rl/rl/analysis/statistics.py b/retinal_rl/rl/analysis/statistics.py index 86bfec4e..224664ef 100644 --- a/retinal_rl/rl/analysis/statistics.py +++ b/retinal_rl/rl/analysis/statistics.py @@ -1,19 +1,18 @@ -from typing import Dict, List, Tuple +import math import warnings +from typing import Dict, List, Tuple import numpy as np import torch -import torch.nn as nn from captum.attr import NeuronGradient from numpy.typing import NDArray -from torch import Tensor +from torch import Tensor, nn from torch.utils.data import Dataset +from tqdm import tqdm from retinal_rl.models.brain import Brain from retinal_rl.models.circuits.convolutional import ConvolutionalEncoder from retinal_rl.util import encoder_out_size, rf_size_and_start -from tqdm import tqdm -import math def gradient_receptive_fields( @@ -137,17 +136,17 @@ def get_input_output_shape(model: nn.Sequential): in_size = layer.in_features down_stream_linear = True break - elif isinstance(layer, nn.Conv2d): + if isinstance(layer, nn.Conv2d): num_outputs = layer.out_channels in_channels = layer.in_channels in_size = layer.in_channels * ((layer.kernel_size[0]-1)*layer.dilation[0]+1) ** 2 break - elif isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d): + if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d): for prev_layer in reversed(model[:-i-1]): if isinstance(prev_layer, nn.Conv2d): in_channels = prev_layer.out_channels break - elif isinstance(prev_layer, nn.Linear): + if isinstance(prev_layer, nn.Linear): in_channels=1 else: raise Exception("layer before pooling needs to be conv or linear") @@ -171,7 +170,7 @@ def get_input_output_shape(model: nn.Sequential): in_size = ( (in_size - 1) * layer.stride[0] - 2 * layer.padding[0] * down_stream_linear - + ((layer.kernel_size[0]-1)*layer.dilation[0]+1) + + ((layer.kernel_size[0]-1)*layer.dilation[0]+1) ) in_size = in_size**2 * in_channels elif isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d): diff --git a/retinal_rl/rl/sample_factory/models.py b/retinal_rl/rl/sample_factory/models.py index a8e4f974..c79f651e 100644 --- a/retinal_rl/rl/sample_factory/models.py +++ b/retinal_rl/rl/sample_factory/models.py @@ -1,24 +1,20 @@ import warnings from enum import Enum -from typing import Dict, Optional, Tuple +from typing import Dict, Tuple import networkx as nx import numpy as np -from omegaconf import DictConfig import torch -from sample_factory.algo.utils.context import global_model_factory +from omegaconf import DictConfig from sample_factory.algo.utils.tensor_dict import TensorDict from sample_factory.model.actor_critic import ActorCritic -from sample_factory.model.core import ModelCore -from sample_factory.model.decoder import Decoder -from sample_factory.model.encoder import Encoder from sample_factory.model.model_utils import model_device from sample_factory.utils.typing import ActionSpace, Config, ObsSpace from torch import Tensor, nn +from retinal_rl.models.brain import Brain from retinal_rl.rl.sample_factory.sf_interfaces import ActorCriticProtocol from runner.util import create_brain # TODO: Remove runner reference! -from retinal_rl.models.brain import Brain class CoreMode(Enum): diff --git a/retinal_rl/rl/sample_factory/sf_framework.py b/retinal_rl/rl/sample_factory/sf_framework.py index 934797b0..e2b4e550 100644 --- a/retinal_rl/rl/sample_factory/sf_framework.py +++ b/retinal_rl/rl/sample_factory/sf_framework.py @@ -1,39 +1,38 @@ +import argparse +import json +import os +from argparse import Namespace from typing import Any, Dict, List, Optional, Tuple +# from retinal_rl.rl.sample_factory.observer import RetinalAlgoObserver +import torch from omegaconf import DictConfig +from omegaconf.omegaconf import OmegaConf +from sample_factory.algo.learning.learner import Learner from sample_factory.algo.utils.context import global_model_factory from sample_factory.algo.utils.misc import ExperimentStatus from sample_factory.cfg.arguments import ( + load_from_checkpoint, parse_full_cfg, parse_sf_args, ) +from sample_factory.enjoy import enjoy from sample_factory.train import make_runner -from sample_factory.utils.typing import Config -from sample_factory.cfg.arguments import load_from_checkpoint -from sample_factory.algo.learning.learner import Learner from sample_factory.utils.attr_dict import AttrDict +from sample_factory.utils.typing import Config from torch import Tensor -import argparse from torch.utils.data import Dataset -from retinal_rl.models.brain import Brain from retinal_rl.framework_interface import TrainingFramework -from retinal_rl.rl.sample_factory.models import SampleFactoryBrain +from retinal_rl.models.brain import Brain from retinal_rl.rl.sample_factory.arguments import ( add_retinal_env_args, add_retinal_env_eval_args, retinal_override_defaults, ) -from runner.util import create_brain -import json from retinal_rl.rl.sample_factory.environment import register_retinal_env -# from retinal_rl.rl.sample_factory.observer import RetinalAlgoObserver -import torch -from sample_factory.enjoy import enjoy - -import os -from argparse import Namespace -from omegaconf.omegaconf import OmegaConf +from retinal_rl.rl.sample_factory.models import SampleFactoryBrain +from runner.util import create_brain class SFFramework(TrainingFramework): @@ -71,7 +70,7 @@ def load_brain_from_checkpoint( config = DictConfig(config) model_dict: Dict[str, Any] = checkpoint_dict["model"] brain_dict: Dict[str, Any] = {} - for key in model_dict.keys(): + for key in model_dict: if "brain" in key: brain_dict[key[6:]] = model_dict[key] brain = create_brain(config.brain) @@ -157,7 +156,7 @@ def _get_default_cfg(envname: str = "") -> argparse.Namespace: sf_cfg = parse_full_cfg(parser, mock_argv) return sf_cfg - + @staticmethod def get_checkpoint(cfg: Config) -> tuple[Dict[str, Any], AttrDict]: """ diff --git a/runner/util.py b/runner/util.py index 6dc19972..7e24dbc5 100644 --- a/runner/util.py +++ b/runner/util.py @@ -5,8 +5,7 @@ import logging import os import shutil -from typing import Any, Dict, List, Tuple -from typing import Dict, List, cast +from typing import Any, Dict, List, Tuple, cast import networkx as nx import torch @@ -16,8 +15,8 @@ from torch import nn from torch.optim.optimizer import Optimizer -from retinal_rl.models.neural_circuit import NeuralCircuit from retinal_rl.models.brain import Brain +from retinal_rl.models.neural_circuit import NeuralCircuit nx.DiGraph.__class_getitem__ = classmethod(lambda _, __: "nx.DiGraph") # type: ignore From 30f3a2d37b2db617666d4ca033b03b6509edfacf Mon Sep 17 00:00:00 2001 From: Fabio Seel Date: Mon, 28 Oct 2024 12:00:51 +0100 Subject: [PATCH 11/15] bugfix: lint.sh with just --fix works now --- tests/ci/lint.sh | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/ci/lint.sh b/tests/ci/lint.sh index cfa1c70d..5771c494 100755 --- a/tests/ci/lint.sh +++ b/tests/ci/lint.sh @@ -30,11 +30,6 @@ if [ "$1" = "--all" ]; then # Run ruff on all files with any remaining arguments apptainer exec "$CONTAINER" ruff check . "$@" else - # If first arg isn't --all, put it back in the argument list - if [ -n "$1" ]; then - set -- "$1" "$@" - fi - # Get changed Python files changed_files=$(git diff --name-only origin/master...HEAD -- '*.py') if [ -n "$changed_files" ]; then From 447d4fe7608201d9c17301178cc4df8ee4487fe1 Mon Sep 17 00:00:00 2001 From: Fabio Seel Date: Mon, 28 Oct 2024 13:14:06 +0100 Subject: [PATCH 12/15] fix: merge isinstance calls, add/remove return statements etc --- retinal_rl/rl/analysis/statistics.py | 6 +++--- retinal_rl/rl/sample_factory/models.py | 4 ++-- retinal_rl/rl/sample_factory/sf_framework.py | 9 +++------ 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/retinal_rl/rl/analysis/statistics.py b/retinal_rl/rl/analysis/statistics.py index 224664ef..0b2d6911 100644 --- a/retinal_rl/rl/analysis/statistics.py +++ b/retinal_rl/rl/analysis/statistics.py @@ -141,7 +141,7 @@ def get_input_output_shape(model: nn.Sequential): in_channels = layer.in_channels in_size = layer.in_channels * ((layer.kernel_size[0]-1)*layer.dilation[0]+1) ** 2 break - if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d): + if isinstance(layer, (nn.MaxPool2d, nn.AvgPool2d)): for prev_layer in reversed(model[:-i-1]): if isinstance(prev_layer, nn.Conv2d): in_channels = prev_layer.out_channels @@ -149,7 +149,7 @@ def get_input_output_shape(model: nn.Sequential): if isinstance(prev_layer, nn.Linear): in_channels=1 else: - raise Exception("layer before pooling needs to be conv or linear") + raise TypeError("layer before pooling needs to be conv or linear") _kernel_size = layer.kernel_size if isinstance(layer.kernel_size, int) else layer.kernel_size[0] in_size = _kernel_size**2 * in_channels break @@ -173,7 +173,7 @@ def get_input_output_shape(model: nn.Sequential): + ((layer.kernel_size[0]-1)*layer.dilation[0]+1) ) in_size = in_size**2 * in_channels - elif isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d): + elif isinstance(layer, (nn.MaxPool2d, nn.AvgPool2d)): for prev_layer in reversed(model[:-i-_first-1]): if isinstance(prev_layer, nn.Conv2d): in_channels = prev_layer.out_channels diff --git a/retinal_rl/rl/sample_factory/models.py b/retinal_rl/rl/sample_factory/models.py index c79f651e..9933deca 100644 --- a/retinal_rl/rl/sample_factory/models.py +++ b/retinal_rl/rl/sample_factory/models.py @@ -53,7 +53,7 @@ def set_brain(self, brain: Brain): @staticmethod def get_encoder_decoder(brain: Brain) -> Tuple[str, CoreMode, str]: - assert "vision" in brain.sensors.keys() # needed as input + assert "vision" in brain.sensors # needed as input # potential TODO: add other input sources if needed? vision_paths = [] @@ -62,7 +62,7 @@ def get_encoder_decoder(brain: Brain) -> Tuple[str, CoreMode, str]: vision_paths.append(nx.shortest_path(brain.connectome, "vision", node)) decoder = "action_decoder" # default assumption - if decoder in brain.circuits.keys(): # needed to produce output = decoder + if decoder in brain.circuits: # needed to produce output = decoder vision_path = nx.shortest_path(brain.connectome, "vision", "action_decoder") else: selected_path = 0 diff --git a/retinal_rl/rl/sample_factory/sf_framework.py b/retinal_rl/rl/sample_factory/sf_framework.py index e2b4e550..bd3a0f0c 100644 --- a/retinal_rl/rl/sample_factory/sf_framework.py +++ b/retinal_rl/rl/sample_factory/sf_framework.py @@ -58,7 +58,7 @@ def train(self): status = runner.init() if status == ExperimentStatus.SUCCESS: status = runner.run() - return status + print(status) @staticmethod def load_brain_from_checkpoint( @@ -126,9 +126,7 @@ def analyze( epoch: int, copy_checkpoint: bool = False, ): - - status = enjoy(self.sf_cfg) - return status + return enjoy(self.sf_cfg) @staticmethod def _set_cfg_cli_argument(cfg: Namespace, name: str, value: Any): @@ -154,8 +152,7 @@ def _get_default_cfg(envname: str = "") -> argparse.Namespace: # Actually, discuss that. Would avoid having a unified interface retinal_override_defaults(parser) - sf_cfg = parse_full_cfg(parser, mock_argv) - return sf_cfg + return parse_full_cfg(parser, mock_argv) @staticmethod def get_checkpoint(cfg: Config) -> tuple[Dict[str, Any], AttrDict]: From 39c4ba6b3d43a1eeadc51a898722d1812f07d1c1 Mon Sep 17 00:00:00 2001 From: Fabio Seel Date: Mon, 28 Oct 2024 14:20:21 +0100 Subject: [PATCH 13/15] fix: optimizer has to be provided, remove check --- main.py | 15 ++++++++------- runner/util.py | 3 ++- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index 71d68d97..4a2ee026 100644 --- a/main.py +++ b/main.py @@ -25,6 +25,9 @@ # Hydra entry point @hydra.main(config_path="config/base", config_name="config", version_base=None) def _program(cfg: DictConfig): + #TODO: Instead of doing checks of the config here, we should implement + # sth like the configstore which ensures config parameters are present + if cfg.command == "clean": delete_results(cfg) sys.exit(0) @@ -37,14 +40,12 @@ def _program(cfg: DictConfig): brain = create_brain(cfg.brain).to(device) - if hasattr(cfg, "optimizer"): - optimizer = instantiate(cfg.optimizer.optimizer, brain.parameters()) - if hasattr(cfg.optimizer, "objective"): - objective = instantiate(cfg.optimizer.objective, brain=brain) - else: - warnings.warn("No objective specified, is that wanted?") + optimizer = instantiate(cfg.optimizer.optimizer, brain.parameters()) + if hasattr(cfg.optimizer, "objective"): + objective = instantiate(cfg.optimizer.objective, brain=brain) + # TODO: RL framework currently can't use objective else: - warnings.warn("No optimizer config specified, is that wanted?") + warnings.warn("No objective specified, is that wanted?") if cfg.command == "scan": print(brain.scan()) diff --git a/runner/util.py b/runner/util.py index 7e24dbc5..fbfea1b8 100644 --- a/runner/util.py +++ b/runner/util.py @@ -92,7 +92,8 @@ def create_brain(brain_cfg: DictConfig) -> Brain: brain_cfg.circuits, sensors, connections ) - return Brain(circuits, sensors, connectome) + return Brain(circuits, sensors, connectome) + def assemble_neural_circuits( circuits: DictConfig, From 3478315d5d20f81b52292f900be5aca4d5512d85 Mon Sep 17 00:00:00 2001 From: Fabio Seel Date: Mon, 28 Oct 2024 14:43:46 +0100 Subject: [PATCH 14/15] fix: split complex function --- retinal_rl/rl/analysis/statistics.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/retinal_rl/rl/analysis/statistics.py b/retinal_rl/rl/analysis/statistics.py index 0b2d6911..77689268 100644 --- a/retinal_rl/rl/analysis/statistics.py +++ b/retinal_rl/rl/analysis/statistics.py @@ -1,6 +1,6 @@ import math import warnings -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np import torch @@ -117,17 +117,11 @@ def sum_collapse_output(out_tensor): return out_tensor -def get_input_output_shape(model: nn.Sequential): - """ - Calculates the 'minimal' input and output of a sequential model. - If last layer is a convolutional layer, output is assumed to be the number of channels (so 1x1 in space). - Takes into account if last layer is a pooling layer. - For linear layer obviously the number of out_features. - TODO: assert kernel sizes etc are quadratic / implement adaptation to non quadratic kernels - """ +def _find_last_layer_shape(model: nn.Sequential) -> Tuple[int, Optional[int], Optional[int], Optional[int], bool]: _first = 0 down_stream_linear = False num_outputs = None + in_size, in_channels = None, None for i, layer in enumerate(reversed(model)): _first += 1 if isinstance(layer, nn.Linear): @@ -153,7 +147,19 @@ def get_input_output_shape(model: nn.Sequential): _kernel_size = layer.kernel_size if isinstance(layer.kernel_size, int) else layer.kernel_size[0] in_size = _kernel_size**2 * in_channels break + return _first, num_outputs, in_size, in_channels, down_stream_linear + +def get_input_output_shape(model: nn.Sequential): + """ + Calculates the 'minimal' input and output of a sequential model. + If last layer is a convolutional layer, output is assumed to be the number of channels (so 1x1 in space). + Takes into account if last layer is a pooling layer. + For linear layer obviously the number of out_features. + TODO: assert kernel sizes etc are quadratic / implement adaptation to non quadratic kernels + TODO: Check if still needed, function near duplicate of some of Sachas code + """ + _first, num_outputs, in_size, in_channels, down_stream_linear = _find_last_layer_shape(model) for i, layer in enumerate(reversed(model[:-_first])): if isinstance(layer, nn.Linear): From 7168e02f6703584ada16de4600403a3c640f83d4 Mon Sep 17 00:00:00 2001 From: Fabio Seel Date: Tue, 29 Oct 2024 10:19:07 +0100 Subject: [PATCH 15/15] doc: mark rl analysis stuff as deprecated --- retinal_rl/rl/analysis/statistics.py | 81 ++++++++++++++++++++-------- 1 file changed, 60 insertions(+), 21 deletions(-) diff --git a/retinal_rl/rl/analysis/statistics.py b/retinal_rl/rl/analysis/statistics.py index 77689268..4eb2107b 100644 --- a/retinal_rl/rl/analysis/statistics.py +++ b/retinal_rl/rl/analysis/statistics.py @@ -9,12 +9,14 @@ from torch import Tensor, nn from torch.utils.data import Dataset from tqdm import tqdm +from typing_extensions import deprecated from retinal_rl.models.brain import Brain from retinal_rl.models.circuits.convolutional import ConvolutionalEncoder from retinal_rl.util import encoder_out_size, rf_size_and_start +@deprecated("Use functions of retinal_rl.analysis.statistics") def gradient_receptive_fields( device: torch.device, enc: ConvolutionalEncoder ) -> Dict[str, NDArray[np.float64]]: @@ -52,10 +54,10 @@ def gradient_receptive_fields( # Assert min max is in bounds # potential TODO: change input size if rf is larger than actual input - h_min = max(0,h_min) - w_min = max(0,w_min) - hrf_size = min(hght,hrf_size) - wrf_size = min(wdth,wrf_size) + h_min = max(0, h_min) + w_min = max(0, w_min) + hrf_size = min(hght, hrf_size) + wrf_size = min(wdth, wrf_size) h_max = h_min + hrf_size w_max = w_min + wrf_size @@ -74,13 +76,18 @@ def gradient_receptive_fields( return stas -def _activation_triggered_average(model: nn.Sequential, n_batch: int = 2048, rf_size=None, device=None): + +def _activation_triggered_average( + model: nn.Sequential, n_batch: int = 2048, rf_size=None, device=None +): model.eval() if rf_size is None: _out_channels, input_size = get_input_output_shape(model) else: input_size = rf_size - input_tensor = torch.randn((n_batch, *input_size), requires_grad=False, device=device) + input_tensor = torch.randn( + (n_batch, *input_size), requires_grad=False, device=device + ) output = model(input_tensor) output = sum_collapse_output(output) input_tensor = input_tensor[:, None, :, :, :].expand( @@ -93,31 +100,50 @@ def _activation_triggered_average(model: nn.Sequential, n_batch: int = 2048, rf_ weighted = (weights * input_tensor).sum(0) return weighted.cpu().detach(), weight_sums.cpu().detach() + def activation_triggered_average( - model: nn.Sequential, n_batch: int = 2048, n_iter: int = 1, rf_size=None, device=None + model: nn.Sequential, + n_batch: int = 2048, + n_iter: int = 1, + rf_size=None, + device=None, ) -> Dict[str, NDArray[np.float64]]: # TODO: WIP warnings.warn("Code is not tested and might contain bugs.") stas: Dict[str, NDArray[np.float64]] = {} with torch.no_grad(): - for index, (layer_name, mdl) in tqdm(enumerate(model.named_children()), total=len(model)): - weighted, weight_sums = _activation_triggered_average(model[:index+1], n_batch, device=device) - for _ in tqdm(range(n_iter - 1), total=n_iter-1, leave=False): - it_weighted, it_weight_sums = _activation_triggered_average(model[:index+1], n_batch, rf_size, device=device) + for index, (layer_name, mdl) in tqdm( + enumerate(model.named_children()), total=len(model) + ): + weighted, weight_sums = _activation_triggered_average( + model[: index + 1], n_batch, device=device + ) + for _ in tqdm(range(n_iter - 1), total=n_iter - 1, leave=False): + it_weighted, it_weight_sums = _activation_triggered_average( + model[: index + 1], n_batch, rf_size, device=device + ) weighted += it_weighted weight_sums += it_weight_sums - stas[layer_name] = (weighted.cpu().detach() / weight_sums[:, None, None, None] / len(weight_sums)).numpy() + stas[layer_name] = ( + weighted.cpu().detach() + / weight_sums[:, None, None, None] + / len(weight_sums) + ).numpy() torch.cuda.empty_cache() return stas + +@deprecated("Use functions of retinal_rl.analysis.statistics") def sum_collapse_output(out_tensor): if len(out_tensor.shape) > 2: - sum_dims = [2+i for i in range(len(out_tensor.shape)-2)] + sum_dims = [2 + i for i in range(len(out_tensor.shape) - 2)] out_tensor = torch.sum(out_tensor, dim=sum_dims) return out_tensor -def _find_last_layer_shape(model: nn.Sequential) -> Tuple[int, Optional[int], Optional[int], Optional[int], bool]: +def _find_last_layer_shape( + model: nn.Sequential, +) -> Tuple[int, Optional[int], Optional[int], Optional[int], bool]: _first = 0 down_stream_linear = False num_outputs = None @@ -133,22 +159,31 @@ def _find_last_layer_shape(model: nn.Sequential) -> Tuple[int, Optional[int], Op if isinstance(layer, nn.Conv2d): num_outputs = layer.out_channels in_channels = layer.in_channels - in_size = layer.in_channels * ((layer.kernel_size[0]-1)*layer.dilation[0]+1) ** 2 + in_size = ( + layer.in_channels + * ((layer.kernel_size[0] - 1) * layer.dilation[0] + 1) ** 2 + ) break if isinstance(layer, (nn.MaxPool2d, nn.AvgPool2d)): - for prev_layer in reversed(model[:-i-1]): + for prev_layer in reversed(model[: -i - 1]): if isinstance(prev_layer, nn.Conv2d): in_channels = prev_layer.out_channels break if isinstance(prev_layer, nn.Linear): - in_channels=1 + in_channels = 1 else: raise TypeError("layer before pooling needs to be conv or linear") - _kernel_size = layer.kernel_size if isinstance(layer.kernel_size, int) else layer.kernel_size[0] + _kernel_size = ( + layer.kernel_size + if isinstance(layer.kernel_size, int) + else layer.kernel_size[0] + ) in_size = _kernel_size**2 * in_channels break return _first, num_outputs, in_size, in_channels, down_stream_linear + +@deprecated("Use functions of retinal_rl.analysis.statistics") def get_input_output_shape(model: nn.Sequential): """ Calculates the 'minimal' input and output of a sequential model. @@ -159,7 +194,9 @@ def get_input_output_shape(model: nn.Sequential): TODO: Check if still needed, function near duplicate of some of Sachas code """ - _first, num_outputs, in_size, in_channels, down_stream_linear = _find_last_layer_shape(model) + _first, num_outputs, in_size, in_channels, down_stream_linear = ( + _find_last_layer_shape(model) + ) for i, layer in enumerate(reversed(model[:-_first])): if isinstance(layer, nn.Linear): @@ -176,11 +213,11 @@ def get_input_output_shape(model: nn.Sequential): in_size = ( (in_size - 1) * layer.stride[0] - 2 * layer.padding[0] * down_stream_linear - + ((layer.kernel_size[0]-1)*layer.dilation[0]+1) + + ((layer.kernel_size[0] - 1) * layer.dilation[0] + 1) ) in_size = in_size**2 * in_channels elif isinstance(layer, (nn.MaxPool2d, nn.AvgPool2d)): - for prev_layer in reversed(model[:-i-_first-1]): + for prev_layer in reversed(model[: -i - _first - 1]): if isinstance(prev_layer, nn.Conv2d): in_channels = prev_layer.out_channels break @@ -196,6 +233,8 @@ def get_input_output_shape(model: nn.Sequential): input_size = (in_channels, in_size, in_size) return num_outputs, input_size + +@deprecated("Use functions of retinal_rl.analysis.statistics") def get_reconstructions( device: torch.device, brain: Brain,