Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paths and file management #61

Merged
merged 10 commits into from
Nov 8, 2024
121 changes: 70 additions & 51 deletions retinal_rl/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,53 @@
import networkx as nx
import numpy as np
import seaborn as sns
import torch
from matplotlib import gridspec, patches
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
from matplotlib.patches import Circle, Wedge
from matplotlib.ticker import MaxNLocator
from numpy import fft
from torch import Tensor
from torchvision.utils import make_grid

from retinal_rl.models.brain import Brain
from retinal_rl.models.objective import ContextT, Objective
from retinal_rl.util import FloatArray


def make_image_grid(arrays: List[FloatArray], nrow: int) -> FloatArray:
"""Create a grid of images from a list of numpy arrays."""
# Assuming arrays are [C, H, W]
n = len(arrays)
if not n:
return np.array([])

ncol = nrow
nrow = (n - 1) // ncol + 1

nchns, hght, wdth = arrays[0].shape
grid = np.zeros((nchns, hght * nrow, wdth * ncol))

for idx, array in enumerate(arrays):
i = idx // ncol
j = idx % ncol
grid[:, i * hght : (i + 1) * hght, j * wdth : (j + 1) * wdth] = array

return grid


def plot_transforms(
source_transforms: Dict[str, Dict[float, List[torch.Tensor]]],
noise_transforms: Dict[str, Dict[float, List[torch.Tensor]]],
source_transforms: Dict[str, Dict[float, List[FloatArray]]],
noise_transforms: Dict[str, Dict[float, List[FloatArray]]],
) -> Figure:
"""Use the result of the transform_base_images function to plot the effects of source and noise transforms on images.
"""Plot effects of source and noise transforms on images.

Args:
----
source_transforms: A dictionary of source transforms and their effects on images.
noise_transforms: A dictionary of noise transforms and their effects on images.
source_transforms: Dictionary of source transforms (numpy arrays)
noise_transforms: Dictionary of noise transforms (numpy arrays)

Returns:
-------
Figure: A matplotlib Figure containing the plotted transforms.

