Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flatten objectives #16

Merged
merged 6 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

from retinal_rl.classification.loss import ClassificationContext
from retinal_rl.framework_interface import TrainingFramework
from retinal_rl.models.brain import Brain
from retinal_rl.models.goal import Goal
from retinal_rl.rl.sample_factory.sf_framework import SFFramework
from runner.analyze import analyze
from runner.dataset import get_datasets
Expand Down Expand Up @@ -40,8 +38,8 @@ def _program(cfg: DictConfig):

brain = Brain(**cfg.brain).to(device)
if hasattr(cfg, "optimizer"):
goal = Goal[ClassificationContext](brain, dict(cfg.optimizer.goal))
optimizer = instantiate(cfg.optimizer.optimizer, brain.parameters())
objective = instantiate(cfg.optimizer.objective, brain=brain)
else:
warnings.warn("No optimizer config specified, is that wanted?")

Expand All @@ -68,7 +66,7 @@ def _program(cfg: DictConfig):
cfg,
device,
brain,
goal,
objective,
optimizer,
train_set,
test_set,
Expand All @@ -82,7 +80,7 @@ def _program(cfg: DictConfig):
cfg,
device,
brain,
goal,
objective,
histories,
train_set,
test_set,
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should decide on ruff vs pylint now, else we also have to maintain this rule checking for both equally etc :D
I don't have a strong opinion on it, but if ruff integrates nicer in your IDE we can use that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion either, but I think the main thing is making sure we only have to track one set of such things in ruff. Maybe I'd say if you'd like me to take care of it, then I'll manage the ruff/toml stuff. If you care about extra pylint features though, then it's up to you to maintain it in toml, so that I can then get the desired behaviour.

Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ select = [
"I", # Import conventions
]

ignore = ["E501"] # Example: Ignore line length warnings
ignore = [
"E501", # Example: Ignore line length warnings
"D", # Ignore all docstring-related warnings
]

[tool.ruff.format]
docstring-code-format = true
Expand Down
89 changes: 37 additions & 52 deletions resources/config_templates/user/optimizer/class-recon.yaml
Original file line number Diff line number Diff line change
@@ -1,56 +1,41 @@
# Number of training epochs
num_epochs: 100

# The optimizer to use
optimizer: # torch.optim Class and parameters
_target_: torch.optim.Adam
lr: 0.0003

goal:
recon:
min_epoch: 0 # Epoch to start optimizer
max_epoch: 100 # Epoch to stop optimizer
losses: # Weighted optimizer losses as defined in retinal-rl
- _target_: retinal_rl.models.loss.ReconstructionLoss
weight: ${recon_weight_retina}
- _target_: retinal_rl.classification.loss.ClassificationLoss
weight: ${eval:'1-${recon_weight_retina}'}
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction
- retina
decode:
min_epoch: 0 # Epoch to start optimizer
max_epoch: 100 # Epoch to stop optimizer
losses: # Weighted optimizer losses as defined in retinal-rl
- _target_: retinal_rl.models.loss.ReconstructionLoss
weight: 1
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction
- decoder
- inferotemporal_decoder
mixed:
min_epoch: 0
max_epoch: 100
losses:
- _target_: retinal_rl.models.loss.ReconstructionLoss
weight: ${recon_weight_thalamus}
- _target_: retinal_rl.classification.loss.ClassificationLoss
weight: ${eval:'1-${recon_weight_thalamus}'}
target_circuits: # The thalamus is somewhat sensitive to task losses
- thalamus
cortex:
min_epoch: 0
max_epoch: 100
losses:
- _target_: retinal_rl.models.loss.ReconstructionLoss
weight: ${recon_weight_cortex}
- _target_: retinal_rl.classification.loss.ClassificationLoss
weight: ${eval:'1-${recon_weight_cortex}'}
target_circuits: # Visual cortex and downstream layers are driven by the task
- visual_cortex
- inferotemporal
class:
min_epoch: 0
max_epoch: 100
losses:
- _target_: retinal_rl.classification.loss.ClassificationLoss
weight: 1
- _target_: retinal_rl.classification.loss.PercentCorrect
weight: 0
target_circuits: # Visual cortex and downstream layers are driven by the task
- prefrontal
- classifier
# The objective function
objective:
_target_: retinal_rl.models.objective.Objective
losses:
- _target_: retinal_rl.classification.loss.PercentCorrect
- _target_: retinal_rl.classification.loss.ClassificationLoss
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction
- retina
- thalamus
- visual_cortex
- inferotemporal
- prefrontal
- classifier
weights:
- ${eval:'1-${recon_weight_retina}'}
- ${eval:'1-${recon_weight_thalamus}'}
- ${eval:'1-${recon_weight_cortex}'}
- 1
- 1
- 1
- _target_: retinal_rl.models.loss.ReconstructionLoss
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction
- retina
- thalamus
- visual_cortex
- decoder
- inferotemporal_decoder
weights:
- ${recon_weight_retina}
- ${recon_weight_thalamus}
- ${recon_weight_cortex}
- 1
- 1
163 changes: 91 additions & 72 deletions retinal_rl/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
from matplotlib.patches import Circle, Wedge
from matplotlib.ticker import MaxNLocator
from torch import Tensor
from torchvision.utils import make_grid

