From 5d87dea1e3fc69216c3dce9bc5e030ce90d4ad60 Mon Sep 17 00:00:00 2001 From: alex404 Date: Thu, 31 Oct 2024 09:03:40 +0100 Subject: [PATCH 1/9] WIP. Converted return values of statistics functions to dataclasses. Now need to adapt plot and analyze. --- retinal_rl/analysis/statistics.py | 198 +++++++++++--------- runner/frameworks/classification/analyze.py | 188 +++++++++++++------ 2 files changed, 235 insertions(+), 151 deletions(-) diff --git a/retinal_rl/analysis/statistics.py b/retinal_rl/analysis/statistics.py index fae77531..d1f12d2a 100644 --- a/retinal_rl/analysis/statistics.py +++ b/retinal_rl/analysis/statistics.py @@ -1,7 +1,8 @@ """Functions for analysis and statistics on a Brain model.""" import logging -from typing import Dict, List, Tuple, cast +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, cast import numpy as np import torch @@ -21,9 +22,76 @@ logger = logging.getLogger(__name__) +### Dataclasses ### + + +@dataclass +class TransformStatistics: + """Results of applying transformations to images.""" + + source_transforms: Dict[str, Dict[float, List[Tensor]]] + noise_transforms: Dict[str, Dict[float, List[Tensor]]] + + +@dataclass +class Reconstructions: + """Set of source images, inputs, and their reconstructions.""" + + sources: List[Tuple[Tensor, int]] + inputs: List[Tuple[Tensor, int]] + estimates: List[Tuple[Tensor, int]] + + +@dataclass +class ReconstructionStatistics: + """Results of image reconstruction for both training and test sets.""" + + train: Reconstructions + test: Reconstructions + + +@dataclass +class SpectralAnalysis: + """Results of spectral analysis for a layer.""" + + mean_power_spectrum: FloatArray + var_power_spectrum: FloatArray + mean_autocorr: FloatArray + var_autocorr: FloatArray + + +@dataclass +class HistogramAnalysis: + """Results of histogram analysis for a layer.""" + + channel_histograms: FloatArray + bin_edges: FloatArray + + +@dataclass +class LayerStatistics: + """Statistics for a single layer.""" + + receptive_fields: FloatArray + num_channels: int + spectral: Optional[SpectralAnalysis] = None + histogram: Optional[HistogramAnalysis] = None + + +@dataclass +class CNNStatistics: + """Complete statistics for a CNN model.""" + + input_shape: Tuple[int, ...] # nclrs, hght, wdth + layers: Dict[str, LayerStatistics] + + +### Functions ### + + def transform_base_images( imageset: Imageset, num_steps: int, num_images: int -) -> Dict[str, Dict[str, Dict[float, List[Tensor]]]]: +) -> TransformStatistics: """Apply transformations to a set of images from an Imageset.""" images: List[Image.Image] = [] @@ -59,7 +127,7 @@ def transform_base_images( imageset.to_tensor(transform.transform(img, step)) ) - return results + return TransformStatistics(**results) def reconstruct_images( @@ -69,15 +137,13 @@ def reconstruct_images( test_set: Imageset, train_set: Imageset, sample_size: int, -) -> Dict[str, List[Tuple[Tensor, int]]]: +) -> ReconstructionStatistics: """Compute reconstructions of a set of training and test images using a Brain model.""" brain.eval() # Set the model to evaluation mode def collect_reconstructions( imageset: Imageset, sample_size: int - ) -> Tuple[ - List[Tuple[Tensor, int]], List[Tuple[Tensor, int]], List[Tuple[Tensor, int]] - ]: + ) -> Reconstructions: """Collect reconstructions for a subset of a dataset.""" source_subset: List[Tuple[Tensor, int]] = [] input_subset: List[Tuple[Tensor, int]] = [] @@ -97,24 +163,13 @@ def collect_reconstructions( input_subset.append((img.cpu(), k)) estimates.append((rec_img.cpu(), pred_k)) - return source_subset, input_subset, estimates + return Reconstructions(source_subset, input_subset, estimates) - train_source, train_input, train_estimates = collect_reconstructions( - train_set, sample_size - ) - test_source, test_input, test_estimates = collect_reconstructions( - test_set, sample_size + return ReconstructionStatistics( + collect_reconstructions(train_set, sample_size), + collect_reconstructions(test_set, sample_size), ) - return { - "train_sources": train_source, - "train_inputs": train_input, - "train_estimates": train_estimates, - "test_sources": test_source, - "test_inputs": test_input, - "test_estimates": test_estimates, - } - def cnn_statistics( device: torch.device, @@ -122,20 +177,8 @@ def cnn_statistics( brain: Brain, channel_analysis: bool, max_sample_size: int = 0, -) -> Dict[str, Dict[str, FloatArray]]: - """Compute statistics for a convolutional encoder model. - - Args: - device: The device to run computations on. - imageset: The dataset to analyze. - brain: The trained Brain model. - channel_analysis: Whether to compute channel-wise statistics (histograms, spectra). - max_sample_size: Maximum number of samples to use. If 0, use all samples. - - Returns: - A nested dictionary containing statistics for the input and each layer. - When channel_analysis is False, only receptive_fields and num_channels are computed. - """ +) -> CNNStatistics: + """Compute statistics for a convolutional encoder model.""" brain.eval() brain.to(device) input_shape, cnn_layers = get_cnn_circuit(brain) @@ -144,7 +187,7 @@ def cnn_statistics( dataloader = _prepare_dataset(imageset, max_sample_size) # Initialize results - results: Dict[str, Dict[str, FloatArray]] = { + results = { "input": _analyze_input(device, dataloader, input_shape, channel_analysis) } @@ -168,7 +211,7 @@ def cnn_statistics( device, dataloader, head_layers, input_shape, out_channels, channel_analysis ) - return results + return CNNStatistics(input_shape, results) def _prepare_dataset( @@ -225,34 +268,22 @@ def _analyze_layer( input_shape: Tuple[int, ...], out_channels: int, channel_analysis: bool = True, -) -> Dict[str, FloatArray]: +) -> LayerStatistics: """Analyze statistics for a single layer.""" head_model = nn.Sequential(*head_layers) - results: Dict[str, FloatArray] = {} # Always compute receptive fields - results["receptive_fields"] = _compute_receptive_fields( - device, head_layers, input_shape, out_channels - ) - results["num_channels"] = np.array(out_channels, dtype=np.float64) + rfs = _compute_receptive_fields(device, head_layers, input_shape, out_channels) + + layer_spectral = None + layer_histograms = None # Compute channel-wise statistics only if requested if channel_analysis: layer_spectral = _layer_spectral_analysis(device, dataloader, head_model) layer_histograms = _layer_pixel_histograms(device, dataloader, head_model) - results.update( - { - "pixel_histograms": layer_histograms["channel_histograms"], - "histogram_bin_edges": layer_histograms["bin_edges"], - "mean_power_spectrum": layer_spectral["mean_power_spectrum"], - "var_power_spectrum": layer_spectral["var_power_spectrum"], - "mean_autocorr": layer_spectral["mean_autocorr"], - "var_autocorr": layer_spectral["var_autocorr"], - } - ) - - return results + return LayerStatistics(rfs, out_channels, layer_spectral, layer_histograms) def _analyze_input( @@ -260,31 +291,22 @@ def _analyze_input( dataloader: DataLoader[Tuple[Tensor, Tensor, int]], input_shape: Tuple[int, ...], channel_analysis: bool, -) -> Dict[str, FloatArray]: +) -> LayerStatistics: """Analyze statistics for the input layer.""" - nclrs = input_shape[0] - results: Dict[str, FloatArray] = { - "receptive_fields": np.eye(nclrs)[:, :, np.newaxis, np.newaxis], - "shape": np.array(input_shape, dtype=np.float64), - "num_channels": np.array(nclrs, dtype=np.float64), - } + + input_spectral = None + input_histograms = None if channel_analysis: input_spectral = _layer_spectral_analysis(device, dataloader, nn.Identity()) input_histograms = _layer_pixel_histograms(device, dataloader, nn.Identity()) - results.update( - { - "pixel_histograms": input_histograms["channel_histograms"], - "histogram_bin_edges": input_histograms["bin_edges"], - "mean_power_spectrum": input_spectral["mean_power_spectrum"], - "var_power_spectrum": input_spectral["var_power_spectrum"], - "mean_autocorr": input_spectral["mean_autocorr"], - "var_autocorr": input_spectral["var_autocorr"], - } - ) - - return results + return LayerStatistics( + np.eye(input_shape[0])[:, :, np.newaxis, np.newaxis], + input_shape[0], + input_spectral, + input_histograms, + ) def _layer_pixel_histograms( @@ -292,7 +314,7 @@ def _layer_pixel_histograms( dataloader: DataLoader[Tuple[Tensor, Tensor, int]], model: nn.Module, num_bins: int = 20, -) -> Dict[str, FloatArray]: +) -> HistogramAnalysis: """Compute histograms of pixel/activation values for each channel across all data in an imageset.""" _, first_batch, _ = next(iter(dataloader)) with torch.no_grad(): @@ -335,19 +357,17 @@ def _layer_pixel_histograms( bin_width = (hist_range[1] - hist_range[0]) / num_bins normalized_histograms = histograms / (total_elements * bin_width / num_channels) - return { - "bin_edges": np.linspace( - hist_range[0], hist_range[1], num_bins + 1, dtype=np.float64 - ), - "channel_histograms": normalized_histograms.cpu().numpy(), - } + return HistogramAnalysis( + normalized_histograms.cpu().numpy(), + np.linspace(hist_range[0], hist_range[1], num_bins + 1, dtype=np.float64), + ) def _layer_spectral_analysis( device: torch.device, dataloader: DataLoader[Tuple[Tensor, Tensor, int]], model: nn.Module, -) -> Dict[str, FloatArray]: +) -> SpectralAnalysis: """Compute spectral analysis statistics for each channel across all data in an imageset.""" _, first_batch, _ = next(iter(dataloader)) with torch.no_grad(): @@ -392,9 +412,9 @@ def _layer_spectral_analysis( var_power_spectrum = m2_power_spectrum / count - (mean_power_spectrum / count) ** 2 var_autocorr = m2_autocorr / count - (mean_autocorr / count) ** 2 - return { - "mean_power_spectrum": mean_power_spectrum.cpu().numpy(), - "var_power_spectrum": var_power_spectrum.cpu().numpy(), - "mean_autocorr": mean_autocorr.cpu().numpy(), - "var_autocorr": var_autocorr.cpu().numpy(), - } + return SpectralAnalysis( + mean_power_spectrum.cpu().numpy(), + var_power_spectrum.cpu().numpy(), + mean_autocorr.cpu().numpy(), + var_autocorr.cpu().numpy(), + ) diff --git a/runner/frameworks/classification/analyze.py b/runner/frameworks/classification/analyze.py index 66c9781f..3553f2ad 100644 --- a/runner/frameworks/classification/analyze.py +++ b/runner/frameworks/classification/analyze.py @@ -1,9 +1,11 @@ +import json import logging import os import shutil -from typing import Dict, List +from typing import Any, Dict, List import matplotlib.pyplot as plt +import numpy as np import torch import wandb from matplotlib.figure import Figure @@ -34,60 +36,130 @@ init_dir = "initialization_analysis" -def analyze( +class NumpyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + return super().default(obj) + + +def save_statistics(cfg: DictConfig, stats: Dict[str, Any], epoch: int) -> None: + """Save statistics to a JSON file in the plot directory.""" + stats_dir = os.path.join(cfg.path.plot_dir, "statistics") + os.makedirs(stats_dir, exist_ok=True) + + filename = os.path.join(stats_dir, f"epoch_{epoch}_stats.json") + with open(filename, "w") as f: + json.dump(stats, f, cls=NumpyEncoder) + + if cfg.logging.use_wandb: + wandb.save(filename, base_path=cfg.path.plot_dir, policy="now") + + +def collect_statistics( cfg: DictConfig, device: torch.device, brain: Brain, objective: Objective[ContextT], - histories: Dict[str, List[float]], train_set: Imageset, test_set: Imageset, epoch: int, - copy_checkpoint: bool = False, -): - if not cfg.logging.use_wandb: - _plot_and_save_histories(cfg, histories) +) -> Dict[str, Any]: + """Collect all statistics without plotting.""" + stats = {} - cnn_analysis = cnn_statistics( + # Collect CNN statistics + cnn_stats = cnn_statistics( device, test_set, brain, cfg.logging.channel_analysis, cfg.logging.plot_sample_size, ) + stats["cnn_analysis"] = cnn_stats + # Collect reconstruction statistics if applicable + reconstruction_decoders = [ + loss.target_decoder + for loss in objective.losses + if isinstance(loss, ReconstructionLoss) + ] + + if reconstruction_decoders: + rec_stats = {} + for decoder in reconstruction_decoders: + rec_dict = reconstruct_images( + device, brain, decoder, train_set, test_set, 5 + ) + rec_stats[str(decoder)] = rec_dict + stats["reconstruction_analysis"] = rec_stats + + # If it's initialization epoch, collect additional stats if epoch == 0: - _perform_initialization_analysis(cfg, brain, objective, train_set, cnn_analysis) + # Save brain summary + stats["brain_summary"] = brain.scan() + + # Save transform statistics + transforms = transform_base_images(train_set, num_steps=5, num_images=2) + stats["transforms"] = transforms - _analyze_layers(cfg, cnn_analysis, epoch, copy_checkpoint) + return stats - _perform_reconstruction_analysis( - cfg, device, brain, objective, train_set, test_set, epoch, copy_checkpoint + +def analyze( + cfg: DictConfig, + device: torch.device, + brain: Brain, + objective: Objective[ContextT], + histories: Dict[str, List[float]], + train_set: Imageset, + test_set: Imageset, + epoch: int, + copy_checkpoint: bool = False, +): + # First collect all statistics + stats = collect_statistics( + cfg, device, brain, objective, train_set, test_set, epoch ) + # Save statistics to file + save_statistics(cfg, stats, epoch) -def _plot_and_save_histories(cfg: DictConfig, histories: Dict[str, List[float]]): - hist_fig = plot_histories(histories) - _save_figure(cfg, "", "histories", hist_fig) - plt.close(hist_fig) + # Plot histories if not using wandb + if not cfg.logging.use_wandb: + _plot_and_save_histories(cfg, histories) + + # Plot CNN analysis + cnn_analysis = stats["cnn_analysis"] + + if epoch == 0: + _plot_initialization_analysis(cfg, brain, objective, train_set, cnn_analysis) + + _plot_layers(cfg, cnn_analysis, epoch, copy_checkpoint) + + # Plot reconstruction analysis if available + if "reconstruction_analysis" in stats: + _plot_reconstruction_analysis( + cfg, train_set, stats["reconstruction_analysis"], epoch, copy_checkpoint + ) -def _perform_initialization_analysis( +def _plot_initialization_analysis( cfg: DictConfig, brain: Brain, objective: Objective[ContextT], train_set: Imageset, cnn_analysis: Dict[str, Dict[str, FloatArray]], ): - summary = brain.scan() + # Save brain summary to file filepath = os.path.join(cfg.path.run_dir, "brain_summary.txt") - with open(filepath, "w") as f: - f.write(summary) + f.write(brain.scan()) if cfg.logging.use_wandb: wandb.save(filepath, base_path=cfg.path.run_dir, policy="now") + # Plot various initialization analyses rf_sizes_fig = plot_receptive_field_sizes(cnn_analysis) _process_figure(cfg, False, rf_sizes_fig, init_dir, "receptive_field_sizes", 0) @@ -98,10 +170,39 @@ def _perform_initialization_analysis( transforms_fig = plot_transforms(**transforms) _process_figure(cfg, False, transforms_fig, init_dir, "transforms", 0) - _analyze_input_layer(cfg, cnn_analysis["input"], cfg.logging.channel_analysis) + _plot_input_layer(cfg, cnn_analysis["input"], cfg.logging.channel_analysis) -def _analyze_layers( +def _plot_reconstruction_analysis( + cfg: DictConfig, + train_set: Imageset, + rec_stats: Dict[str, Dict], + epoch: int, + copy_checkpoint: bool, +): + norm_means, norm_stds = train_set.normalization_stats + for decoder, rec_dict in rec_stats.items(): + recon_fig = plot_reconstructions( + norm_means, norm_stds, **rec_dict, num_samples=5 + ) + _process_figure( + cfg, + copy_checkpoint, + recon_fig, + "reconstruction", + f"{decoder}_reconstructions", + epoch, + ) + + +# Rest of the helper functions remain the same +def _plot_and_save_histories(cfg: DictConfig, histories: Dict[str, List[float]]): + hist_fig = plot_histories(histories) + _save_figure(cfg, "", "histories", hist_fig) + plt.close(hist_fig) + + +def _plot_layers( cfg: DictConfig, cnn_analysis: Dict[str, Dict[str, FloatArray]], epoch: int, @@ -109,7 +210,7 @@ def _analyze_layers( ): for layer_name, layer_data in cnn_analysis.items(): if layer_name != "input": - _analyze_regular_layer( + _plot_regular_layer( cfg, layer_name, layer_data, @@ -119,7 +220,7 @@ def _analyze_layers( ) -def _analyze_input_layer( +def _plot_input_layer( cfg: DictConfig, layer_data: Dict[str, FloatArray], channel_analysis: bool, @@ -136,7 +237,7 @@ def _analyze_input_layer( ) -def _analyze_regular_layer( +def _plot_regular_layer( cfg: DictConfig, layer_name: str, layer_data: Dict[str, FloatArray], @@ -163,38 +264,6 @@ def _analyze_regular_layer( ) -def _perform_reconstruction_analysis( - cfg: DictConfig, - device: torch.device, - brain: Brain, - objective: Objective[ContextT], - train_set: Imageset, - test_set: Imageset, - epoch: int, - copy_checkpoint: bool, -): - reconstruction_decoders = [ - loss.target_decoder - for loss in objective.losses - if isinstance(loss, ReconstructionLoss) - ] - - for decoder in reconstruction_decoders: - norm_means, norm_stds = train_set.normalization_stats - rec_dict = reconstruct_images(device, brain, decoder, train_set, test_set, 5) - recon_fig = plot_reconstructions( - norm_means, norm_stds, **rec_dict, num_samples=5 - ) - _process_figure( - cfg, - copy_checkpoint, - recon_fig, - "reconstruction", - f"{decoder}_reconstructions", - epoch, - ) - - def _save_figure(cfg: DictConfig, sub_dir: str, file_name: str, fig: Figure) -> None: dir = os.path.join(cfg.path.plot_dir, sub_dir) os.makedirs(dir, exist_ok=True) @@ -213,18 +282,13 @@ def _checkpoint_copy(cfg: DictConfig, sub_dir: str, file_name: str, epoch: int) def _wandb_title(title: str) -> str: - # Split the title by slashes parts = title.split("/") def capitalize_part(part: str) -> str: - # Split the part by dashes words = part.split("_") - # Capitalize each word capitalized_words = [word.capitalize() for word in words] - # Join the words with spaces return " ".join(capitalized_words) - # Capitalize each part, then join with slashes capitalized_parts = [capitalize_part(part) for part in parts] return "/".join(capitalized_parts) From 1810be19fb96b6cdda5ea76dcb9aa2bea607a3f1 Mon Sep 17 00:00:00 2001 From: alex404 Date: Thu, 31 Oct 2024 11:31:22 +0100 Subject: [PATCH 2/9] Simulations run and plots and data are saved. Going to refactor and refine some more. --- retinal_rl/analysis/plot.py | 120 ++++++---- retinal_rl/analysis/statistics.py | 26 +-- runner/frameworks/classification/analyze.py | 244 +++++++++----------- 3 files changed, 196 insertions(+), 194 deletions(-) diff --git a/retinal_rl/analysis/plot.py b/retinal_rl/analysis/plot.py index ff2a3d1e..ce89df1a 100644 --- a/retinal_rl/analysis/plot.py +++ b/retinal_rl/analysis/plot.py @@ -6,7 +6,6 @@ import networkx as nx import numpy as np import seaborn as sns -import torch from matplotlib import gridspec, patches from matplotlib.axes import Axes from matplotlib.figure import Figure @@ -14,8 +13,6 @@ from matplotlib.patches import Circle, Wedge from matplotlib.ticker import MaxNLocator from numpy import fft -from torch import Tensor -from torchvision.utils import make_grid from retinal_rl.models.brain import Brain from retinal_rl.models.objective import ContextT, Objective @@ -23,22 +20,18 @@ def plot_transforms( - source_transforms: Dict[str, Dict[float, List[torch.Tensor]]], - noise_transforms: Dict[str, Dict[float, List[torch.Tensor]]], + source_transforms: Dict[str, Dict[float, List[FloatArray]]], + noise_transforms: Dict[str, Dict[float, List[FloatArray]]], ) -> Figure: - """Use the result of the transform_base_images function to plot the effects of source and noise transforms on images. + """Plot effects of source and noise transforms on images. Args: - ---- - source_transforms: A dictionary of source transforms and their effects on images. - noise_transforms: A dictionary of noise transforms and their effects on images. + source_transforms: Dictionary of source transforms (numpy arrays) + noise_transforms: Dictionary of noise transforms (numpy arrays) Returns: - ------- - Figure: A matplotlib Figure containing the plotted transforms. - + Figure containing the plotted transforms """ - # Determine the number of transforms and images num_source_transforms = len(source_transforms) num_noise_transforms = len(noise_transforms) num_transforms = num_source_transforms + num_noise_transforms @@ -48,13 +41,32 @@ def plot_transforms( ] ) - # Create a figure with subplots for each transform fig, axs = plt.subplots(num_transforms, 1, figsize=(20, 5 * num_transforms)) if num_transforms == 1: axs = [axs] transform_index = 0 + def make_image_grid(arrays: List[FloatArray], nrow: int) -> FloatArray: + """Create a grid of images from a list of numpy arrays.""" + # Assuming arrays are [C, H, W] + n = len(arrays) + if not n: + return np.array([]) + + ncol = nrow + nrow = (n - 1) // ncol + 1 + + C, H, W = arrays[0].shape + grid = np.zeros((C, H * nrow, W * ncol)) + + for idx, array in enumerate(arrays): + i = idx // ncol + j = idx % ncol + grid[:, i * H : (i + 1) * H, j * W : (j + 1) * W] = array + + return grid + # Plot source transforms for transform_name, transform_data in source_transforms.items(): ax = axs[transform_index] @@ -62,19 +74,20 @@ def plot_transforms( # Create a grid of images for each step images = [ - make_grid( - torch.stack([img * 0.5 + 0.5 for img in transform_data[step]]), + make_image_grid( + [(img * 0.5 + 0.5) for img in transform_data[step]], nrow=num_images, ) for step in steps ] - grid = make_grid(images, nrow=len(steps)) + grid = make_image_grid(images, nrow=len(steps)) - # Display the grid - ax.imshow(grid.permute(1, 2, 0)) + # Move channels last for imshow + grid_display = np.transpose(grid, (1, 2, 0)) + ax.imshow(grid_display) ax.set_title(f"Source Transform: {transform_name}") ax.set_xticks( - [(i + 0.5) * grid.shape[2] / len(steps) for i in range(len(steps))] + [(i + 0.5) * grid_display.shape[1] / len(steps) for i in range(len(steps))] ) ax.set_xticklabels([f"{step:.2f}" for step in steps]) ax.set_yticks([]) @@ -88,19 +101,20 @@ def plot_transforms( # Create a grid of images for each step images = [ - make_grid( - torch.stack([img * 0.5 + 0.5 for img in transform_data[step]]), + make_image_grid( + [(img * 0.5 + 0.5) for img in transform_data[step]], nrow=num_images, ) for step in steps ] - grid = make_grid(images, nrow=len(steps)) + grid = make_image_grid(images, nrow=len(steps)) - # Display the grid - ax.imshow(grid.permute(1, 2, 0)) + # Move channels last for imshow + grid_display = np.transpose(grid, (1, 2, 0)) + ax.imshow(grid_display) ax.set_title(f"Noise Transform: {transform_name}") ax.set_xticks( - [(i + 0.5) * grid.shape[2] / len(steps) for i in range(len(steps))] + [(i + 0.5) * grid_display.shape[1] / len(steps) for i in range(len(steps))] ) ax.set_xticklabels([f"{step:.2f}" for step in steps]) ax.set_yticks([]) @@ -237,16 +251,17 @@ def plot_brain_and_optimizers(brain: Brain, objective: Objective[ContextT]) -> F return fig -def plot_receptive_field_sizes(results: Dict[str, Dict[str, FloatArray]]) -> Figure: +def plot_receptive_field_sizes( + input_shape: Tuple[int, ...], layers: Dict[str, Dict[str, FloatArray]] +) -> Figure: """Plot the receptive field sizes for each layer of the convolutional part of the network.""" # Get visual field size from the input shape - input_shape = results["input"]["shape"] [_, height, width] = list(input_shape) # Calculate receptive field sizes for each layer rf_sizes: List[Tuple[int, int]] = [] layer_names: List[str] = [] - for name, layer_data in results.items(): + for name, layer_data in layers.items(): if name == "input": continue rf = layer_data["receptive_fields"] @@ -357,22 +372,26 @@ def plot_histories(histories: Dict[str, List[float]]) -> Figure: def plot_channel_statistics( - layer_data: Dict[str, FloatArray], layer_name: str, channel: int + receptive_fields: FloatArray, + spectral: Dict[str, FloatArray], + histogram: Dict[str, FloatArray], + layer_name: str, + channel: int, ) -> Figure: """Plot receptive fields, pixel histograms, and autocorrelation plots for a single channel in a layer.""" fig, axs = plt.subplots(2, 3, figsize=(20, 10)) fig.suptitle(f"Layer: {layer_name}, Channel: {channel}", fontsize=16) # Receptive Fields - rf = layer_data["receptive_fields"][channel] + rf = receptive_fields[channel] _plot_receptive_fields(axs[0, 0], rf) axs[0, 0].set_title("Receptive Field") axs[0, 0].set_xlabel("X") axs[0, 0].set_ylabel("Y") # Pixel Histograms - hist = layer_data["pixel_histograms"][channel] - bin_edges = layer_data["histogram_bin_edges"] + hist = histogram["channel_histograms"][channel] + bin_edges = histogram["bin_edges"] axs[1, 0].bar( bin_edges[:-1], hist, @@ -387,7 +406,7 @@ def plot_channel_statistics( # Autocorrelation plots # Plot average 2D autocorrelation and variance - autocorr = fft.fftshift(layer_data["mean_autocorr"][channel]) + autocorr = fft.fftshift(spectral["mean_autocorr"][channel]) h, w = autocorr.shape extent = [-w // 2, w // 2, -h // 2, h // 2] im = axs[0, 1].imshow( @@ -399,7 +418,7 @@ def plot_channel_statistics( fig.colorbar(im, ax=axs[0, 1]) _set_integer_ticks(axs[0, 1]) - autocorr_sd = fft.fftshift(np.sqrt(layer_data["var_autocorr"][channel])) + autocorr_sd = fft.fftshift(np.sqrt(spectral["var_autocorr"][channel])) im = axs[0, 2].imshow( autocorr_sd, cmap="inferno", origin="lower", extent=extent, vmin=0 ) @@ -411,7 +430,7 @@ def plot_channel_statistics( # Plot average 2D power spectrum log_power_spectrum = fft.fftshift( - np.log1p(layer_data["mean_power_spectrum"][channel]) + np.log1p(spectral["mean_power_spectrum"][channel]) ) h, w = log_power_spectrum.shape @@ -425,7 +444,7 @@ def plot_channel_statistics( _set_integer_ticks(axs[1, 1]) log_power_spectrum_sd = fft.fftshift( - np.log1p(np.sqrt(layer_data["var_power_spectrum"][channel])) + np.log1p(np.sqrt(spectral["var_power_spectrum"][channel])) ) im = axs[1, 2].imshow( log_power_spectrum_sd, @@ -450,16 +469,15 @@ def _set_integer_ticks(ax: Axes): ax.yaxis.set_major_locator(MaxNLocator(integer=True)) -# Function to plot the original and reconstructed images def plot_reconstructions( normalization_mean: List[float], normalization_std: List[float], - train_sources: List[Tuple[Tensor, int]], - train_inputs: List[Tuple[Tensor, int]], - train_estimates: List[Tuple[Tensor, int]], - test_sources: List[Tuple[Tensor, int]], - test_inputs: List[Tuple[Tensor, int]], - test_estimates: List[Tuple[Tensor, int]], + train_sources: List[Tuple[FloatArray, int]], + train_inputs: List[Tuple[FloatArray, int]], + train_estimates: List[Tuple[FloatArray, int]], + test_sources: List[Tuple[FloatArray, int]], + test_inputs: List[Tuple[FloatArray, int]], + test_estimates: List[Tuple[FloatArray, int]], num_samples: int, ) -> Figure: """Plot original and reconstructed images for both training and test sets, including the classes.""" @@ -474,27 +492,28 @@ def plot_reconstructions( test_recon, test_pred = test_estimates[i] # Unnormalize the original images using the normalization lists + # Arrays are already [C, H, W], need to move channels to last dimension train_source = ( - train_source.permute(1, 2, 0).numpy() * normalization_std + np.transpose(train_source, (1, 2, 0)) * normalization_std + normalization_mean ) train_input = ( - train_input.permute(1, 2, 0).numpy() * normalization_std + np.transpose(train_input, (1, 2, 0)) * normalization_std + normalization_mean ) train_recon = ( - train_recon.permute(1, 2, 0).numpy() * normalization_std + np.transpose(train_recon, (1, 2, 0)) * normalization_std + normalization_mean ) test_source = ( - test_source.permute(1, 2, 0).numpy() * normalization_std + np.transpose(test_source, (1, 2, 0)) * normalization_std + normalization_mean ) test_input = ( - test_input.permute(1, 2, 0).numpy() * normalization_std + normalization_mean + np.transpose(test_input, (1, 2, 0)) * normalization_std + normalization_mean ) test_recon = ( - test_recon.permute(1, 2, 0).numpy() * normalization_std + normalization_mean + np.transpose(test_recon, (1, 2, 0)) * normalization_std + normalization_mean ) axes[0, i].imshow(np.clip(train_source, 0, 1)) @@ -522,7 +541,6 @@ def plot_reconstructions( axes[5, i].set_title(f"Pred: {test_pred}") # Set y-axis labels for each row - fig.text( 0.02, 0.90, diff --git a/retinal_rl/analysis/statistics.py b/retinal_rl/analysis/statistics.py index d1f12d2a..3aa9c59b 100644 --- a/retinal_rl/analysis/statistics.py +++ b/retinal_rl/analysis/statistics.py @@ -29,17 +29,17 @@ class TransformStatistics: """Results of applying transformations to images.""" - source_transforms: Dict[str, Dict[float, List[Tensor]]] - noise_transforms: Dict[str, Dict[float, List[Tensor]]] + source_transforms: Dict[str, Dict[float, List[FloatArray]]] + noise_transforms: Dict[str, Dict[float, List[FloatArray]]] @dataclass class Reconstructions: """Set of source images, inputs, and their reconstructions.""" - sources: List[Tuple[Tensor, int]] - inputs: List[Tuple[Tensor, int]] - estimates: List[Tuple[Tensor, int]] + sources: List[Tuple[FloatArray, int]] + inputs: List[Tuple[FloatArray, int]] + estimates: List[Tuple[FloatArray, int]] @dataclass @@ -102,7 +102,7 @@ def transform_base_images( src, _ = base_dataset[np.random.randint(base_len)] images.append(src) - results: Dict[str, Dict[str, Dict[float, List[Tensor]]]] = { + results: Dict[str, Dict[str, Dict[float, List[FloatArray]]]] = { "source_transforms": {}, "noise_transforms": {}, } @@ -124,7 +124,7 @@ def transform_base_images( results[category][transform.name][step] = [] for img in images: results[category][transform.name][step].append( - imageset.to_tensor(transform.transform(img, step)) + imageset.to_tensor(transform.transform(img, step)).cpu().numpy() ) return TransformStatistics(**results) @@ -145,9 +145,9 @@ def collect_reconstructions( imageset: Imageset, sample_size: int ) -> Reconstructions: """Collect reconstructions for a subset of a dataset.""" - source_subset: List[Tuple[Tensor, int]] = [] - input_subset: List[Tuple[Tensor, int]] = [] - estimates: List[Tuple[Tensor, int]] = [] + source_subset: List[Tuple[FloatArray, int]] = [] + input_subset: List[Tuple[FloatArray, int]] = [] + estimates: List[Tuple[FloatArray, int]] = [] indices = torch.randperm(imageset.epoch_len())[:sample_size] with torch.no_grad(): # Disable gradient computation @@ -159,9 +159,9 @@ def collect_reconstructions( response = brain(stimulus) rec_img = response[decoder].squeeze(0) pred_k = response["classifier"].argmax().item() - source_subset.append((src.cpu(), k)) - input_subset.append((img.cpu(), k)) - estimates.append((rec_img.cpu(), pred_k)) + source_subset.append((src.cpu().numpy(), k)) + input_subset.append((img.cpu().numpy(), k)) + estimates.append((rec_img.cpu().numpy(), pred_k)) return Reconstructions(source_subset, input_subset, estimates) diff --git a/runner/frameworks/classification/analyze.py b/runner/frameworks/classification/analyze.py index 3553f2ad..e2394e88 100644 --- a/runner/frameworks/classification/analyze.py +++ b/runner/frameworks/classification/analyze.py @@ -2,6 +2,8 @@ import logging import os import shutil +from dataclasses import asdict +from pathlib import Path from typing import Any, Dict, List import matplotlib.pyplot as plt @@ -21,6 +23,8 @@ plot_transforms, ) from retinal_rl.analysis.statistics import ( + CNNStatistics, + LayerStatistics, cnn_statistics, reconstruct_images, transform_base_images, @@ -29,7 +33,9 @@ from retinal_rl.models.brain import Brain from retinal_rl.models.loss import ReconstructionLoss from retinal_rl.models.objective import ContextT, Objective -from retinal_rl.util import FloatArray + +### Infrastructure ### + logger = logging.getLogger(__name__) @@ -37,38 +43,35 @@ class NumpyEncoder(json.JSONEncoder): - def default(self, obj): + """JSON encoder that handles numpy arrays.""" + + def default(self, obj: Any) -> Any: if isinstance(obj, np.ndarray): return obj.tolist() return super().default(obj) -def save_statistics(cfg: DictConfig, stats: Dict[str, Any], epoch: int) -> None: - """Save statistics to a JSON file in the plot directory.""" - stats_dir = os.path.join(cfg.path.plot_dir, "statistics") - os.makedirs(stats_dir, exist_ok=True) +### Analysis ### - filename = os.path.join(stats_dir, f"epoch_{epoch}_stats.json") - with open(filename, "w") as f: - json.dump(stats, f, cls=NumpyEncoder) - if cfg.logging.use_wandb: - wandb.save(filename, base_path=cfg.path.plot_dir, policy="now") - - -def collect_statistics( +def analyze( cfg: DictConfig, device: torch.device, brain: Brain, objective: Objective[ContextT], + histories: Dict[str, List[float]], train_set: Imageset, test_set: Imageset, epoch: int, -) -> Dict[str, Any]: - """Collect all statistics without plotting.""" - stats = {} + copy_checkpoint: bool = False, +): + results_dir = Path(cfg.path.data_dir) / "analyses" + results_dir.mkdir(exist_ok=True) + + if not cfg.logging.use_wandb: + _plot_and_save_histories(cfg, histories) - # Collect CNN statistics + # Get CNN statistics and save them cnn_stats = cnn_statistics( device, test_set, @@ -76,141 +79,69 @@ def collect_statistics( cfg.logging.channel_analysis, cfg.logging.plot_sample_size, ) - stats["cnn_analysis"] = cnn_stats - # Collect reconstruction statistics if applicable - reconstruction_decoders = [ - loss.target_decoder - for loss in objective.losses - if isinstance(loss, ReconstructionLoss) - ] - - if reconstruction_decoders: - rec_stats = {} - for decoder in reconstruction_decoders: - rec_dict = reconstruct_images( - device, brain, decoder, train_set, test_set, 5 - ) - rec_stats[str(decoder)] = rec_dict - stats["reconstruction_analysis"] = rec_stats + # Save CNN statistics + with open(results_dir / f"cnn_stats_epoch_{epoch}.json", "w") as f: + json.dump(asdict(cnn_stats), f, cls=NumpyEncoder) - # If it's initialization epoch, collect additional stats if epoch == 0: - # Save brain summary - stats["brain_summary"] = brain.scan() + _perform_initialization_analysis(cfg, brain, objective, train_set, cnn_stats) - # Save transform statistics - transforms = transform_base_images(train_set, num_steps=5, num_images=2) - stats["transforms"] = transforms + _analyze_layers(cfg, cnn_stats, epoch, copy_checkpoint) - return stats - - -def analyze( - cfg: DictConfig, - device: torch.device, - brain: Brain, - objective: Objective[ContextT], - histories: Dict[str, List[float]], - train_set: Imageset, - test_set: Imageset, - epoch: int, - copy_checkpoint: bool = False, -): - # First collect all statistics - stats = collect_statistics( - cfg, device, brain, objective, train_set, test_set, epoch + _perform_reconstruction_analysis( + cfg, device, brain, objective, train_set, test_set, epoch, copy_checkpoint ) - # Save statistics to file - save_statistics(cfg, stats, epoch) - - # Plot histories if not using wandb - if not cfg.logging.use_wandb: - _plot_and_save_histories(cfg, histories) - - # Plot CNN analysis - cnn_analysis = stats["cnn_analysis"] - - if epoch == 0: - _plot_initialization_analysis(cfg, brain, objective, train_set, cnn_analysis) - _plot_layers(cfg, cnn_analysis, epoch, copy_checkpoint) - - # Plot reconstruction analysis if available - if "reconstruction_analysis" in stats: - _plot_reconstruction_analysis( - cfg, train_set, stats["reconstruction_analysis"], epoch, copy_checkpoint - ) +def _plot_and_save_histories(cfg: DictConfig, histories: Dict[str, List[float]]): + hist_fig = plot_histories(histories) + _save_figure(cfg, "", "histories", hist_fig) + plt.close(hist_fig) -def _plot_initialization_analysis( +def _perform_initialization_analysis( cfg: DictConfig, brain: Brain, objective: Objective[ContextT], train_set: Imageset, - cnn_analysis: Dict[str, Dict[str, FloatArray]], + cnn_stats: CNNStatistics, ): - # Save brain summary to file - filepath = os.path.join(cfg.path.run_dir, "brain_summary.txt") - with open(filepath, "w") as f: - f.write(brain.scan()) + summary = brain.scan() + filepath = Path(cfg.path.run_dir) / "brain_summary.txt" + filepath.write_text(summary) if cfg.logging.use_wandb: - wandb.save(filepath, base_path=cfg.path.run_dir, policy="now") + wandb.save(str(filepath), base_path=cfg.path.run_dir, policy="now") - # Plot various initialization analyses - rf_sizes_fig = plot_receptive_field_sizes(cnn_analysis) + # TODO: This is a bit of a hack, we should refactor this to get the relevant information out of cnn_stats + rf_sizes_fig = plot_receptive_field_sizes(**asdict(cnn_stats)) _process_figure(cfg, False, rf_sizes_fig, init_dir, "receptive_field_sizes", 0) graph_fig = plot_brain_and_optimizers(brain, objective) _process_figure(cfg, False, graph_fig, init_dir, "brain_graph", 0) transforms = transform_base_images(train_set, num_steps=5, num_images=2) - transforms_fig = plot_transforms(**transforms) - _process_figure(cfg, False, transforms_fig, init_dir, "transforms", 0) - - _plot_input_layer(cfg, cnn_analysis["input"], cfg.logging.channel_analysis) - - -def _plot_reconstruction_analysis( - cfg: DictConfig, - train_set: Imageset, - rec_stats: Dict[str, Dict], - epoch: int, - copy_checkpoint: bool, -): - norm_means, norm_stds = train_set.normalization_stats - for decoder, rec_dict in rec_stats.items(): - recon_fig = plot_reconstructions( - norm_means, norm_stds, **rec_dict, num_samples=5 - ) - _process_figure( - cfg, - copy_checkpoint, - recon_fig, - "reconstruction", - f"{decoder}_reconstructions", - epoch, - ) + # Save transform statistics + transform_path = Path(cfg.path.run_dir) / "results" / "transforms.json" + with open(transform_path, "w") as f: + json.dump(asdict(transforms), f, cls=NumpyEncoder) + transforms_fig = plot_transforms(**asdict(transforms)) + _process_figure(cfg, False, transforms_fig, init_dir, "transforms", 0) -# Rest of the helper functions remain the same -def _plot_and_save_histories(cfg: DictConfig, histories: Dict[str, List[float]]): - hist_fig = plot_histories(histories) - _save_figure(cfg, "", "histories", hist_fig) - plt.close(hist_fig) + _analyze_input_layer(cfg, cnn_stats.layers["input"], cfg.logging.channel_analysis) -def _plot_layers( +def _analyze_layers( cfg: DictConfig, - cnn_analysis: Dict[str, Dict[str, FloatArray]], + cnn_stats: CNNStatistics, epoch: int, copy_checkpoint: bool, ): - for layer_name, layer_data in cnn_analysis.items(): + for layer_name, layer_data in cnn_stats.layers.items(): if layer_name != "input": - _plot_regular_layer( + _analyze_regular_layer( cfg, layer_name, layer_data, @@ -220,40 +151,47 @@ def _plot_layers( ) -def _plot_input_layer( +def _analyze_input_layer( cfg: DictConfig, - layer_data: Dict[str, FloatArray], + layer_statistics: LayerStatistics, channel_analysis: bool, ): - layer_rfs = layer_receptive_field_plots(layer_data["receptive_fields"]) + layer_rfs = layer_receptive_field_plots(layer_statistics.receptive_fields) _process_figure(cfg, False, layer_rfs, init_dir, "input_rfs", 0) if channel_analysis: - num_channels = int(layer_data["num_channels"]) + layer_dict = asdict(layer_statistics) + num_channels = int(layer_dict.pop("num_channels")) for channel in range(num_channels): - channel_fig = plot_channel_statistics(layer_data, "input", channel) + channel_fig = plot_channel_statistics( + **layer_dict, layer_name="input", channel=channel + ) _process_figure( cfg, False, channel_fig, init_dir, f"input_channel_{channel}", 0 ) -def _plot_regular_layer( +def _analyze_regular_layer( cfg: DictConfig, layer_name: str, - layer_data: Dict[str, FloatArray], + layer_statistics: LayerStatistics, epoch: int, copy_checkpoint: bool, channel_analysis: bool, ): - layer_rfs = layer_receptive_field_plots(layer_data["receptive_fields"]) + layer_rfs = layer_receptive_field_plots(layer_statistics.receptive_fields) _process_figure( cfg, copy_checkpoint, layer_rfs, "receptive_fields", f"{layer_name}", epoch ) if channel_analysis: - num_channels = int(layer_data["num_channels"]) + layer_dict = asdict(layer_statistics) + num_channels = int(layer_dict.pop("num_channels")) for channel in range(num_channels): - channel_fig = plot_channel_statistics(layer_data, layer_name, channel) + channel_fig = plot_channel_statistics( + **layer_dict, layer_name=layer_name, channel=channel + ) + _process_figure( cfg, copy_checkpoint, @@ -264,6 +202,47 @@ def _plot_regular_layer( ) +def _perform_reconstruction_analysis( + cfg: DictConfig, + device: torch.device, + brain: Brain, + objective: Objective[ContextT], + train_set: Imageset, + test_set: Imageset, + epoch: int, + copy_checkpoint: bool, +): + reconstruction_decoders = [ + loss.target_decoder + for loss in objective.losses + if isinstance(loss, ReconstructionLoss) + ] + + for decoder in reconstruction_decoders: + norm_means, norm_stds = train_set.normalization_stats + rec_dict = asdict( + reconstruct_images(device, brain, decoder, train_set, test_set, 5) + ) + recon_fig = plot_reconstructions( + norm_means, + norm_stds, + *rec_dict["train"].values(), + *rec_dict["test"].values(), + num_samples=5, + ) + _process_figure( + cfg, + copy_checkpoint, + recon_fig, + "reconstruction", + f"{decoder}_reconstructions", + epoch, + ) + + +### Helper Functions ### + + def _save_figure(cfg: DictConfig, sub_dir: str, file_name: str, fig: Figure) -> None: dir = os.path.join(cfg.path.plot_dir, sub_dir) os.makedirs(dir, exist_ok=True) @@ -282,13 +261,18 @@ def _checkpoint_copy(cfg: DictConfig, sub_dir: str, file_name: str, epoch: int) def _wandb_title(title: str) -> str: + # Split the title by slashes parts = title.split("/") def capitalize_part(part: str) -> str: + # Split the part by dashes words = part.split("_") + # Capitalize each word capitalized_words = [word.capitalize() for word in words] + # Join the words with spaces return " ".join(capitalized_words) + # Capitalize each part, then join with slashes capitalized_parts = [capitalize_part(part) for part in parts] return "/".join(capitalized_parts) From f16ccc8ed359a0ea28e1253c04ac3d6d0ccd8738 Mon Sep 17 00:00:00 2001 From: alex404 Date: Thu, 31 Oct 2024 12:26:13 +0100 Subject: [PATCH 3/9] Moved initialize.py over to pathlib style. Some more refinement, but things are looking much better imho --- retinal_rl/analysis/plot.py | 6 +- runner/frameworks/classification/analyze.py | 224 ++++++++++++++---- .../frameworks/classification/initialize.py | 110 ++++++--- runner/util.py | 14 +- 4 files changed, 265 insertions(+), 89 deletions(-) diff --git a/retinal_rl/analysis/plot.py b/retinal_rl/analysis/plot.py index ce89df1a..ff25e690 100644 --- a/retinal_rl/analysis/plot.py +++ b/retinal_rl/analysis/plot.py @@ -57,13 +57,13 @@ def make_image_grid(arrays: List[FloatArray], nrow: int) -> FloatArray: ncol = nrow nrow = (n - 1) // ncol + 1 - C, H, W = arrays[0].shape - grid = np.zeros((C, H * nrow, W * ncol)) + nchns, hght, wdth = arrays[0].shape + grid = np.zeros((nchns, hght * nrow, wdth * ncol)) for idx, array in enumerate(arrays): i = idx // ncol j = idx % ncol - grid[:, i * H : (i + 1) * H, j * W : (j + 1) * W] = array + grid[:, i * hght : (i + 1) * hght, j * wdth : (j + 1) * wdth] = array return grid diff --git a/runner/frameworks/classification/analyze.py b/runner/frameworks/classification/analyze.py index e2394e88..5fb0e792 100644 --- a/runner/frameworks/classification/analyze.py +++ b/runner/frameworks/classification/analyze.py @@ -1,6 +1,5 @@ import json import logging -import os import shutil from dataclasses import asdict from pathlib import Path @@ -65,76 +64,168 @@ def analyze( epoch: int, copy_checkpoint: bool = False, ): - results_dir = Path(cfg.path.data_dir) / "analyses" - results_dir.mkdir(exist_ok=True) + ## DictConfig - if not cfg.logging.use_wandb: - _plot_and_save_histories(cfg, histories) + # Path creation + run_dir = Path(cfg.path.run_dir) + run_dir.mkdir(exist_ok=True) + + plot_dir = Path(cfg.path.plot_dir) + plot_dir.mkdir(exist_ok=True) + + checkpoint_plot_dir = Path(cfg.path.checkpoint_plot_dir) + checkpoint_plot_dir.mkdir(exist_ok=True) + + analyses_dir = Path(cfg.path.data_dir) / "analyses" + analyses_dir.mkdir(exist_ok=True) + + # Variables + use_wandb = cfg.logging.use_wandb + channel_analysis = cfg.logging.channel_analysis + plot_sample_size = cfg.logging.plot_sample_size + + ## Analysis + + if not use_wandb: + _plot_and_save_histories(plot_dir, histories) # Get CNN statistics and save them cnn_stats = cnn_statistics( device, test_set, brain, - cfg.logging.channel_analysis, - cfg.logging.plot_sample_size, + channel_analysis, + plot_sample_size, ) # Save CNN statistics - with open(results_dir / f"cnn_stats_epoch_{epoch}.json", "w") as f: + with open(analyses_dir / f"cnn_stats_epoch_{epoch}.json", "w") as f: json.dump(asdict(cnn_stats), f, cls=NumpyEncoder) if epoch == 0: - _perform_initialization_analysis(cfg, brain, objective, train_set, cnn_stats) + _perform_initialization_analysis( + channel_analysis, + analyses_dir, + use_wandb, + plot_dir, + checkpoint_plot_dir, + run_dir, + brain, + objective, + train_set, + cnn_stats, + ) - _analyze_layers(cfg, cnn_stats, epoch, copy_checkpoint) + _analyze_layers( + channel_analysis, + use_wandb, + plot_dir, + checkpoint_plot_dir, + cnn_stats, + epoch, + copy_checkpoint, + ) _perform_reconstruction_analysis( - cfg, device, brain, objective, train_set, test_set, epoch, copy_checkpoint + use_wandb, + plot_dir, + checkpoint_plot_dir, + device, + brain, + objective, + train_set, + test_set, + epoch, + copy_checkpoint, ) + hist_fig = plot_histories(histories) + _save_figure(plot_dir, "", "histories", hist_fig) + plt.close(hist_fig) + -def _plot_and_save_histories(cfg: DictConfig, histories: Dict[str, List[float]]): +def _plot_and_save_histories(plot_dir: Path, histories: Dict[str, List[float]]): hist_fig = plot_histories(histories) - _save_figure(cfg, "", "histories", hist_fig) + _save_figure(plot_dir, "", "histories", hist_fig) plt.close(hist_fig) def _perform_initialization_analysis( - cfg: DictConfig, + channel_analysis: bool, + analyses_dir: Path, + use_wandb: bool, + plot_dir: Path, + checkpoint_plot_dir: Path, + run_dir: Path, brain: Brain, objective: Objective[ContextT], train_set: Imageset, cnn_stats: CNNStatistics, ): summary = brain.scan() - filepath = Path(cfg.path.run_dir) / "brain_summary.txt" + filepath = run_dir / "brain_summary.txt" filepath.write_text(summary) - if cfg.logging.use_wandb: - wandb.save(str(filepath), base_path=cfg.path.run_dir, policy="now") + if use_wandb: + wandb.save(str(filepath), base_path=run_dir, policy="now") # TODO: This is a bit of a hack, we should refactor this to get the relevant information out of cnn_stats rf_sizes_fig = plot_receptive_field_sizes(**asdict(cnn_stats)) - _process_figure(cfg, False, rf_sizes_fig, init_dir, "receptive_field_sizes", 0) + _process_figure( + use_wandb, + plot_dir, + checkpoint_plot_dir, + False, + rf_sizes_fig, + init_dir, + "receptive_field_sizes", + 0, + ) graph_fig = plot_brain_and_optimizers(brain, objective) - _process_figure(cfg, False, graph_fig, init_dir, "brain_graph", 0) + _process_figure( + use_wandb, + plot_dir, + checkpoint_plot_dir, + False, + graph_fig, + init_dir, + "brain_graph", + 0, + ) transforms = transform_base_images(train_set, num_steps=5, num_images=2) # Save transform statistics - transform_path = Path(cfg.path.run_dir) / "results" / "transforms.json" + transform_path = analyses_dir / "transforms.json" with open(transform_path, "w") as f: json.dump(asdict(transforms), f, cls=NumpyEncoder) transforms_fig = plot_transforms(**asdict(transforms)) - _process_figure(cfg, False, transforms_fig, init_dir, "transforms", 0) + _process_figure( + use_wandb, + plot_dir, + checkpoint_plot_dir, + False, + transforms_fig, + init_dir, + "transforms", + 0, + ) - _analyze_input_layer(cfg, cnn_stats.layers["input"], cfg.logging.channel_analysis) + _analyze_input_layer( + use_wandb, + plot_dir, + checkpoint_plot_dir, + cnn_stats.layers["input"], + channel_analysis, + ) def _analyze_layers( - cfg: DictConfig, + channel_analysis: bool, + use_wandb: bool, + plot_dir: Path, + checkpoint_plot_dir: Path, cnn_stats: CNNStatistics, epoch: int, copy_checkpoint: bool, @@ -142,22 +233,35 @@ def _analyze_layers( for layer_name, layer_data in cnn_stats.layers.items(): if layer_name != "input": _analyze_regular_layer( - cfg, + use_wandb, + plot_dir, + checkpoint_plot_dir, layer_name, layer_data, epoch, copy_checkpoint, - cfg.logging.channel_analysis, + channel_analysis, ) def _analyze_input_layer( - cfg: DictConfig, + use_wandb: bool, + plot_dir: Path, + checkpoint_plot_dir: Path, layer_statistics: LayerStatistics, channel_analysis: bool, ): layer_rfs = layer_receptive_field_plots(layer_statistics.receptive_fields) - _process_figure(cfg, False, layer_rfs, init_dir, "input_rfs", 0) + _process_figure( + use_wandb, + plot_dir, + checkpoint_plot_dir, + False, + layer_rfs, + init_dir, + "input_rfs", + 0, + ) if channel_analysis: layer_dict = asdict(layer_statistics) @@ -167,12 +271,21 @@ def _analyze_input_layer( **layer_dict, layer_name="input", channel=channel ) _process_figure( - cfg, False, channel_fig, init_dir, f"input_channel_{channel}", 0 + use_wandb, + plot_dir, + checkpoint_plot_dir, + False, + channel_fig, + init_dir, + f"input_channel_{channel}", + 0, ) def _analyze_regular_layer( - cfg: DictConfig, + use_wandb: bool, + plot_dir: Path, + checkpoint_plot_dir: Path, layer_name: str, layer_statistics: LayerStatistics, epoch: int, @@ -181,7 +294,14 @@ def _analyze_regular_layer( ): layer_rfs = layer_receptive_field_plots(layer_statistics.receptive_fields) _process_figure( - cfg, copy_checkpoint, layer_rfs, "receptive_fields", f"{layer_name}", epoch + use_wandb, + plot_dir, + checkpoint_plot_dir, + copy_checkpoint, + layer_rfs, + "receptive_fields", + f"{layer_name}", + epoch, ) if channel_analysis: @@ -193,7 +313,9 @@ def _analyze_regular_layer( ) _process_figure( - cfg, + use_wandb, + plot_dir, + checkpoint_plot_dir, copy_checkpoint, channel_fig, f"{layer_name}_layer_channel_analysis", @@ -203,7 +325,9 @@ def _analyze_regular_layer( def _perform_reconstruction_analysis( - cfg: DictConfig, + use_wandb: bool, + plot_dir: Path, + checkpoint_plot_dir: Path, device: torch.device, brain: Brain, objective: Objective[ContextT], @@ -231,7 +355,9 @@ def _perform_reconstruction_analysis( num_samples=5, ) _process_figure( - cfg, + use_wandb, + plot_dir, + checkpoint_plot_dir, copy_checkpoint, recon_fig, "reconstruction", @@ -243,19 +369,21 @@ def _perform_reconstruction_analysis( ### Helper Functions ### -def _save_figure(cfg: DictConfig, sub_dir: str, file_name: str, fig: Figure) -> None: - dir = os.path.join(cfg.path.plot_dir, sub_dir) - os.makedirs(dir, exist_ok=True) - file_name = os.path.join(dir, f"{file_name}.png") - fig.savefig(file_name) +def _save_figure(plot_dir: Path, sub_dir: str, file_name: str, fig: Figure) -> None: + dir = plot_dir / sub_dir + dir.mkdir(exist_ok=True) + file_path = dir / f"{file_name}.png" + fig.savefig(file_path) -def _checkpoint_copy(cfg: DictConfig, sub_dir: str, file_name: str, epoch: int) -> None: - src_path = os.path.join(cfg.path.plot_dir, sub_dir, f"{file_name}.png") +def _checkpoint_copy( + plot_dir: Path, checkpoint_plot_dir: Path, sub_dir: str, file_name: str, epoch: int +) -> None: + src_path = plot_dir / sub_dir / f"{file_name}.png" - dest_dir = os.path.join(cfg.path.checkpoint_plot_dir, f"epoch_{epoch}", sub_dir) - os.makedirs(dest_dir, exist_ok=True) - dest_path = os.path.join(dest_dir, f"{file_name}.png") + dest_dir = checkpoint_plot_dir / f"epoch_{epoch}" / sub_dir + dest_dir.mkdir(parents=True, exist_ok=True) + dest_path = dest_dir / f"{file_name}.png" shutil.copy2(src_path, dest_path) @@ -278,19 +406,21 @@ def capitalize_part(part: str) -> str: def _process_figure( - cfg: DictConfig, + use_wandb: bool, + plot_dir: Path, + checkpoint_plot_dir: Path, copy_checkpoint: bool, fig: Figure, sub_dir: str, file_name: str, epoch: int, ) -> None: - if cfg.logging.use_wandb: + if use_wandb: title = f"{_wandb_title(sub_dir)}/{_wandb_title(file_name)}" img = wandb.Image(fig) wandb.log({title: img}, commit=False) else: - _save_figure(cfg, sub_dir, file_name, fig) + _save_figure(plot_dir, sub_dir, file_name, fig) if copy_checkpoint: - _checkpoint_copy(cfg, sub_dir, file_name, epoch) + _checkpoint_copy(plot_dir, checkpoint_plot_dir, sub_dir, file_name, epoch) plt.close(fig) diff --git a/runner/frameworks/classification/initialize.py b/runner/frameworks/classification/initialize.py index 5956c267..d3203d68 100644 --- a/runner/frameworks/classification/initialize.py +++ b/runner/frameworks/classification/initialize.py @@ -2,7 +2,9 @@ ### Imports ### import logging -import os +from dataclasses import dataclass +from os import getenv +from pathlib import Path from typing import Any, Dict, List, Tuple, cast import omegaconf @@ -15,68 +17,112 @@ from retinal_rl.models.brain import Brain from runner.util import save_checkpoint +### Infrastructure ### + + # Initialize the logger logger = logging.getLogger(__name__) +@dataclass +class InitConfig: + """Configuration for initialization.""" + + # Paths + data_dir: Path + checkpoint_dir: Path + plot_dir: Path + wandb_dir: Path + + # WandB settings + use_wandb: bool + wandb_project: str + wandb_entity: str | None + wandb_preempt: bool + + # Run settings + run_name: str + max_checkpoints: int + + @classmethod + def from_dict_config(cls, cfg: DictConfig) -> "InitConfig": + """Create InitConfig from a DictConfig.""" + return cls( + data_dir=Path(cfg.path.data_dir), + checkpoint_dir=Path(cfg.path.checkpoint_dir), + plot_dir=Path(cfg.path.plot_dir), + wandb_dir=Path(cfg.path.wandb_dir), + use_wandb=cfg.logging.use_wandb, + wandb_project=cfg.logging.wandb_project, + wandb_entity=None + if cfg.logging.wandb_entity == "default" + else cfg.logging.wandb_entity, + wandb_preempt=cfg.logging.wandb_preempt, + run_name=cfg.run_name, + max_checkpoints=cfg.logging.max_checkpoints, + ) + + +### Initialization ### + + def initialize( - cfg: DictConfig, + dict_cfg: DictConfig, brain: Brain, optimizer: Optimizer, ) -> Tuple[Brain, Optimizer, Dict[str, List[float]], int]: """Initialize the Brain, Optimizers, and training histories. Checks whether the experiment directory exists and loads the model and history if it does. Otherwise, initializes a new model and history.""" - wandb_sweep_id = os.getenv("WANDB_SWEEP_ID", "local") + + cfg = InitConfig.from_dict_config(dict_cfg) + wandb_sweep_id = getenv("WANDB_SWEEP_ID", "local") logger.info(f"Run Name: {cfg.run_name}") logger.info(f"(WANDB) Sweep ID: {wandb_sweep_id}") # If continuing from a previous run, load the model and history - if os.path.exists(cfg.path.data_dir): + if cfg.data_dir.exists(): return _initialize_reload(cfg, brain, optimizer) # else, initialize a new model and history + logger.info( + f"Experiment data path {cfg.data_dir} does not exist. Initializing {cfg.run_name}." + ) return _initialize_create(cfg, brain, optimizer) def _initialize_create( - cfg: DictConfig, + cfg: InitConfig, brain: Brain, optimizer: Optimizer, ) -> Tuple[Brain, Optimizer, Dict[str, List[float]], int]: epoch = 0 - logger.info( - f"Experiment path {cfg.path.run_dir} does not exist. Initializing {cfg.run_name}." - ) - # initialize the training histories histories: Dict[str, List[float]] = {} - # create the directories - os.makedirs(cfg.path.data_dir) - os.makedirs(cfg.path.checkpoint_dir) - if not cfg.logging.use_wandb: - os.makedirs(cfg.path.plot_dir) - + cfg.data_dir.mkdir(parents=True, exist_ok=True) + cfg.checkpoint_dir.mkdir(parents=True, exist_ok=True) + if not cfg.use_wandb: + cfg.plot_dir.mkdir(parents=True, exist_ok=True) else: - os.makedirs(cfg.path.wandb_dir) + cfg.wandb_dir.mkdir(parents=True, exist_ok=True) # convert DictConfig to dict dict_conf = omegaconf.OmegaConf.to_container( cfg, resolve=True, throw_on_missing=True ) dict_conf = cast(Dict[str, Any], dict_conf) - entity = cfg.logging.wandb_entity + entity = cfg.wandb_entity if entity == "default": entity = None wandb.init( - project=cfg.logging.wandb_project, + project=cfg.wandb_project, entity=entity, group=HydraConfig.get().runtime.choices.experiment, job_type=HydraConfig.get().runtime.choices.brain, config=dict_conf, name=cfg.run_name, id=cfg.run_name, - dir=cfg.path.wandb_dir, + dir=cfg.wandb_dir, ) - if cfg.logging.wandb_preempt: + if cfg.wandb_preempt: wandb.mark_preempting() wandb.define_metric("Epoch") @@ -84,9 +130,9 @@ def _initialize_create( wandb.define_metric("Test/*", step_metric="Epoch") save_checkpoint( - cfg.path.data_dir, - cfg.path.checkpoint_dir, - cfg.logging.max_checkpoints, + cfg.data_dir, + cfg.checkpoint_dir, + cfg.max_checkpoints, brain, optimizer, histories, @@ -97,15 +143,15 @@ def _initialize_create( def _initialize_reload( - cfg: DictConfig, brain: Brain, optimizer: Optimizer + cfg: InitConfig, brain: Brain, optimizer: Optimizer ) -> Tuple[Brain, Optimizer, Dict[str, List[float]], int]: logger.info( - f"Experiment dir {cfg.path.run_dir} exists. Loading existing model and history." + f"Experiment data dir {cfg.data_dir} exists. Loading existing model and history." ) - checkpoint_file = os.path.join(cfg.path.data_dir, "current_checkpoint.pt") + checkpoint_file = cfg.data_dir / "current_checkpoint.pt" # check if files don't exist - if not os.path.exists(checkpoint_file): + if not checkpoint_file.exists(): logger.error(f"File not found: {checkpoint_file}") raise FileNotFoundError("Checkpoint file does not exist.") @@ -115,22 +161,22 @@ def _initialize_reload( optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) completed_epochs = checkpoint["completed_epochs"] history = checkpoint["histories"] - entity = cfg.logging.wandb_entity + entity = cfg.wandb_entity if entity == "default": entity = None - if cfg.logging.use_wandb: + if cfg.use_wandb: wandb.init( - project=cfg.logging.wandb_project, + project=cfg.wandb_project, entity=entity, group=HydraConfig.get().runtime.choices.experiment, job_type=HydraConfig.get().runtime.choices.brain, name=cfg.run_name, id=cfg.run_name, resume="must", - dir=cfg.path.wandb_dir, + dir=cfg.wandb_dir, ) - if cfg.logging.wandb_preempt: + if cfg.wandb_preempt: wandb.mark_preempting() return brain, optimizer, history, completed_epochs diff --git a/runner/util.py b/runner/util.py index fbfea1b8..1ae4e4f0 100644 --- a/runner/util.py +++ b/runner/util.py @@ -5,6 +5,7 @@ import logging import os import shutil +from pathlib import Path from typing import Any, Dict, List, Tuple, cast import networkx as nx @@ -25,8 +26,8 @@ def save_checkpoint( - data_dir: str, - checkpoint_dir: str, + data_dir: Path, + checkpoint_dir: Path, max_checkpoints: int, brain: nn.Module, optimizer: Optimizer, @@ -34,8 +35,8 @@ def save_checkpoint( completed_epochs: int, ) -> None: """Save a checkpoint of the model and optimizer state.""" - current_file = os.path.join(data_dir, "current_checkpoint.pt") - checkpoint_file = os.path.join(checkpoint_dir, f"epoch_{completed_epochs}.pt") + current_file = data_dir / "current_checkpoint.pt" + checkpoint_file = checkpoint_dir / f"epoch_{completed_epochs}.pt" checkpoint_dict: Dict[str, Any] = { "completed_epochs": completed_epochs, "brain_state_dict": brain.state_dict(), @@ -59,11 +60,10 @@ def save_checkpoint( os.remove(os.path.join(checkpoint_dir, checkpoints.pop())) -def delete_results(cfg: DictConfig) -> None: +def delete_results(run_dir: Path) -> None: """Delete the results directory.""" - run_dir: str = cfg.path.run_dir - if not os.path.exists(run_dir): + if not run_dir.exists(): print(f"Directory {run_dir} does not exist.") return From 2df03373731deddb5d00bf915386dde1912286e8 Mon Sep 17 00:00:00 2001 From: alex404 Date: Thu, 31 Oct 2024 12:36:24 +0100 Subject: [PATCH 4/9] Added necessary pathlibbing to train --- runner/frameworks/classification/train.py | 31 ++++++++++++++++------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/runner/frameworks/classification/train.py b/runner/frameworks/classification/train.py index dcf187c6..1c52195e 100644 --- a/runner/frameworks/classification/train.py +++ b/runner/frameworks/classification/train.py @@ -2,6 +2,7 @@ import logging import time +from pathlib import Path from typing import Dict, List import torch @@ -47,11 +48,23 @@ def train( history (Dict[str, List[float]]): The training history. """ + + use_wandb = cfg.logging.use_wandb + + data_dir = Path(cfg.path.data_dir) + checkpoint_dir = Path(cfg.path.checkpoint_dir) + + max_checkpoints = cfg.logging.max_checkpoints + checkpoint_step = cfg.logging.checkpoint_step + + num_epochs = cfg.optimizer.num_epochs + num_workers = cfg.system.num_workers + trainloader = DataLoader( - train_set, batch_size=64, shuffle=True, num_workers=cfg.system.num_workers + train_set, batch_size=64, shuffle=True, num_workers=num_workers ) testloader = DataLoader( - test_set, batch_size=64, shuffle=False, num_workers=cfg.system.num_workers + test_set, batch_size=64, shuffle=False, num_workers=num_workers ) wall_time = time.time() @@ -103,7 +116,7 @@ def train( wall_time = new_wall_time logger.info(f"Initialization complete. Wall Time: {epoch_wall_time:.2f}s.") - if cfg.logging.use_wandb: + if use_wandb: _wandb_log_statistics(initial_epoch, epoch_wall_time, history) else: @@ -111,7 +124,7 @@ def train( f"Reloading complete. Resuming training from epoch {initial_epoch}." ) - for epoch in range(initial_epoch + 1, cfg.optimizer.num_epochs + 1): + for epoch in range(initial_epoch + 1, num_epochs + 1): brain, history = run_epoch( device, brain, @@ -128,13 +141,13 @@ def train( wall_time = new_wall_time logger.info(f"Epoch {epoch} complete. Wall Time: {epoch_wall_time:.2f}s.") - if epoch % cfg.logging.checkpoint_step == 0: + if epoch % checkpoint_step == 0: logger.info("Saving checkpoint and plots.") save_checkpoint( - cfg.path.data_dir, - cfg.path.checkpoint_dir, - cfg.logging.max_checkpoints, + data_dir, + checkpoint_dir, + max_checkpoints, brain, optimizer, history, @@ -154,7 +167,7 @@ def train( ) logger.info("Analysis complete.") - if cfg.logging.use_wandb: + if use_wandb: _wandb_log_statistics(epoch, epoch_wall_time, history) From 661f78f0f29c663945fa16fafca360d98d3de678 Mon Sep 17 00:00:00 2001 From: alex404 Date: Thu, 31 Oct 2024 13:07:59 +0100 Subject: [PATCH 5/9] Fixed wandb initialization --- runner/frameworks/classification/initialize.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/runner/frameworks/classification/initialize.py b/runner/frameworks/classification/initialize.py index d3203d68..4213eec1 100644 --- a/runner/frameworks/classification/initialize.py +++ b/runner/frameworks/classification/initialize.py @@ -85,11 +85,18 @@ def initialize( logger.info( f"Experiment data path {cfg.data_dir} does not exist. Initializing {cfg.run_name}." ) - return _initialize_create(cfg, brain, optimizer) + + cfg_backup = omegaconf.OmegaConf.to_container( + dict_cfg, resolve=True, throw_on_missing=True + ) + cfg_backup = cast(Dict[str, Any], cfg_backup) + + return _initialize_create(cfg, cfg_backup, brain, optimizer) def _initialize_create( cfg: InitConfig, + cfg_backup: dict[Any, Any], brain: Brain, optimizer: Optimizer, ) -> Tuple[Brain, Optimizer, Dict[str, List[float]], int]: @@ -104,10 +111,6 @@ def _initialize_create( else: cfg.wandb_dir.mkdir(parents=True, exist_ok=True) # convert DictConfig to dict - dict_conf = omegaconf.OmegaConf.to_container( - cfg, resolve=True, throw_on_missing=True - ) - dict_conf = cast(Dict[str, Any], dict_conf) entity = cfg.wandb_entity if entity == "default": entity = None @@ -116,7 +119,7 @@ def _initialize_create( entity=entity, group=HydraConfig.get().runtime.choices.experiment, job_type=HydraConfig.get().runtime.choices.brain, - config=dict_conf, + config=cfg_backup, name=cfg.run_name, id=cfg.run_name, dir=cfg.wandb_dir, From f95d7e1d5679ca3428a930bfaec990d6a0bc9798 Mon Sep 17 00:00:00 2001 From: alex404 Date: Thu, 31 Oct 2024 15:08:32 +0100 Subject: [PATCH 6/9] Forgot to save reconstructions too --- runner/frameworks/classification/analyze.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/runner/frameworks/classification/analyze.py b/runner/frameworks/classification/analyze.py index 5fb0e792..1e0b86dc 100644 --- a/runner/frameworks/classification/analyze.py +++ b/runner/frameworks/classification/analyze.py @@ -347,6 +347,11 @@ def _perform_reconstruction_analysis( rec_dict = asdict( reconstruct_images(device, brain, decoder, train_set, test_set, 5) ) + # Save the reconstructions + rec_path = plot_dir / f"{decoder}_reconstructions_epoch_{epoch}.json" + with open(rec_path, "w") as f: + json.dump(rec_dict, f, cls=NumpyEncoder) + recon_fig = plot_reconstructions( norm_means, norm_stds, From cddd8981cd8f64d78b06d4f129ce2a2de73527ec Mon Sep 17 00:00:00 2001 From: alex404 Date: Thu, 31 Oct 2024 15:14:49 +0100 Subject: [PATCH 7/9] For real this time --- runner/frameworks/classification/analyze.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/runner/frameworks/classification/analyze.py b/runner/frameworks/classification/analyze.py index 1e0b86dc..013ba78a 100644 --- a/runner/frameworks/classification/analyze.py +++ b/runner/frameworks/classification/analyze.py @@ -105,8 +105,8 @@ def analyze( if epoch == 0: _perform_initialization_analysis( channel_analysis, - analyses_dir, use_wandb, + analyses_dir, plot_dir, checkpoint_plot_dir, run_dir, @@ -128,6 +128,7 @@ def analyze( _perform_reconstruction_analysis( use_wandb, + analyses_dir, plot_dir, checkpoint_plot_dir, device, @@ -152,8 +153,8 @@ def _plot_and_save_histories(plot_dir: Path, histories: Dict[str, List[float]]): def _perform_initialization_analysis( channel_analysis: bool, - analyses_dir: Path, use_wandb: bool, + analyses_dir: Path, plot_dir: Path, checkpoint_plot_dir: Path, run_dir: Path, @@ -326,6 +327,7 @@ def _analyze_regular_layer( def _perform_reconstruction_analysis( use_wandb: bool, + analyses_dir: Path, plot_dir: Path, checkpoint_plot_dir: Path, device: torch.device, @@ -348,7 +350,7 @@ def _perform_reconstruction_analysis( reconstruct_images(device, brain, decoder, train_set, test_set, 5) ) # Save the reconstructions - rec_path = plot_dir / f"{decoder}_reconstructions_epoch_{epoch}.json" + rec_path = analyses_dir / f"{decoder}_reconstructions_epoch_{epoch}.json" with open(rec_path, "w") as f: json.dump(rec_dict, f, cls=NumpyEncoder) From cf9714812c5a4ade407adc2fb013efa646a34d17 Mon Sep 17 00:00:00 2001 From: Sacha Sokoloski Date: Wed, 6 Nov 2024 11:54:51 +0100 Subject: [PATCH 8/9] Review changes --- retinal_rl/analysis/plot.py | 41 ++++++++++++++++--------------- retinal_rl/analysis/statistics.py | 36 +++++++++++++++------------ 2 files changed, 41 insertions(+), 36 deletions(-) diff --git a/retinal_rl/analysis/plot.py b/retinal_rl/analysis/plot.py index ff25e690..6341a337 100644 --- a/retinal_rl/analysis/plot.py +++ b/retinal_rl/analysis/plot.py @@ -19,6 +19,27 @@ from retinal_rl.util import FloatArray +def make_image_grid(arrays: List[FloatArray], nrow: int) -> FloatArray: + """Create a grid of images from a list of numpy arrays.""" + # Assuming arrays are [C, H, W] + n = len(arrays) + if not n: + return np.array([]) + + ncol = nrow + nrow = (n - 1) // ncol + 1 + + nchns, hght, wdth = arrays[0].shape + grid = np.zeros((nchns, hght * nrow, wdth * ncol)) + + for idx, array in enumerate(arrays): + i = idx // ncol + j = idx % ncol + grid[:, i * hght : (i + 1) * hght, j * wdth : (j + 1) * wdth] = array + + return grid + + def plot_transforms( source_transforms: Dict[str, Dict[float, List[FloatArray]]], noise_transforms: Dict[str, Dict[float, List[FloatArray]]], @@ -47,26 +68,6 @@ def plot_transforms( transform_index = 0 - def make_image_grid(arrays: List[FloatArray], nrow: int) -> FloatArray: - """Create a grid of images from a list of numpy arrays.""" - # Assuming arrays are [C, H, W] - n = len(arrays) - if not n: - return np.array([]) - - ncol = nrow - nrow = (n - 1) // ncol + 1 - - nchns, hght, wdth = arrays[0].shape - grid = np.zeros((nchns, hght * nrow, wdth * ncol)) - - for idx, array in enumerate(arrays): - i = idx // ncol - j = idx % ncol - grid[:, i * hght : (i + 1) * hght, j * wdth : (j + 1) * wdth] = array - - return grid - # Plot source transforms for transform_name, transform_data in source_transforms.items(): ax = axs[transform_index] diff --git a/retinal_rl/analysis/statistics.py b/retinal_rl/analysis/statistics.py index 3aa9c59b..5daf8cce 100644 --- a/retinal_rl/analysis/statistics.py +++ b/retinal_rl/analysis/statistics.py @@ -102,32 +102,36 @@ def transform_base_images( src, _ = base_dataset[np.random.randint(base_len)] images.append(src) - results: Dict[str, Dict[str, Dict[float, List[FloatArray]]]] = { - "source_transforms": {}, - "noise_transforms": {}, - } + results = TransformStatistics( + source_transforms={}, + noise_transforms={}, + ) - transforms: List[Tuple[str, nn.Module]] = [] - transforms += [ - ("source_transforms", transform) for transform in imageset.source_transforms - ] - transforms += [ - ("noise_transforms", transform) for transform in imageset.noise_transforms - ] + for transform in imageset.source_transforms: + if isinstance(transform, ContinuousTransform): + results.source_transforms[transform.name] = {} + trans_range: Tuple[float, float] = transform.trans_range + transform_steps = np.linspace(*trans_range, num_steps) + for step in transform_steps: + results.source_transforms[transform.name][step] = [] + for img in images: + results.source_transforms[transform.name][step].append( + imageset.to_tensor(transform.transform(img, step)).cpu().numpy() + ) - for category, transform in transforms: + for transform in imageset.noise_transforms: if isinstance(transform, ContinuousTransform): - results[category][transform.name] = {} + results.noise_transforms[transform.name] = {} trans_range: Tuple[float, float] = transform.trans_range transform_steps = np.linspace(*trans_range, num_steps) for step in transform_steps: - results[category][transform.name][step] = [] + results.noise_transforms[transform.name][step] = [] for img in images: - results[category][transform.name][step].append( + results.noise_transforms[transform.name][step].append( imageset.to_tensor(transform.transform(img, step)).cpu().numpy() ) - return TransformStatistics(**results) + return results def reconstruct_images( From 96e44c3200cec62e7be06da866f8b8312871e925 Mon Sep 17 00:00:00 2001 From: Sacha Sokoloski Date: Thu, 7 Nov 2024 16:49:29 +0100 Subject: [PATCH 9/9] I think this is correct, but need to check on a machine with a gpu --- retinal_rl/analysis/statistics.py | 46 ++++++++++++++----------------- 1 file changed, 20 insertions(+), 26 deletions(-) diff --git a/retinal_rl/analysis/statistics.py b/retinal_rl/analysis/statistics.py index 5daf8cce..ca71acc0 100644 --- a/retinal_rl/analysis/statistics.py +++ b/retinal_rl/analysis/statistics.py @@ -102,36 +102,30 @@ def transform_base_images( src, _ = base_dataset[np.random.randint(base_len)] images.append(src) - results = TransformStatistics( + resultss = TransformStatistics( source_transforms={}, noise_transforms={}, ) - for transform in imageset.source_transforms: - if isinstance(transform, ContinuousTransform): - results.source_transforms[transform.name] = {} - trans_range: Tuple[float, float] = transform.trans_range - transform_steps = np.linspace(*trans_range, num_steps) - for step in transform_steps: - results.source_transforms[transform.name][step] = [] - for img in images: - results.source_transforms[transform.name][step].append( - imageset.to_tensor(transform.transform(img, step)).cpu().numpy() - ) - - for transform in imageset.noise_transforms: - if isinstance(transform, ContinuousTransform): - results.noise_transforms[transform.name] = {} - trans_range: Tuple[float, float] = transform.trans_range - transform_steps = np.linspace(*trans_range, num_steps) - for step in transform_steps: - results.noise_transforms[transform.name][step] = [] - for img in images: - results.noise_transforms[transform.name][step].append( - imageset.to_tensor(transform.transform(img, step)).cpu().numpy() - ) - - return results + for transforms, results in [ + (imageset.source_transforms, resultss.source_transforms), + (imageset.noise_transforms, resultss.noise_transforms), + ]: + for transform in transforms: + if isinstance(transform, ContinuousTransform): + results[transform.name] = {} + trans_range: Tuple[float, float] = transform.trans_range + transform_steps = np.linspace(*trans_range, num_steps) + for step in transform_steps: + results[transform.name][step] = [] + for img in images: + results[transform.name][step].append( + imageset.to_tensor(transform.transform(img, step)) + .cpu() + .numpy() + ) + + return resultss def reconstruct_images(