-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
350 changed files
with
47,778 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import glob | ||
|
||
|
||
def get_best_checkpoint(ckpt_dir): | ||
output = [] | ||
for filename in glob.glob(ckpt_dir + "/*_best.ckpt"): | ||
output.append(filename) | ||
assert len(output) == 1, '\n'.join(output) | ||
return output[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
""" | ||
Functions used in Critic notebooks | ||
""" | ||
import numpy as np | ||
import torch | ||
|
||
from denoisplit.core.model_type import ModelType | ||
from denoisplit.core.psnr import PSNR, RangeInvariantPsnr | ||
|
||
|
||
def _get_critic_prediction(pred: torch.Tensor, tar: torch.Tensor, D) -> dict: | ||
""" | ||
Given a predicted image and a target image, here we return a per sample prediction of | ||
the critic regarding whether they belong to real or predicted images. | ||
Args: | ||
pred: predicted image | ||
tar: target image | ||
D: discriminator model | ||
""" | ||
pred_label = D(pred) | ||
tar_label = D(tar) | ||
pred_label = torch.sigmoid(pred_label) | ||
tar_label = torch.sigmoid(tar_label) | ||
N = len(pred_label) | ||
pred_label = pred_label.view(N, -1) | ||
tar_label = tar_label.view(N, -1) | ||
return { | ||
'generated': { | ||
'mu': pred_label.mean(dim=1), | ||
'std': pred_label.std(dim=1) | ||
}, | ||
'target': { | ||
'mu': tar_label.mean(dim=1), | ||
'std': tar_label.std(dim=1) | ||
} | ||
} | ||
|
||
|
||
def get_critic_prediction(model, pred_normalized, target_normalized): | ||
pred1, pred2 = pred_normalized.chunk(2, dim=1) | ||
tar1, tar2 = target_normalized.chunk(2, dim=1) | ||
cpred_1 = _get_critic_prediction(pred1, tar1, model.D1) | ||
cpred_2 = _get_critic_prediction(pred2, tar2, model.D2) | ||
return cpred_1, cpred_2 | ||
|
||
|
||
def get_mmse_dict(model, x_normalized, target_normalized, mmse_count, model_type, psnr_type='range_invariant', | ||
compute_kl_loss=False): | ||
assert psnr_type in ['simple', 'range_invariant'] | ||
if psnr_type == 'simple': | ||
psnr_fn = PSNR | ||
else: | ||
psnr_fn = RangeInvariantPsnr | ||
|
||
img_mmse = 0 | ||
avg_logvar = None | ||
assert mmse_count >= 1 | ||
for _ in range(mmse_count): | ||
recon_normalized, td_data = model(x_normalized) | ||
ll, dic = model.likelihood(recon_normalized, target_normalized) | ||
recon_img = dic['mean'] | ||
img_mmse += recon_img / mmse_count | ||
if model.predict_logvar: | ||
if avg_logvar is None: | ||
avg_logvar = 0 | ||
avg_logvar += dic['logvar'] / mmse_count | ||
|
||
ll, dic = model.likelihood(recon_normalized, target_normalized) | ||
mse = (img_mmse - target_normalized) ** 2 | ||
# batch and the two channels | ||
N = np.prod(mse.shape[:2]) | ||
rmse = torch.sqrt(torch.mean(mse.view(N, -1), dim=1)) | ||
rmse = rmse.view(mse.shape[:2]) | ||
loss_mmse = model.likelihood.log_likelihood(target_normalized, {'mean': img_mmse, 'logvar': avg_logvar}) | ||
kl_loss = None | ||
kl_loss_channelwise = None | ||
if compute_kl_loss: | ||
kl_loss = model.get_kl_divergence_loss(td_data).cpu().numpy() | ||
resN = len(td_data['kl_channelwise']) | ||
kl_loss_channelwise = [td_data['kl_channelwise'][i].detach().cpu().numpy() for i in range(resN)] | ||
|
||
psnrl1 = psnr_fn(target_normalized[:, 0], img_mmse[:, 0]).cpu().numpy() | ||
psnrl2 = psnr_fn(target_normalized[:, 1], img_mmse[:, 1]).cpu().numpy() | ||
|
||
output = { | ||
'mmse_img': img_mmse, | ||
'mmse_rec_loss': loss_mmse, | ||
'img': recon_img, | ||
'rec_loss': ll, | ||
'rmse': rmse, | ||
'psnr_l1': psnrl1, | ||
'psnr_l2': psnrl2, | ||
'kl_loss': kl_loss, | ||
'kl_loss_channelwise': kl_loss_channelwise, | ||
} | ||
if model_type == ModelType.LadderVAECritic: | ||
D_loss = model.get_critic_loss_stats(recon_img, target_normalized)['loss'].cpu().item() | ||
cpred_1, cpred_2 = get_critic_prediction(model, recon_img, target_normalized) | ||
critic = { | ||
'label1': cpred_1, | ||
'label2': cpred_2, | ||
'D_loss': D_loss, | ||
} | ||
output['critic'] = critic | ||
return output | ||
|
||
|
||
def get_label_separated_loss(loss_tensor): | ||
assert loss_tensor.shape[1] == 2 | ||
return -1 * loss_tensor[:, 0].mean(dim=(1, 2)).cpu().numpy(), -1 * loss_tensor[:, 1].mean(dim=(1, 2)).cpu().numpy() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
""" | ||
This is specific to the HDN => uSplit pipeline. | ||
""" | ||
import os | ||
|
||
from denoisplit.config_utils import get_configdir_from_saved_predictionfile, load_config | ||
|
||
|
||
def get_source_channel(pred_fname): | ||
den_config_dir1 = get_configdir_from_saved_predictionfile(pred_fname) | ||
config_temp = load_config(den_config_dir1) | ||
print(pred_fname, config_temp.model.denoise_channel, config_temp.data.ch1_fname, config_temp.data.ch2_fname) | ||
if config_temp.model.denoise_channel == 'Ch1': | ||
ch1 = config_temp.data.ch1_fname | ||
elif config_temp.model.denoise_channel == 'Ch2': | ||
ch1 = config_temp.data.ch2_fname | ||
else: | ||
raise ValueError('Unhandled channel', config_temp.model.denoise_channel) | ||
return ch1 | ||
|
||
|
||
def whether_to_flip(ch1_fname, ch2_fname, reference_config): | ||
""" | ||
When one wants to get the highsnr data, then one does not know if the order of the channels is same as what uSplit predicts. | ||
If not, then one needs to flip the channels. | ||
""" | ||
ch1 = get_source_channel(ch1_fname) | ||
ch2 = get_source_channel(ch2_fname) | ||
channels = [reference_config.data.ch1_fname, reference_config.data.ch2_fname] | ||
assert ch1 in channels, f'{ch1} not in {channels}' | ||
assert ch2 in channels, f'{ch2} not in {channels}' | ||
assert ch1 != ch2, f'{ch1} and {ch2} are same' | ||
if ch1 == reference_config.data.ch2_fname: | ||
return True | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import os | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
from denoisplit.analysis.plot_utils import clean_ax | ||
from denoisplit.core.psnr import RangeInvariantPsnr | ||
|
||
|
||
def get_psnr(gt, pred): | ||
""" | ||
Order in the prediction is not fixed. So, we compute the psnr of each ground truth with both predictions | ||
and then pick the correct ordering based on the psnr value. | ||
""" | ||
psnr0_0 = RangeInvariantPsnr(gt[0], pred[0]) | ||
psnr0_1 = RangeInvariantPsnr(gt[0], pred[1]) | ||
|
||
psnr1_0 = RangeInvariantPsnr(gt[1], pred[0]) | ||
psnr1_1 = RangeInvariantPsnr(gt[1], pred[1]) | ||
if psnr0_0 + psnr1_1 > psnr0_1 + psnr1_0: | ||
return psnr0_0, psnr1_1 | ||
else: | ||
return psnr0_1, psnr1_0 | ||
|
||
|
||
def step_num(fname: str) -> int: | ||
""" | ||
sum1_499.jpg => 499 | ||
""" | ||
return int(fname.split('.')[0].split('_')[-1]) | ||
|
||
|
||
def get_fpath_sequence(prefix, rootdir, extension=None): | ||
""" | ||
Args: | ||
prefix: file name should start with prefix | ||
rootdir: | ||
extension:str | ||
""" | ||
output = [] | ||
for fname in os.listdir(rootdir): | ||
if prefix == fname[:len(prefix)]: | ||
if extension is not None: | ||
if fname[-1 * len(extension):] != extension: | ||
continue | ||
|
||
output.append(os.path.join(rootdir, fname)) | ||
|
||
return sorted(output, key=lambda x: step_num(os.path.basename(x))) | ||
|
||
|
||
def show_imgs_from_np_fpaths(fpath_list, ncols=4, img_sz=5, title_list=None, preprocessing_fn=None): | ||
nrows = int(np.ceil(len(fpath_list) / ncols)) | ||
_, ax = plt.subplots(figsize=(img_sz * ncols, nrows * img_sz), ncols=ncols, nrows=nrows) | ||
clean_ax(ax) | ||
if len(ax.shape) == 1: | ||
ax = ax.reshape(1, -1) | ||
for ridx in range(nrows): | ||
for cidx in range(ncols): | ||
fpath_idx = ridx * nrows + cidx | ||
fpath = fpath_list[fpath_idx] | ||
img = np.load(fpath) | ||
if preprocessing_fn is not None: | ||
img = preprocessing_fn(img) | ||
|
||
ax[ridx, cidx].imshow(img[0]) | ||
if isinstance(title_list, list): | ||
title = title_list[fpath_idx] | ||
ax[ridx, cidx].set_title(title) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
""" | ||
This module computes the gradients and stores them so that next access is fast. | ||
This can be used to compute gradients of arbitrary order on images. | ||
Last two dimensions of the data are assumed to be x & y dimension. | ||
grads = GradientFetcher(imgs) | ||
To get d/dx2y3, | ||
grad_x2_y3 = grads[2,3] | ||
""" | ||
import numpy as np | ||
from typing import List, Tuple | ||
import seaborn as sns | ||
|
||
|
||
class GradientFetcher: | ||
def __init__(self, data) -> None: | ||
self._data = data | ||
|
||
self._grad_data = {0: {0: self._data}} | ||
|
||
@staticmethod | ||
def apply_x_grad(data): | ||
grad = np.empty(data.shape) | ||
grad[:] = np.nan | ||
grad[..., :, 1:] = data[..., :, 1:] - data[..., :, :-1] | ||
return grad | ||
|
||
@staticmethod | ||
def apply_y_grad(data): | ||
grad = np.empty(data.shape) | ||
grad[:] = np.nan | ||
grad[..., 1:, :] = data[..., 1:, :] - data[..., :-1, :] | ||
return grad | ||
|
||
def __getitem__(self, order): | ||
order_x, order_y = order | ||
if order_x in self._grad_data and order_y in self._grad_data[order_x]: | ||
return self._grad_data[order_x][order_y] | ||
|
||
self.compute(order_x, order_y) | ||
return self._grad_data[order_x][order_y] | ||
|
||
def compute(self, order_x, order_y): | ||
assert order_y >= 0 and order_x >= 0 | ||
if order_x in self._grad_data: | ||
if order_y in self._grad_data[order_x]: | ||
return self._grad_data[order_x][order_y] | ||
if order_y - 1 not in self._grad_data[order_x]: | ||
self.compute(order_x, order_y - 1) | ||
|
||
self._grad_data[order_x][order_y] = self.apply_y_grad(self._grad_data[order_x][order_y - 1]) | ||
return self._grad_data[order_x][order_y] | ||
|
||
self._grad_data[order_x] = {} | ||
self.compute(order_x - 1, order_y) | ||
self._grad_data[order_x][order_y] = self.apply_x_grad(self._grad_data[order_x - 1][order_y]) | ||
return self._grad_data[order_x][order_y] | ||
|
||
|
||
class GradientViewer: | ||
def __init__(self, data) -> None: | ||
self._data = data | ||
self._grad = GradientFetcher(data) | ||
|
||
def plot(self, | ||
ax, | ||
gradorder_list: List[Tuple[int, int]], | ||
x_start=0, | ||
x_end=None, | ||
y_start=0, | ||
y_end=None, | ||
subsample=1, | ||
reduce_x=False, | ||
reduce_y=False): | ||
if x_end is None: | ||
x_end = self._data.shape[-1] | ||
|
||
if y_end is None: | ||
y_end = self._data.shape[-2] | ||
|
||
if isinstance(reduce_x, bool): | ||
reduce_x = [reduce_x] * len(gradorder_list) | ||
if isinstance(reduce_y, bool): | ||
reduce_y = [reduce_y] * len(gradorder_list) | ||
|
||
all_plots_data = [] | ||
for idx, order in enumerate(gradorder_list): | ||
grad_data = self._grad[order] | ||
grad_data = grad_data[y_start:y_end:subsample, x_start:x_end:subsample] | ||
if reduce_x[idx]: | ||
grad_data = grad_data.mean(axis=1) | ||
sns.lineplot(data=grad_data, ax=ax[idx]) | ||
all_plots_data.append(grad_data) | ||
elif reduce_y[idx]: | ||
grad_data = grad_data.mean(axis=0) | ||
sns.lineplot(data=grad_data, ax=ax[idx]) | ||
all_plots_data.append(grad_data) | ||
else: | ||
sns.heatmap(grad_data, ax=ax[idx]) | ||
all_plots_data.append(grad_data) | ||
return all_plots_data | ||
|
||
|
||
if __name__ == '__main__': | ||
import matplotlib.pyplot as plt | ||
imgs = np.arange(1024).reshape(1, 1, 32, 32) | ||
plt.imshow(imgs[0, 0]) | ||
grads = GradientFetcher(imgs) | ||
gradx = grads[1, 0] | ||
print('next') | ||
grady = grads[0, 1] | ||
print('next') | ||
gradxy = grads[1, 1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import numpy as np | ||
import torch | ||
|
||
from denoisplit.core.data_utils import crop_img_tensor | ||
|
||
|
||
def get_img_from_forward_output(out, model): | ||
recons_img = model.likelihood.get_mean_lv(out)[0] | ||
recons_img = recons_img * model.data_std + model.data_mean | ||
return recons_img | ||
|
||
|
||
def get_z(img, model): | ||
with torch.no_grad(): | ||
img = torch.Tensor(img[None]).cuda() | ||
x_normalized = model.normalize(img) | ||
recons_img_latent, td_data = model(x_normalized) | ||
q_mu = td_data['q_mu'] | ||
recons_img = get_img_from_forward_output(recons_img_latent, model) | ||
return recons_img, q_mu | ||
|
||
|
||
def get_recons_with_latent(img_shape, z, model): | ||
# Top-down inference/generation | ||
out, td_data = model.topdown_pass(None, forced_latent=z, n_img_prior=1) | ||
# Restore original image size | ||
out = crop_img_tensor(out, img_shape) | ||
|
||
return get_img_from_forward_output(out, model) |
Oops, something went wrong.