from retinal_rl.models.brain import Brain
from retinal_rl.models.goal import ContextT, Goal
from retinal_rl.models.objective import ContextT, Objective
from retinal_rl.util import FloatArray


Expand Down Expand Up @@ -107,15 +108,7 @@ def plot_transforms(
return fig


def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure:
"""Visualize the Brain's connectome organized by depth and highlight optimizer targets using border colors.

Args:
----
- brain: The Brain instance
- brain_optimizer: The BrainOptimizer instance

"""
def plot_brain_and_optimizers(brain: Brain, objective: Objective[ContextT]) -> Figure:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the one line comment you added to the other functions would be nice here as well as the function name is not fully explaining what to expect

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my strategy. I'm actually not deleting as many comments as I thought I would. I've removed the warning, but wherever a one liner of comment would be helpful I leave it/add it.

graph = brain.connectome

# Compute the depth of each node
Expand All @@ -138,44 +131,102 @@ def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure:
pos[node] = ((i - width / 2) / (width + 1), -(max_depth - depth) / max_depth)

# Set up the plot
fig = plt.figure(figsize=(12, 10))
fig, ax = plt.subplots(figsize=(12, 10))

# Draw edges
nx.draw_networkx_edges(graph, pos, edge_color="gray", arrows=True)
nx.draw_networkx_edges(graph, pos, edge_color="gray", arrows=True, ax=ax)

# Color scheme for different node types
color_map = {"sensor": "lightblue", "circuit": "lightgreen"}

# Generate colors for optimizers
optimizer_colors = sns.color_palette("husl", len(goal.losses))
# Generate colors for losses
loss_colors = sns.color_palette("husl", len(objective.losses))

# Prepare node colors and edge colors
node_colors: List[str] = []
edge_colors: List[Tuple[float, float, float]] = []
# Draw nodes
for node in graph.nodes():
x, y = pos[node]

# Determine node type and base color
if node in brain.sensors:
node_colors.append(color_map["sensor"])
base_color = color_map["sensor"]
else:
node_colors.append(color_map["circuit"])

# Determine if the node is targeted by an optimizer
edge_color = "none"
for i, optimizer_name in enumerate(goal.losses.keys()):
if node in goal.target_circuits[optimizer_name]:
edge_color = optimizer_colors[i]
break
edge_colors.append(edge_color)

# Draw nodes with a single call
nx.draw_networkx_nodes(
graph,
pos,
node_color=node_colors,
edgecolors=edge_colors,
node_size=4000,
linewidths=5,
base_color = color_map["circuit"]

# Draw base circle
circle = Circle((x, y), 0.05, facecolor=base_color, edgecolor="black")
ax.add_patch(circle)

# Determine which losses target this node
targeting_losses = [
loss for loss in objective.losses if node in loss.target_circuits
]

if targeting_losses:
# Calculate angle for each loss
angle_per_loss = 360 / len(targeting_losses)

# Draw a wedge for each targeting loss
for i, loss in enumerate(targeting_losses):
start_angle = i * angle_per_loss
wedge = Wedge(
(x, y),
0.07,
start_angle,
start_angle + angle_per_loss,
width=0.02,
facecolor=loss_colors[objective.losses.index(loss)],
)
ax.add_patch(wedge)

# Draw labels
nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold", ax=ax)

# Add a legend for losses
legend_elements = [
Line2D(
[0],
[0],
marker="o",
color="w",
label=f"Loss: {loss.__class__.__name__}",
markerfacecolor=color,
markersize=15,
)
for loss, color in zip(objective.losses, loss_colors)
]

# Add legend elements for sensor and circuit
legend_elements.extend(
[
Line2D(
[0],
[0],
marker="o",
color="w",
label="Sensor",
markerfacecolor=color_map["sensor"],
markersize=15,
),
Line2D(
[0],
[0],
marker="o",
color="w",
label="Circuit",
markerfacecolor=color_map["circuit"],
markersize=15,
),
]
)

plt.legend(handles=legend_elements, loc="center left", bbox_to_anchor=(1, 0.5))

plt.title("Brain Connectome and Loss Targets")
plt.tight_layout()
plt.axis("equal")
plt.axis("off")

return fig
# Draw labels
nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold")

Expand All @@ -192,7 +243,7 @@ def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure:
markersize=15,
markeredgewidth=3,
)
for name, color in zip(goal.losses.keys(), optimizer_colors)
for name, color in zip(objective.losses.keys(), optimizer_colors)
]

