Skip to content

Commit

Permalink
added code
Browse files Browse the repository at this point in the history
  • Loading branch information
ashesh-0 committed Mar 25, 2024
1 parent 95e9377 commit c7b50ea
Show file tree
Hide file tree
Showing 350 changed files with 47,778 additions and 0 deletions.
Binary file not shown.
Binary file added denoisplit/__pycache__/losses.cpython-39.pyc
Binary file not shown.
Binary file added denoisplit/__pycache__/training.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file added denoisplit/__pycache__/utils.cpython-39.pyc
Binary file not shown.
Binary file not shown.
9 changes: 9 additions & 0 deletions denoisplit/analysis/checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import glob


def get_best_checkpoint(ckpt_dir):
output = []
for filename in glob.glob(ckpt_dir + "/*_best.ckpt"):
output.append(filename)
assert len(output) == 1, '\n'.join(output)
return output[0]
110 changes: 110 additions & 0 deletions denoisplit/analysis/critic_notebook_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
Functions used in Critic notebooks
"""
import numpy as np
import torch

from denoisplit.core.model_type import ModelType
from denoisplit.core.psnr import PSNR, RangeInvariantPsnr


def _get_critic_prediction(pred: torch.Tensor, tar: torch.Tensor, D) -> dict:
"""
Given a predicted image and a target image, here we return a per sample prediction of
the critic regarding whether they belong to real or predicted images.
Args:
pred: predicted image
tar: target image
D: discriminator model
"""
pred_label = D(pred)
tar_label = D(tar)
pred_label = torch.sigmoid(pred_label)
tar_label = torch.sigmoid(tar_label)
N = len(pred_label)
pred_label = pred_label.view(N, -1)
tar_label = tar_label.view(N, -1)
return {
'generated': {
'mu': pred_label.mean(dim=1),
'std': pred_label.std(dim=1)
},
'target': {
'mu': tar_label.mean(dim=1),
'std': tar_label.std(dim=1)
}
}


def get_critic_prediction(model, pred_normalized, target_normalized):
pred1, pred2 = pred_normalized.chunk(2, dim=1)
tar1, tar2 = target_normalized.chunk(2, dim=1)
cpred_1 = _get_critic_prediction(pred1, tar1, model.D1)
cpred_2 = _get_critic_prediction(pred2, tar2, model.D2)
return cpred_1, cpred_2


def get_mmse_dict(model, x_normalized, target_normalized, mmse_count, model_type, psnr_type='range_invariant',
compute_kl_loss=False):
assert psnr_type in ['simple', 'range_invariant']
if psnr_type == 'simple':
psnr_fn = PSNR
else:
psnr_fn = RangeInvariantPsnr

img_mmse = 0
avg_logvar = None
assert mmse_count >= 1
for _ in range(mmse_count):
recon_normalized, td_data = model(x_normalized)
ll, dic = model.likelihood(recon_normalized, target_normalized)
recon_img = dic['mean']
img_mmse += recon_img / mmse_count
if model.predict_logvar:
if avg_logvar is None:
avg_logvar = 0
avg_logvar += dic['logvar'] / mmse_count

ll, dic = model.likelihood(recon_normalized, target_normalized)
mse = (img_mmse - target_normalized) ** 2
# batch and the two channels
N = np.prod(mse.shape[:2])
rmse = torch.sqrt(torch.mean(mse.view(N, -1), dim=1))
rmse = rmse.view(mse.shape[:2])
loss_mmse = model.likelihood.log_likelihood(target_normalized, {'mean': img_mmse, 'logvar': avg_logvar})
kl_loss = None
kl_loss_channelwise = None
if compute_kl_loss:
kl_loss = model.get_kl_divergence_loss(td_data).cpu().numpy()
resN = len(td_data['kl_channelwise'])
kl_loss_channelwise = [td_data['kl_channelwise'][i].detach().cpu().numpy() for i in range(resN)]

psnrl1 = psnr_fn(target_normalized[:, 0], img_mmse[:, 0]).cpu().numpy()
psnrl2 = psnr_fn(target_normalized[:, 1], img_mmse[:, 1]).cpu().numpy()

output = {
'mmse_img': img_mmse,
'mmse_rec_loss': loss_mmse,
'img': recon_img,
'rec_loss': ll,
'rmse': rmse,
'psnr_l1': psnrl1,
'psnr_l2': psnrl2,
'kl_loss': kl_loss,
'kl_loss_channelwise': kl_loss_channelwise,
}
if model_type == ModelType.LadderVAECritic:
D_loss = model.get_critic_loss_stats(recon_img, target_normalized)['loss'].cpu().item()
cpred_1, cpred_2 = get_critic_prediction(model, recon_img, target_normalized)
critic = {
'label1': cpred_1,
'label2': cpred_2,
'D_loss': D_loss,
}
output['critic'] = critic
return output


def get_label_separated_loss(loss_tensor):
assert loss_tensor.shape[1] == 2
return -1 * loss_tensor[:, 0].mean(dim=(1, 2)).cpu().numpy(), -1 * loss_tensor[:, 1].mean(dim=(1, 2)).cpu().numpy()
35 changes: 35 additions & 0 deletions denoisplit/analysis/denoiser_splitter_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
This is specific to the HDN => uSplit pipeline.
"""
import os

