Skip to content

Commit

Permalink
Feature: Add new CLIP-IQA metric
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanPetersTM committed Dec 20, 2024
1 parent 99f9ce2 commit 7229473
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 2 deletions.
46 changes: 45 additions & 1 deletion lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@


from lensless.utils.dataset import DiffuserCamTestDataset
from lensless.utils.dataset import HFDataset
from lensless.utils.io import save_image
from waveprop.noise import add_shot_noise
from tqdm import tqdm
import os
import numpy as np
import wandb
from lensless.eval.metric import clip_iqa

try:
import torch
Expand Down Expand Up @@ -102,6 +104,7 @@ def benchmark(
),
"SSIM": StructuralSimilarityIndexMeasure(reduction=None, data_range=(0, 1)).to(device),
"ReconstructionError": None,
"CLIP-IQA": clip_iqa
}

metrics_values = {key: [] for key in metrics}
Expand Down Expand Up @@ -241,6 +244,10 @@ def benchmark(
metrics_values[metric].append(
metrics[metric](prediction, lensed).cpu().item()
)
elif metric == "CLIP-IQA":
metrics_values[metric].append(
metrics[metric](prediction, lensed).cpu().item()
)
elif metric == "MSE":
metrics_values[metric].append(
metrics[metric](prediction, lensed).cpu().item() * len(batch[0])
Expand Down Expand Up @@ -350,7 +357,44 @@ def benchmark(
device = "cpu"

# prepare dataset
dataset = DiffuserCamTestDataset(n_files=n_files, downsample=downsample)
#dataset = DiffuserCamTestDataset(n_files=n_files, downsample=downsample)
dataset = HFDataset(
huggingface_repo='Lensless/TapeCam-Mirflickr-Ambient-100',
cache_dir=None,
psf='psf.png',
single_channel_psf=False,
split="test",
display_res=[600, 600],
rotate=False,
flipud=False,
flip_lensed=False,
downsample=1,
downsample_lensed=2,
alignment={'top_left': [85, 185], 'height': 178},
save_psf=True,
n_files=None,
simulation_config={
'grayscale': False,
'output_dim': None,
'object_height': 0.04,
'flip': True,
'random_shift': False,
'random_vflip': 0.5,
'random_hflip': 0.5,
'random_rotate': False,
'scene2mask': 0.1,
'mask2sensor': 0.009,
'deadspace': True,
'use_waveprop': False,
'sensor': 'rgb', # Replace with the correct value if different
},
per_pixel_color_shift=True,
per_pixel_color_shift_range=[0.8, 1.2],
bg_snr_range=None,
bg_fp=None,
force_rgb=False,
simulate_lensless=False,
)

# prepare model
psf = dataset.psf.to(device)
Expand Down
43 changes: 42 additions & 1 deletion lensless/eval/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,19 @@
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
import lpips as lpips_lib
import torch
import torch.nn.functional as F
from torchmetrics.multimodal import CLIPImageQualityAssessment
from scipy.ndimage import rotate
from lensless.utils.image import resize


# Initialize CLIP-IQA model
clip_iqa_model = CLIPImageQualityAssessment(
model_name_or_path=("clip_iqa"),
prompts=("noisiness", ), # TODO change if different metric is required
).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))


def mse(true, est, normalize=True):
"""
Compute the mean-squared error between two images. The closer to 0, the
Expand Down Expand Up @@ -260,7 +269,6 @@ def lpips(true, est, normalize=True):
)
return loss_fn.forward(true, est).squeeze().item()


def extract(
estimate, original, vertical_crop=None, horizontal_crop=None, rotation=0, verbose=False
):
Expand Down Expand Up @@ -329,3 +337,36 @@ def extract(
print(img_resize.max())

return estimate, img_resize

def clip_iqa(true, est, normalize=True):
"""
Computes the CLIP Image Quality Assessment (CLIP-IQA) score between the true and estimated images.
Args:
true (Tensor): The ground truth image tensor.
est (Tensor): The estimated image tensor.
normalize (bool, optional): If True, normalize the images before computing the CLIP-IQA score. Default is True.
Returns:
float: The CLIP-IQA score.
"""
# if normalize:
# true = np.array(true, dtype=np.float32)
# est = np.array(est, dtype=np.float32)
# true /= true.max()
# est /= est.max()

# Compute CLIP-IQA
with torch.no_grad():
# Resize images to 224x224 for CLIP-IQA
outputs_resized = F.interpolate(
est, size=(224, 224), mode="bilinear", align_corners=False
)

outputs_3d = outputs_resized

#clip_iqa_scores = self.clip_iqa(outputs_3d)


return clip_iqa_model(outputs_3d)

# Compute CLIP-IQA scores over the batch
clip_iqa = clip_iqa_scores.mean().item()

0 comments on commit 7229473

Please sign in to comment.