# Add legend elements for sensor and circuit
Expand Down Expand Up @@ -229,13 +280,7 @@ def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure:


def plot_receptive_field_sizes(results: Dict[str, Dict[str, FloatArray]]) -> Figure:
"""Plot the receptive field sizes for each layer of the convolutional part of the network.

Args:
----
- results: Dictionary containing the results from cnn_statistics function

"""
"""Plot the receptive field sizes for each layer of the convolutional part of the network."""
# Get visual field size from the input shape
input_shape = results["input"]["shape"]
[_, height, width] = list(input_shape)
Expand Down Expand Up @@ -300,17 +345,7 @@ def plot_receptive_field_sizes(results: Dict[str, Dict[str, FloatArray]]) -> Fig


def plot_histories(histories: Dict[str, List[float]]) -> Figure:
"""Plot training and test losses over epochs.

Args:
----
histories (Dict[str, List[float]]): Dictionary containing training and test loss histories.

Returns:
-------
Figure: Matplotlib figure containing the plotted histories.

"""
"""Plot training and test losses over epochs."""
train_metrics = [
key.split("_", 1)[1] for key in histories.keys() if key.startswith("train_")
]
Expand Down Expand Up @@ -467,23 +502,7 @@ def plot_reconstructions(
test_estimates: List[Tuple[Tensor, int]],
num_samples: int,
) -> Figure:
"""Plot original and reconstructed images for both training and test sets, including the classes.

Args:
----
train_sources (List[Tuple[Tensor, int]]): List of original source images and their classes.
train_inputs (List[Tuple[Tensor, int]]): List of original training images and their classes.
train_estimates (List[Tuple[Tensor, int]]): List of reconstructed training images and their predicted classes.
test_sources (List[Tuple[Tensor, int]]): List of original source images and their classes.
test_inputs (List[Tuple[Tensor, int]]): List of original test images and their classes.
test_estimates (List[Tuple[Tensor, int]]): List of reconstructed test images and their predicted classes.
num_samples (int): The number of samples to plot.

Returns:
-------
Figure: The matplotlib Figure object with the plotted images.

"""
"""Plot original and reconstructed images for both training and test sets, including the classes."""
fig, axes = plt.subplots(6, num_samples, figsize=(15, 10))

for i in range(num_samples):
Expand Down
Loading
Loading