from denoisplit.config_utils import get_configdir_from_saved_predictionfile, load_config


def get_source_channel(pred_fname):
den_config_dir1 = get_configdir_from_saved_predictionfile(pred_fname)
config_temp = load_config(den_config_dir1)
print(pred_fname, config_temp.model.denoise_channel, config_temp.data.ch1_fname, config_temp.data.ch2_fname)
if config_temp.model.denoise_channel == 'Ch1':
ch1 = config_temp.data.ch1_fname
elif config_temp.model.denoise_channel == 'Ch2':
ch1 = config_temp.data.ch2_fname
else:
raise ValueError('Unhandled channel', config_temp.model.denoise_channel)
return ch1


def whether_to_flip(ch1_fname, ch2_fname, reference_config):
"""
When one wants to get the highsnr data, then one does not know if the order of the channels is same as what uSplit predicts.
If not, then one needs to flip the channels.
"""
ch1 = get_source_channel(ch1_fname)
ch2 = get_source_channel(ch2_fname)
channels = [reference_config.data.ch1_fname, reference_config.data.ch2_fname]
assert ch1 in channels, f'{ch1} not in {channels}'
assert ch2 in channels, f'{ch2} not in {channels}'
assert ch1 != ch2, f'{ch1} and {ch2} are same'
if ch1 == reference_config.data.ch2_fname:
return True
return False
69 changes: 69 additions & 0 deletions denoisplit/analysis/double_dip_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os

import matplotlib.pyplot as plt
import numpy as np

from denoisplit.analysis.plot_utils import clean_ax
from denoisplit.core.psnr import RangeInvariantPsnr


def get_psnr(gt, pred):
"""
Order in the prediction is not fixed. So, we compute the psnr of each ground truth with both predictions
and then pick the correct ordering based on the psnr value.
"""
psnr0_0 = RangeInvariantPsnr(gt[0], pred[0])
psnr0_1 = RangeInvariantPsnr(gt[0], pred[1])

psnr1_0 = RangeInvariantPsnr(gt[1], pred[0])
psnr1_1 = RangeInvariantPsnr(gt[1], pred[1])
if psnr0_0 + psnr1_1 > psnr0_1 + psnr1_0:
return psnr0_0, psnr1_1
else:
return psnr0_1, psnr1_0


def step_num(fname: str) -> int:
"""
sum1_499.jpg => 499
"""
return int(fname.split('.')[0].split('_')[-1])


def get_fpath_sequence(prefix, rootdir, extension=None):
"""
Args:
prefix: file name should start with prefix
rootdir:
extension:str
"""
output = []
for fname in os.listdir(rootdir):
if prefix == fname[:len(prefix)]:
if extension is not None:
if fname[-1 * len(extension):] != extension:
continue

output.append(os.path.join(rootdir, fname))

return sorted(output, key=lambda x: step_num(os.path.basename(x)))


def show_imgs_from_np_fpaths(fpath_list, ncols=4, img_sz=5, title_list=None, preprocessing_fn=None):
nrows = int(np.ceil(len(fpath_list) / ncols))
_, ax = plt.subplots(figsize=(img_sz * ncols, nrows * img_sz), ncols=ncols, nrows=nrows)
clean_ax(ax)
if len(ax.shape) == 1:
ax = ax.reshape(1, -1)
for ridx in range(nrows):
for cidx in range(ncols):
fpath_idx = ridx * nrows + cidx
fpath = fpath_list[fpath_idx]
img = np.load(fpath)
if preprocessing_fn is not None:
img = preprocessing_fn(img)

