Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor core function, Introduce test suite, pylint config and CI #2

Merged
merged 6 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: pylint

on:
pull_request:
workflow_dispatch:

jobs:
checks:
runs-on: ubuntu-20.04
strategy:
max-parallel: 4
matrix:
python-version: [3.7, 3.9]

steps:
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install tox
- name: Check lint
run: tox -e py$(echo ${{ matrix.python-version }} | tr -d .)-lint
26 changes: 26 additions & 0 deletions .github/workflows/tox.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: tox

on:
pull_request:
workflow_dispatch:

jobs:
checks:
runs-on: ubuntu-20.04
strategy:
max-parallel: 4
matrix:
python-version: [3.7, 3.9]

steps:
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install tox
- name: Test with tox
run: tox -e py$(echo ${{ matrix.python-version }} | tr -d .)
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# pre-compiled spectrum and models
*.npy
*.h5

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
7 changes: 5 additions & 2 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ disable=
E1123, # issues between pylint and tensorflow since 2.2.0
E1120, # see pylint#3613
C3001, # lambda function as variable

C0116, C0114, # docstring
[FORMAT]
max-line-length=100
max-args=12
Expand All @@ -15,4 +15,7 @@ max-args=12
min-similarity-lines=6
ignore-comments=yes
ignore-docstrings=yes
ignore-imports=no
ignore-imports=no

[TYPECHECK]
ignored-modules=torch
21 changes: 18 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Horama: A Compact Library for Feature Visualization Experiments

@todo add notebooks
@todo add illustration images + logo

Horama provides the implementation code for the research paper:

- *Unlocking Feature Visualization for Deeper Networks with MAgnitude Constrained Optimization* by Thomas Fel*, Thibaut Boissin*, Victor Boutin*, Agustin Picard*, Paul Novello*, Julien Colin, Drew Linsley, Tom Rousseau, Rémi Cadène, Laurent Gardes, Thomas Serre. [Read the paper on arXiv](https://arxiv.org/abs/2211.10154).

In addition, this repository introduces various feature visualization methods, including a reimagined approach to the [incredible work of the Clarity team](https://distill.pub/2017/feature-visualization/) and an implementation of [Feature Accentuation](https://arxiv.org/abs/2402.10039) from Hamblin & al. For an official reproduction of distill's work complete with comprehensive notebooks, we highly recommend Lucent. However, Horama focuses on experimentation within PyTorch, offering a compact and modifiable codebase.
In addition, this repository introduces various feature visualization methods, including a reimagined approach to the [incredible work of the Clarity team](https://distill.pub/2017/feature-visualization/) and an implementation of [Feature Accentuation](https://arxiv.org/abs/2402.10039) from Hamblin & al. For an official reproduction of distill's work complete with comprehensive notebooks, we highly recommend [Lucent](https://github.com/greentfrapp/lucent). However, Horama focuses on **experimentation** within PyTorch, offering a compact and easily hackable codebase.

# 🚀 Getting Started with Horama

Expand All @@ -31,11 +34,21 @@ objective = lambda images: torch.mean(model(images)[:, 1])

image1, alpha1 = maco(objective)
plot_maco(image1, alpha1)
plt.show()

image2, alpha2 = fourier(objective)
plot_maco(image2, alpha2)
plt.show()
```

# Notebooks

@todo: fourier, maco for various models on timm
@todo: cossim vs logits
@todo: speedup process, what parameters to change
@todo: feature inversion
@todo: feature accentuation

# Complete API

Complete API Guide
Expand Down Expand Up @@ -76,15 +89,17 @@ When optimizing, it's crucial to fine-tune the hyperparameters. Parameters like
@article{fel2023maco,
title={Unlocking Feature Visualization for Deeper Networks with MAgnitude Constrained Optimization},
author={Thomas, Fel and Thibaut, Boissin and Victor, Boutin and Agustin, Picard and Paul, Novello and Julien, Colin and Drew, Linsley and Tom, Rousseau and Rémi, Cadène and Laurent, Gardes and Thomas, Serre},
journal={Advances in Neural Information Processing Systems (NeurIPS)},
year={2023},
}
```

# Additional Resources
For a simpler and maintenance-friendly implementation for TensorFlow and more on feature visualization methods, check out the Xplique toolbox.

A simpler and maintain implementation of the code for Tensorflow and the other feature visualization methods used in the paper come from the [Xplique toolbox](https://github.com/deel-ai/xplique). Additionally, we have created a website called the [LENS Project](https://github.com/serre-lab/Lens), which features the 1000 classes of ImageNet.

# Authors of the code
For a code faithful to the original work of the Clarity team, we highly recommend [Lucent](https://github.com/greentfrapp/lucent).

# Authors

- [Thomas Fel](https://thomasfel.fr) - [email protected], PhD Student DEEL (ANITI), Brown University
2 changes: 1 addition & 1 deletion horama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
from .maco_fv import maco
from .fourier_fv import fourier
from .plots import plot_maco
from .losses import dot_cossim
from .losses import dot_cossim
21 changes: 14 additions & 7 deletions horama/common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import torch
from torchvision.ops import roi_align


def standardize(tensor):
# standardizes the tensor to have 0 mean and unit variance
tensor = tensor - torch.mean(tensor)
tensor = tensor / (torch.std(tensor) + 1e-4)
return tensor


def recorrelate_colors(image, device):
# recorrelates the colors of the images
assert len(image.shape) == 3

# tensor for color correlation svd square root
color_correlation_svd_sqrt = torch.tensor(
Expand All @@ -17,9 +21,6 @@ def recorrelate_colors(image, device):
dtype=torch.float32
).to(device)

# recorrelates the colors of the images
assert len(image.shape) == 3

permuted_image = image.permute(1, 2, 0).contiguous()
flat_image = permuted_image.view(-1, 3)

Expand All @@ -28,8 +29,11 @@ def recorrelate_colors(image, device):

return recorrelated_image

def optimization_step(objective_function, image, box_size, noise_level, number_of_crops_per_iteration, model_input_size):

def optimization_step(objective_function, image, box_size, noise_level,
number_of_crops_per_iteration, model_input_size):
# performs an optimization step on the generated image
# pylint: disable=C0103
assert box_size[1] >= box_size[0]
assert len(image.shape) == 3

Expand All @@ -39,7 +43,8 @@ def optimization_step(objective_function, image, box_size, noise_level, number_o
# generate random boxes
x0 = 0.5 + torch.randn((number_of_crops_per_iteration,), device=device) * 0.15
y0 = 0.5 + torch.randn((number_of_crops_per_iteration,), device=device) * 0.15
delta_x = torch.rand((number_of_crops_per_iteration,), device=device) * (box_size[1] - box_size[0]) + box_size[1]
delta_x = torch.rand((number_of_crops_per_iteration,),
device=device) * (box_size[1] - box_size[0]) + box_size[1]
delta_y = delta_x

boxes = torch.stack([torch.zeros((number_of_crops_per_iteration,), device=device),
Expand All @@ -48,11 +53,13 @@ def optimization_step(objective_function, image, box_size, noise_level, number_o
x0 + delta_x * 0.5,
y0 + delta_y * 0.5], dim=1) * image.shape[1]

cropped_and_resized_images = roi_align(image.unsqueeze(0), boxes, output_size=(model_input_size, model_input_size)).squeeze(0)
cropped_and_resized_images = roi_align(image.unsqueeze(
0), boxes, output_size=(model_input_size, model_input_size)).squeeze(0)

# add normal and uniform noise for better robustness
cropped_and_resized_images.add_(torch.randn_like(cropped_and_resized_images) * noise_level)
cropped_and_resized_images.add_((torch.rand_like(cropped_and_resized_images) - 0.5) * noise_level)
cropped_and_resized_images.add_(
(torch.rand_like(cropped_and_resized_images) - 0.5) * noise_level)

# compute the score and loss
score = objective_function(cropped_and_resized_images)
Expand Down
39 changes: 26 additions & 13 deletions horama/fourier_fv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
from .common import standardize, recorrelate_colors, optimization_step
from tqdm import tqdm

from .common import standardize, recorrelate_colors, optimization_step


def fft_2d_freq(width, height):
# calculate the 2D frequency grid for FFT
freq_y = torch.fft.fftfreq(height).unsqueeze(1)
Expand All @@ -11,21 +13,26 @@ def fft_2d_freq(width, height):

return torch.sqrt(freq_x**2 + freq_y**2)


def get_fft_scale(width, height, decay_power=1.0):
# generate the FFT scale based on the image size and decay power
# generate the scaler that account for power decay in FFT space
frequencies = fft_2d_freq(width, height)

fft_scale = 1.0 / torch.maximum(frequencies, torch.tensor(1.0 / max(width, height))) ** decay_power
fft_scale = 1.0 / torch.maximum(frequencies,
torch.tensor(1.0 / max(width, height))) ** decay_power
fft_scale = fft_scale * torch.sqrt(torch.tensor(width * height).float())

return fft_scale.to(torch.complex64)

def init_olah_buffer(width, height, std=1.0):
# initialize the Olah buffer with a random spectrum

def init_lucid_buffer(width, height, std=1.0):
# initialize the buffer with a random spectrum a la Lucid
spectrum_shape = (3, width, height // 2 + 1)
random_spectrum = torch.complex(torch.randn(spectrum_shape) * std, torch.randn(spectrum_shape) * std)
random_spectrum = torch.complex(torch.randn(spectrum_shape) * std,
torch.randn(spectrum_shape) * std)
return random_spectrum


def fourier_preconditionner(spectrum, spectrum_scaler, values_range, device):
# precondition the Fourier spectrum and convert it to spatial domain
assert spectrum.shape[0] == 3
Expand All @@ -37,16 +44,21 @@ def fourier_preconditionner(spectrum, spectrum_scaler, values_range, device):
spatial_image = standardize(spatial_image)
color_recorrelated_image = recorrelate_colors(spatial_image, device)

image = torch.sigmoid(color_recorrelated_image) * (values_range[1] - values_range[0]) + values_range[0]
image = torch.sigmoid(
color_recorrelated_image) * (values_range[1] - values_range[0]) + values_range[0]
return image

def fourier(objective_function, decay_power=1.5, total_steps=1000, learning_rate=1.0, image_size=1280, model_input_size=224,
noise=0.05, values_range=(-2.5, 2.5), crops_per_iteration=6, box_size=(0.20, 0.25), device='cuda'):
# perform the Olah optimization process

def fourier(
objective_function, decay_power=1.5, total_steps=1000, learning_rate=1.0, image_size=1280,
model_input_size=224, noise=0.05, values_range=(-2.5, 2.5),
crops_per_iteration=6, box_size=(0.20, 0.25),
device='cuda'):
# perform the Lucid (Olah & al.) optimization process
assert values_range[1] >= values_range[0]
assert box_size[1] >= box_size[0]

spectrum = init_olah_buffer(image_size, image_size, std=1.0)
spectrum = init_lucid_buffer(image_size, image_size, std=1.0)
spectrum_scaler = get_fft_scale(image_size, image_size, decay_power)

spectrum = spectrum.to(device)
Expand All @@ -56,11 +68,12 @@ def fourier(objective_function, decay_power=1.5, total_steps=1000, learning_rate
optimizer = torch.optim.NAdam([spectrum], lr=learning_rate)
transparency_accumulator = torch.zeros((3, image_size, image_size)).to(device)

for step in tqdm(range(total_steps)):
for _ in tqdm(range(total_steps)):
optimizer.zero_grad()

image = fourier_preconditionner(spectrum, spectrum_scaler, values_range, device)
loss, img = optimization_step(objective_function, image, box_size, noise, crops_per_iteration, model_input_size)
loss, img = optimization_step(objective_function, image, box_size,
noise, crops_per_iteration, model_input_size)
loss.backward()
transparency_accumulator += torch.abs(img.grad)
optimizer.step()
Expand Down
5 changes: 3 additions & 2 deletions horama/losses.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import torch


def cosine_similarity(tensor_a, tensor_b):
# calculate cosine similarity
norm_dims = list(range(1, len(tensor_a.shape)))
tensor_a = torch.nn.functional.normalize(tensor_a.float(), dim=norm_dims)
tensor_b = torch.nn.functional.normalize(tensor_b.float(), dim=norm_dims)
return torch.sum(tensor_a * tensor_b, dim=norm_dims)


def dot_cossim(tensor_a, tensor_b, cossim_pow=2.0):
# compute dot product scaled by cosine similarity
# see https://github.com/tensorflow/lucid/issues/116
cosim = torch.clamp(cosine_similarity(tensor_a, tensor_b), min=1e-1) ** cossim_pow
dot = torch.sum(tensor_a * tensor_b)
return dot * cosim
Loading
Loading