-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flatten objectives #16
Changes from all commits
5721a8b
8e17514
bd82381
b80c0eb
f32f741
2c44b6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,41 @@ | ||
# Number of training epochs | ||
num_epochs: 100 | ||
|
||
# The optimizer to use | ||
optimizer: # torch.optim Class and parameters | ||
_target_: torch.optim.Adam | ||
lr: 0.0003 | ||
|
||
goal: | ||
recon: | ||
min_epoch: 0 # Epoch to start optimizer | ||
max_epoch: 100 # Epoch to stop optimizer | ||
losses: # Weighted optimizer losses as defined in retinal-rl | ||
- _target_: retinal_rl.models.loss.ReconstructionLoss | ||
weight: ${recon_weight_retina} | ||
- _target_: retinal_rl.classification.loss.ClassificationLoss | ||
weight: ${eval:'1-${recon_weight_retina}'} | ||
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction | ||
- retina | ||
decode: | ||
min_epoch: 0 # Epoch to start optimizer | ||
max_epoch: 100 # Epoch to stop optimizer | ||
losses: # Weighted optimizer losses as defined in retinal-rl | ||
- _target_: retinal_rl.models.loss.ReconstructionLoss | ||
weight: 1 | ||
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction | ||
- decoder | ||
- inferotemporal_decoder | ||
mixed: | ||
min_epoch: 0 | ||
max_epoch: 100 | ||
losses: | ||
- _target_: retinal_rl.models.loss.ReconstructionLoss | ||
weight: ${recon_weight_thalamus} | ||
- _target_: retinal_rl.classification.loss.ClassificationLoss | ||
weight: ${eval:'1-${recon_weight_thalamus}'} | ||
target_circuits: # The thalamus is somewhat sensitive to task losses | ||
- thalamus | ||
cortex: | ||
min_epoch: 0 | ||
max_epoch: 100 | ||
losses: | ||
- _target_: retinal_rl.models.loss.ReconstructionLoss | ||
weight: ${recon_weight_cortex} | ||
- _target_: retinal_rl.classification.loss.ClassificationLoss | ||
weight: ${eval:'1-${recon_weight_cortex}'} | ||
target_circuits: # Visual cortex and downstream layers are driven by the task | ||
- visual_cortex | ||
- inferotemporal | ||
class: | ||
min_epoch: 0 | ||
max_epoch: 100 | ||
losses: | ||
- _target_: retinal_rl.classification.loss.ClassificationLoss | ||
weight: 1 | ||
- _target_: retinal_rl.classification.loss.PercentCorrect | ||
weight: 0 | ||
target_circuits: # Visual cortex and downstream layers are driven by the task | ||
- prefrontal | ||
- classifier | ||
# The objective function | ||
objective: | ||
_target_: retinal_rl.models.objective.Objective | ||
losses: | ||
- _target_: retinal_rl.classification.loss.PercentCorrect | ||
- _target_: retinal_rl.classification.loss.ClassificationLoss | ||
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction | ||
- retina | ||
- thalamus | ||
- visual_cortex | ||
- inferotemporal | ||
- prefrontal | ||
- classifier | ||
weights: | ||
- ${eval:'1-${recon_weight_retina}'} | ||
- ${eval:'1-${recon_weight_thalamus}'} | ||
- ${eval:'1-${recon_weight_cortex}'} | ||
- 1 | ||
- 1 | ||
- 1 | ||
- _target_: retinal_rl.models.loss.ReconstructionLoss | ||
target_circuits: # Circuit parameters to optimize with this optimizer. We train the retina and the decoder exclusively to maximize reconstruction | ||
- retina | ||
- thalamus | ||
- visual_cortex | ||
- decoder | ||
- inferotemporal_decoder | ||
weights: | ||
- ${recon_weight_retina} | ||
- ${recon_weight_thalamus} | ||
- ${recon_weight_cortex} | ||
- 1 | ||
- 1 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,12 +13,13 @@ | |
from matplotlib.axes import Axes | ||
from matplotlib.figure import Figure | ||
from matplotlib.lines import Line2D | ||
from matplotlib.patches import Circle, Wedge | ||
from matplotlib.ticker import MaxNLocator | ||
from torch import Tensor | ||
from torchvision.utils import make_grid | ||
|
||
from retinal_rl.models.brain import Brain | ||
from retinal_rl.models.goal import ContextT, Goal | ||
from retinal_rl.models.objective import ContextT, Objective | ||
from retinal_rl.util import FloatArray | ||
|
||
|
||
|
@@ -107,15 +108,7 @@ def plot_transforms( | |
return fig | ||
|
||
|
||
def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure: | ||
"""Visualize the Brain's connectome organized by depth and highlight optimizer targets using border colors. | ||
|
||
Args: | ||
---- | ||
- brain: The Brain instance | ||
- brain_optimizer: The BrainOptimizer instance | ||
|
||
""" | ||
def plot_brain_and_optimizers(brain: Brain, objective: Objective[ContextT]) -> Figure: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the one line comment you added to the other functions would be nice here as well as the function name is not fully explaining what to expect There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is my strategy. I'm actually not deleting as many comments as I thought I would. I've removed the warning, but wherever a one liner of comment would be helpful I leave it/add it. |
||
graph = brain.connectome | ||
|
||
# Compute the depth of each node | ||
|
@@ -138,44 +131,102 @@ def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure: | |
pos[node] = ((i - width / 2) / (width + 1), -(max_depth - depth) / max_depth) | ||
|
||
# Set up the plot | ||
fig = plt.figure(figsize=(12, 10)) | ||
fig, ax = plt.subplots(figsize=(12, 10)) | ||
|
||
# Draw edges | ||
nx.draw_networkx_edges(graph, pos, edge_color="gray", arrows=True) | ||
nx.draw_networkx_edges(graph, pos, edge_color="gray", arrows=True, ax=ax) | ||
|
||
# Color scheme for different node types | ||
color_map = {"sensor": "lightblue", "circuit": "lightgreen"} | ||
|
||
# Generate colors for optimizers | ||
optimizer_colors = sns.color_palette("husl", len(goal.losses)) | ||
# Generate colors for losses | ||
loss_colors = sns.color_palette("husl", len(objective.losses)) | ||
|
||
# Prepare node colors and edge colors | ||
node_colors: List[str] = [] | ||
edge_colors: List[Tuple[float, float, float]] = [] | ||
# Draw nodes | ||
for node in graph.nodes(): | ||
x, y = pos[node] | ||
|
||
# Determine node type and base color | ||
if node in brain.sensors: | ||
node_colors.append(color_map["sensor"]) | ||
base_color = color_map["sensor"] | ||
else: | ||
node_colors.append(color_map["circuit"]) | ||
|
||
# Determine if the node is targeted by an optimizer | ||
edge_color = "none" | ||
for i, optimizer_name in enumerate(goal.losses.keys()): | ||
if node in goal.target_circuits[optimizer_name]: | ||
edge_color = optimizer_colors[i] | ||
break | ||
edge_colors.append(edge_color) | ||
|
||
# Draw nodes with a single call | ||
nx.draw_networkx_nodes( | ||
graph, | ||
pos, | ||
node_color=node_colors, | ||
edgecolors=edge_colors, | ||
node_size=4000, | ||
linewidths=5, | ||
base_color = color_map["circuit"] | ||
|
||
# Draw base circle | ||
circle = Circle((x, y), 0.05, facecolor=base_color, edgecolor="black") | ||
ax.add_patch(circle) | ||
|
||
# Determine which losses target this node | ||
targeting_losses = [ | ||
loss for loss in objective.losses if node in loss.target_circuits | ||
] | ||
|
||
if targeting_losses: | ||
# Calculate angle for each loss | ||
angle_per_loss = 360 / len(targeting_losses) | ||
|
||
# Draw a wedge for each targeting loss | ||
for i, loss in enumerate(targeting_losses): | ||
start_angle = i * angle_per_loss | ||
wedge = Wedge( | ||
(x, y), | ||
0.07, | ||
start_angle, | ||
start_angle + angle_per_loss, | ||
width=0.02, | ||
facecolor=loss_colors[objective.losses.index(loss)], | ||
) | ||
ax.add_patch(wedge) | ||
|
||
# Draw labels | ||
nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold", ax=ax) | ||
|
||
# Add a legend for losses | ||
legend_elements = [ | ||
Line2D( | ||
[0], | ||
[0], | ||
marker="o", | ||
color="w", | ||
label=f"Loss: {loss.__class__.__name__}", | ||
markerfacecolor=color, | ||
markersize=15, | ||
) | ||
for loss, color in zip(objective.losses, loss_colors) | ||
] | ||
|
||
# Add legend elements for sensor and circuit | ||
legend_elements.extend( | ||
[ | ||
Line2D( | ||
[0], | ||
[0], | ||
marker="o", | ||
color="w", | ||
label="Sensor", | ||
markerfacecolor=color_map["sensor"], | ||
markersize=15, | ||
), | ||
Line2D( | ||
[0], | ||
[0], | ||
marker="o", | ||
color="w", | ||
label="Circuit", | ||
markerfacecolor=color_map["circuit"], | ||
markersize=15, | ||
), | ||
] | ||
) | ||
|
||
plt.legend(handles=legend_elements, loc="center left", bbox_to_anchor=(1, 0.5)) | ||
|
||
plt.title("Brain Connectome and Loss Targets") | ||
plt.tight_layout() | ||
plt.axis("equal") | ||
plt.axis("off") | ||
|
||
return fig | ||
# Draw labels | ||
nx.draw_networkx_labels(graph, pos, font_size=8, font_weight="bold") | ||
|
||
|
@@ -192,7 +243,7 @@ def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure: | |
markersize=15, | ||
markeredgewidth=3, | ||
) | ||
for name, color in zip(goal.losses.keys(), optimizer_colors) | ||
for name, color in zip(objective.losses.keys(), optimizer_colors) | ||
] | ||
|
||
# Add legend elements for sensor and circuit | ||
|
@@ -229,13 +280,7 @@ def plot_brain_and_optimizers(brain: Brain, goal: Goal[ContextT]) -> Figure: | |
|
||
|
||
def plot_receptive_field_sizes(results: Dict[str, Dict[str, FloatArray]]) -> Figure: | ||
"""Plot the receptive field sizes for each layer of the convolutional part of the network. | ||
|
||
Args: | ||
---- | ||
- results: Dictionary containing the results from cnn_statistics function | ||
|
||
""" | ||
"""Plot the receptive field sizes for each layer of the convolutional part of the network.""" | ||
# Get visual field size from the input shape | ||
input_shape = results["input"]["shape"] | ||
[_, height, width] = list(input_shape) | ||
|
@@ -300,17 +345,7 @@ def plot_receptive_field_sizes(results: Dict[str, Dict[str, FloatArray]]) -> Fig | |
|
||
|
||
def plot_histories(histories: Dict[str, List[float]]) -> Figure: | ||
"""Plot training and test losses over epochs. | ||
|
||
Args: | ||
---- | ||
histories (Dict[str, List[float]]): Dictionary containing training and test loss histories. | ||
|
||
Returns: | ||
------- | ||
Figure: Matplotlib figure containing the plotted histories. | ||
|
||
""" | ||
"""Plot training and test losses over epochs.""" | ||
train_metrics = [ | ||
key.split("_", 1)[1] for key in histories.keys() if key.startswith("train_") | ||
] | ||
|
@@ -467,23 +502,7 @@ def plot_reconstructions( | |
test_estimates: List[Tuple[Tensor, int]], | ||
num_samples: int, | ||
) -> Figure: | ||
"""Plot original and reconstructed images for both training and test sets, including the classes. | ||
|
||
Args: | ||
---- | ||
train_sources (List[Tuple[Tensor, int]]): List of original source images and their classes. | ||
train_inputs (List[Tuple[Tensor, int]]): List of original training images and their classes. | ||
train_estimates (List[Tuple[Tensor, int]]): List of reconstructed training images and their predicted classes. | ||
test_sources (List[Tuple[Tensor, int]]): List of original source images and their classes. | ||
test_inputs (List[Tuple[Tensor, int]]): List of original test images and their classes. | ||
test_estimates (List[Tuple[Tensor, int]]): List of reconstructed test images and their predicted classes. | ||
num_samples (int): The number of samples to plot. | ||
|
||
Returns: | ||
------- | ||
Figure: The matplotlib Figure object with the plotted images. | ||
|
||
""" | ||
"""Plot original and reconstructed images for both training and test sets, including the classes.""" | ||
fig, axes = plt.subplots(6, num_samples, figsize=(15, 10)) | ||
|
||
for i in range(num_samples): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should decide on ruff vs pylint now, else we also have to maintain this rule checking for both equally etc :D
I don't have a strong opinion on it, but if ruff integrates nicer in your IDE we can use that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion either, but I think the main thing is making sure we only have to track one set of such things in ruff. Maybe I'd say if you'd like me to take care of it, then I'll manage the ruff/toml stuff. If you care about extra pylint features though, then it's up to you to maintain it in toml, so that I can then get the desired behaviour.