ax[ridx, cidx].imshow(img[0])
if isinstance(title_list, list):
title = title_list[fpath_idx]
ax[ridx, cidx].set_title(title)
114 changes: 114 additions & 0 deletions denoisplit/analysis/grad_viewer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
This module computes the gradients and stores them so that next access is fast.
This can be used to compute gradients of arbitrary order on images.
Last two dimensions of the data are assumed to be x & y dimension.
grads = GradientFetcher(imgs)
To get d/dx2y3,
grad_x2_y3 = grads[2,3]
"""
import numpy as np
from typing import List, Tuple
import seaborn as sns


class GradientFetcher:
def __init__(self, data) -> None:
self._data = data

self._grad_data = {0: {0: self._data}}

@staticmethod
def apply_x_grad(data):
grad = np.empty(data.shape)
grad[:] = np.nan
grad[..., :, 1:] = data[..., :, 1:] - data[..., :, :-1]
return grad

@staticmethod
def apply_y_grad(data):
grad = np.empty(data.shape)
grad[:] = np.nan
grad[..., 1:, :] = data[..., 1:, :] - data[..., :-1, :]
return grad

def __getitem__(self, order):
order_x, order_y = order
if order_x in self._grad_data and order_y in self._grad_data[order_x]:
return self._grad_data[order_x][order_y]

self.compute(order_x, order_y)
return self._grad_data[order_x][order_y]

def compute(self, order_x, order_y):
assert order_y >= 0 and order_x >= 0
if order_x in self._grad_data:
if order_y in self._grad_data[order_x]:
return self._grad_data[order_x][order_y]
if order_y - 1 not in self._grad_data[order_x]:
self.compute(order_x, order_y - 1)

self._grad_data[order_x][order_y] = self.apply_y_grad(self._grad_data[order_x][order_y - 1])
return self._grad_data[order_x][order_y]

self._grad_data[order_x] = {}
self.compute(order_x - 1, order_y)
self._grad_data[order_x][order_y] = self.apply_x_grad(self._grad_data[order_x - 1][order_y])
return self._grad_data[order_x][order_y]


class GradientViewer:
def __init__(self, data) -> None:
self._data = data
self._grad = GradientFetcher(data)

def plot(self,
ax,
gradorder_list: List[Tuple[int, int]],
x_start=0,
x_end=None,
y_start=0,
y_end=None,
subsample=1,
reduce_x=False,
reduce_y=False):
if x_end is None:
x_end = self._data.shape[-1]

if y_end is None:
y_end = self._data.shape[-2]

if isinstance(reduce_x, bool):
reduce_x = [reduce_x] * len(gradorder_list)
if isinstance(reduce_y, bool):
reduce_y = [reduce_y] * len(gradorder_list)

all_plots_data = []
for idx, order in enumerate(gradorder_list):
grad_data = self._grad[order]
grad_data = grad_data[y_start:y_end:subsample, x_start:x_end:subsample]
if reduce_x[idx]:
grad_data = grad_data.mean(axis=1)
sns.lineplot(data=grad_data, ax=ax[idx])
all_plots_data.append(grad_data)
elif reduce_y[idx]:
grad_data = grad_data.mean(axis=0)
sns.lineplot(data=grad_data, ax=ax[idx])
all_plots_data.append(grad_data)
else:
sns.heatmap(grad_data, ax=ax[idx])
all_plots_data.append(grad_data)
return all_plots_data


if __name__ == '__main__':
import matplotlib.pyplot as plt
imgs = np.arange(1024).reshape(1, 1, 32, 32)
plt.imshow(imgs[0, 0])
grads = GradientFetcher(imgs)
gradx = grads[1, 0]
print('next')
grady = grads[0, 1]
print('next')
gradxy = grads[1, 1]
29 changes: 29 additions & 0 deletions denoisplit/analysis/lvae_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import numpy as np
import torch

from denoisplit.core.data_utils import crop_img_tensor


def get_img_from_forward_output(out, model):
recons_img = model.likelihood.get_mean_lv(out)[0]
recons_img = recons_img * model.data_std + model.data_mean
return recons_img


def get_z(img, model):
with torch.no_grad():
img = torch.Tensor(img[None]).cuda()
x_normalized = model.normalize(img)
recons_img_latent, td_data = model(x_normalized)
q_mu = td_data['q_mu']
recons_img = get_img_from_forward_output(recons_img_latent, model)
return recons_img, q_mu


def get_recons_with_latent(img_shape, z, model):
# Top-down inference/generation
out, td_data = model.topdown_pass(None, forced_latent=z, n_img_prior=1)
# Restore original image size
out = crop_img_tensor(out, img_shape)

return get_img_from_forward_output(out, model)
Loading

0 comments on commit c7b50ea

Please sign in to comment.