Figure containing the plotted transforms
"""
# Determine the number of transforms and images
num_source_transforms = len(source_transforms)
num_noise_transforms = len(noise_transforms)
num_transforms = num_source_transforms + num_noise_transforms
Expand All @@ -48,7 +62,6 @@ def plot_transforms(
]
)

# Create a figure with subplots for each transform
fig, axs = plt.subplots(num_transforms, 1, figsize=(20, 5 * num_transforms))
if num_transforms == 1:
axs = [axs]
Expand All @@ -62,19 +75,20 @@ def plot_transforms(

# Create a grid of images for each step
images = [
make_grid(
torch.stack([img * 0.5 + 0.5 for img in transform_data[step]]),
make_image_grid(
[(img * 0.5 + 0.5) for img in transform_data[step]],
nrow=num_images,
)
for step in steps
]
grid = make_grid(images, nrow=len(steps))
grid = make_image_grid(images, nrow=len(steps))

# Display the grid
ax.imshow(grid.permute(1, 2, 0))
# Move channels last for imshow
grid_display = np.transpose(grid, (1, 2, 0))
ax.imshow(grid_display)
ax.set_title(f"Source Transform: {transform_name}")
ax.set_xticks(
[(i + 0.5) * grid.shape[2] / len(steps) for i in range(len(steps))]
[(i + 0.5) * grid_display.shape[1] / len(steps) for i in range(len(steps))]
)
ax.set_xticklabels([f"{step:.2f}" for step in steps])
ax.set_yticks([])
Expand All @@ -88,19 +102,20 @@ def plot_transforms(

# Create a grid of images for each step
images = [
make_grid(
torch.stack([img * 0.5 + 0.5 for img in transform_data[step]]),
make_image_grid(
[(img * 0.5 + 0.5) for img in transform_data[step]],
nrow=num_images,
)
for step in steps
]
grid = make_grid(images, nrow=len(steps))
grid = make_image_grid(images, nrow=len(steps))

# Display the grid
ax.imshow(grid.permute(1, 2, 0))
# Move channels last for imshow
grid_display = np.transpose(grid, (1, 2, 0))
ax.imshow(grid_display)
ax.set_title(f"Noise Transform: {transform_name}")
ax.set_xticks(
[(i + 0.5) * grid.shape[2] / len(steps) for i in range(len(steps))]
[(i + 0.5) * grid_display.shape[1] / len(steps) for i in range(len(steps))]
)
ax.set_xticklabels([f"{step:.2f}" for step in steps])
ax.set_yticks([])
Expand Down Expand Up @@ -237,16 +252,17 @@ def plot_brain_and_optimizers(brain: Brain, objective: Objective[ContextT]) -> F
return fig


def plot_receptive_field_sizes(results: Dict[str, Dict[str, FloatArray]]) -> Figure:
def plot_receptive_field_sizes(
input_shape: Tuple[int, ...], layers: Dict[str, Dict[str, FloatArray]]
) -> Figure:
"""Plot the receptive field sizes for each layer of the convolutional part of the network."""
# Get visual field size from the input shape
input_shape = results["input"]["shape"]
[_, height, width] = list(input_shape)

# Calculate receptive field sizes for each layer
rf_sizes: List[Tuple[int, int]] = []
layer_names: List[str] = []
for name, layer_data in results.items():
for name, layer_data in layers.items():
if name == "input":
continue
rf = layer_data["receptive_fields"]
Expand Down Expand Up @@ -357,22 +373,26 @@ def plot_histories(histories: Dict[str, List[float]]) -> Figure:


def plot_channel_statistics(
layer_data: Dict[str, FloatArray], layer_name: str, channel: int
receptive_fields: FloatArray,
spectral: Dict[str, FloatArray],
histogram: Dict[str, FloatArray],
layer_name: str,
channel: int,
) -> Figure:
"""Plot receptive fields, pixel histograms, and autocorrelation plots for a single channel in a layer."""
fig, axs = plt.subplots(2, 3, figsize=(20, 10))
fig.suptitle(f"Layer: {layer_name}, Channel: {channel}", fontsize=16)

# Receptive Fields
rf = layer_data["receptive_fields"][channel]
rf = receptive_fields[channel]
_plot_receptive_fields(axs[0, 0], rf)
axs[0, 0].set_title("Receptive Field")
axs[0, 0].set_xlabel("X")
axs[0, 0].set_ylabel("Y")

# Pixel Histograms
hist = layer_data["pixel_histograms"][channel]
bin_edges = layer_data["histogram_bin_edges"]
hist = histogram["channel_histograms"][channel]
bin_edges = histogram["bin_edges"]
axs[1, 0].bar(
bin_edges[:-1],
hist,
Expand All @@ -387,7 +407,7 @@ def plot_channel_statistics(

# Autocorrelation plots
# Plot average 2D autocorrelation and variance
autocorr = fft.fftshift(layer_data["mean_autocorr"][channel])
autocorr = fft.fftshift(spectral["mean_autocorr"][channel])
h, w = autocorr.shape
extent = [-w // 2, w // 2, -h // 2, h // 2]
im = axs[0, 1].imshow(
Expand All @@ -399,7 +419,7 @@ def plot_channel_statistics(
fig.colorbar(im, ax=axs[0, 1])
_set_integer_ticks(axs[0, 1])

autocorr_sd = fft.fftshift(np.sqrt(layer_data["var_autocorr"][channel]))
autocorr_sd = fft.fftshift(np.sqrt(spectral["var_autocorr"][channel]))
im = axs[0, 2].imshow(
autocorr_sd, cmap="inferno", origin="lower", extent=extent, vmin=0
)
Expand All @@ -411,7 +431,7 @@ def plot_channel_statistics(

# Plot average 2D power spectrum
log_power_spectrum = fft.fftshift(
np.log1p(layer_data["mean_power_spectrum"][channel])
np.log1p(spectral["mean_power_spectrum"][channel])
)
h, w = log_power_spectrum.shape

Expand All @@ -425,7 +445,7 @@ def plot_channel_statistics(
_set_integer_ticks(axs[1, 1])

log_power_spectrum_sd = fft.fftshift(
np.log1p(np.sqrt(layer_data["var_power_spectrum"][channel]))
np.log1p(np.sqrt(spectral["var_power_spectrum"][channel]))
)
im = axs[1, 2].imshow(
log_power_spectrum_sd,
Expand All @@ -450,16 +470,15 @@ def _set_integer_ticks(ax: Axes):
ax.yaxis.set_major_locator(MaxNLocator(integer=True))


# Function to plot the original and reconstructed images
def plot_reconstructions(
normalization_mean: List[float],
normalization_std: List[float],
train_sources: List[Tuple[Tensor, int]],
train_inputs: List[Tuple[Tensor, int]],
train_estimates: List[Tuple[Tensor, int]],
test_sources: List[Tuple[Tensor, int]],
test_inputs: List[Tuple[Tensor, int]],
test_estimates: List[Tuple[Tensor, int]],
train_sources: List[Tuple[FloatArray, int]],
train_inputs: List[Tuple[FloatArray, int]],
train_estimates: List[Tuple[FloatArray, int]],
test_sources: List[Tuple[FloatArray, int]],
test_inputs: List[Tuple[FloatArray, int]],
test_estimates: List[Tuple[FloatArray, int]],
num_samples: int,
) -> Figure:
"""Plot original and reconstructed images for both training and test sets, including the classes."""
Expand All @@ -474,27 +493,28 @@ def plot_reconstructions(
test_recon, test_pred = test_estimates[i]

# Unnormalize the original images using the normalization lists
# Arrays are already [C, H, W], need to move channels to last dimension
train_source = (
train_source.permute(1, 2, 0).numpy() * normalization_std
np.transpose(train_source, (1, 2, 0)) * normalization_std
+ normalization_mean
)
train_input = (
train_input.permute(1, 2, 0).numpy() * normalization_std
np.transpose(train_input, (1, 2, 0)) * normalization_std
+ normalization_mean
)
train_recon = (
train_recon.permute(1, 2, 0).numpy() * normalization_std
np.transpose(train_recon, (1, 2, 0)) * normalization_std
+ normalization_mean
)
test_source = (
test_source.permute(1, 2, 0).numpy() * normalization_std
np.transpose(test_source, (1, 2, 0)) * normalization_std
+ normalization_mean
)
test_input = (
test_input.permute(1, 2, 0).numpy() * normalization_std + normalization_mean
np.transpose(test_input, (1, 2, 0)) * normalization_std + normalization_mean
)
test_recon = (
test_recon.permute(1, 2, 0).numpy() * normalization_std + normalization_mean
np.transpose(test_recon, (1, 2, 0)) * normalization_std + normalization_mean
)

axes[0, i].imshow(np.clip(train_source, 0, 1))
Expand Down Expand Up @@ -522,7 +542,6 @@ def plot_reconstructions(
axes[5, i].set_title(f"Pred: {test_pred}")

# Set y-axis labels for each row

fig.text(
0.02,
0.90,
Expand Down
Loading