Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

first merge with juglab HDN #5

Open
wants to merge 1 commit into
base: Vary_nFilters
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
checkpoints*/
data/
wandb/
__pycache__
hpc*
*.pt
root*/
results*_
*.png
*.tif
*.tiff
*.sbatch
*.npy
Trained_model/
*.out
Trained_Models/*
*.sbatch
sbatches_logs/*
180 changes: 150 additions & 30 deletions boilerplate/boilerplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torchvision.utils import save_image
from torch.nn import init
from torch.optim.optimizer import Optimizer
from torch.cuda.amp import autocast
import os
import glob
import random
Expand All @@ -19,6 +20,7 @@
from models.lvae import LadderVAE
import lib.utils as utils


def _make_datamanager(train_images, val_images, test_images, batch_size, test_batch_size):

"""Create data loaders for training, validation and test sets during training.
Expand Down Expand Up @@ -51,6 +53,8 @@ def _make_datamanager(train_images, val_images, test_images, batch_size, test_ba
train_images = torch.from_numpy(train_images)
train_labels = torch.zeros(len(train_images),).fill_(float('nan'))
train_set = TensorDataset(train_images, train_labels)

# TODO add normal dataloader

val_images = (val_images-data_mean)/data_std
val_images = torch.from_numpy(val_images)
Expand All @@ -69,6 +73,7 @@ def _make_datamanager(train_images, val_images, test_images, batch_size, test_ba

return train_loader, val_loader, test_loader, data_mean, data_std


def _make_optimizer_and_scheduler(model, lr, weight_decay) -> Optimizer:
"""
Implements Adamax optimizer and learning rate scheduler.
Expand All @@ -88,23 +93,29 @@ def _make_optimizer_and_scheduler(model, lr, weight_decay) -> Optimizer:
verbose=True)
return optimizer, scheduler

def forward_pass(x, y, device, model, gaussian_noise_std)-> dict:

def forward_pass(x, y, device, model, gaussian_noise_std, amp=True, stochasticity=True)-> dict:
x = x.to(device, non_blocking=True)
model_out = model(x)
y = y.to(device, non_blocking=True)
with autocast(enabled=amp):
model_out = model(x,y)

if model.mode_pred is False:

recons_sep = -model_out['ll']

recons_sep = -model_out['ll'] # negative log likelihood
kl_sep = model_out['kl_sep']
kl = model_out['kl']
kl_loss = model_out['kl_loss']/float(x.shape[2]*x.shape[3])

if gaussian_noise_std is None:
recons_loss = recons_sep.mean()
if stochasticity == True:
kl_loss = model_out['kl_loss']/float(x.shape[2]*x.shape[3])
else:
recons_loss = recons_sep.mean()/ ((gaussian_noise_std/model.data_std)**2)

kl_loss = None
recons_loss = recons_sep.mean()

output = {
'top_bu': model_out['top_bu'],
'z': model_out['z'],
'out_img': model_out['out_sample'],
'recons_loss': recons_loss,
'kl_loss': kl_loss,
'out_mean': model_out['out_mean'],
Expand All @@ -113,6 +124,9 @@ def forward_pass(x, y, device, model, gaussian_noise_std)-> dict:

else:
output = {
'top_bu': model_out['top_bu'],
'z': model_out['z'],
'out_img': model_out['out_sample'],
'recons_loss': None,
'kl_loss': None,
'out_mean': model_out['out_mean'],
Expand All @@ -124,6 +138,7 @@ def forward_pass(x, y, device, model, gaussian_noise_std)-> dict:

return output


def img_grid_pad_value(imgs, thresh = .2) -> float:
"""Returns padding value (black or white) for a grid of images.
Hack to visualize boundaries between images with torchvision's
Expand All @@ -135,7 +150,6 @@ def img_grid_pad_value(imgs, thresh = .2) -> float:
Returns:
pad_value (float): The padding value
"""

assert imgs.dim() == 4
imgs = imgs.clamp(min=0., max=1.)
assert 0. < thresh < 1.
Expand All @@ -153,6 +167,7 @@ def img_grid_pad_value(imgs, thresh = .2) -> float:
return 1.0
return 0.0


def save_image_grid(images,filename,nrows):
"""Saves images on disk.
Args:
Expand All @@ -163,28 +178,36 @@ def save_image_grid(images,filename,nrows):
pad = img_grid_pad_value(images)
save_image(images, filename, nrow=nrows, pad_value=pad, normalize=True)

def generate_and_save_samples(model, filename, nrows = 4) -> None:

def generate_and_save_samples(model, filename, nrows=4, amp=True) -> None:
"""Save generated images at intermediate training steps.
Args:
model: instance of LadderVAE class
filename (str): filename where to save denoised images
nrows (int): Number of rows in which to arrange denoised/generated images.

"""
samples = model.sample_prior(nrows**2)
with autocast(enabled=amp):
samples = model.sample_prior(nrows**2)
if samples.dim() == 5:
samples = samples[0, ...].permute(1, 0, 2, 3)
save_image_grid(samples, filename, nrows=nrows)
return samples


def save_image_grid_reconstructions(inputs,recons,filename):
assert inputs.shape == recons.shape
n_img = inputs.shape[0]
n = int(np.sqrt(2 * n_img))
imgs = torch.stack([inputs.cpu(), recons.cpu()])
imgs = imgs.permute(1, 0, 2, 3, 4)
imgs = imgs.reshape(n**2, inputs.size(1), inputs.size(2), inputs.size(3))
imgs = imgs.permute(1, 0, *list(range(2, imgs.dim())))
imgs = imgs.reshape(n**2, *inputs.shape[1:])
if imgs.dim() == 5:
imgs = imgs[0, ...].permute(1, 0, 2, 3)
save_image_grid(imgs, filename, nrows=n)

def generate_and_save_reconstructions(x,filename,device,model,gaussian_noise_std,data_std,nrows) -> None:

def generate_and_save_reconstructions(x,filename,device,model,gaussian_noise_std,data_std,nrows, amp=True) -> None:
"""Save denoised images at intermediate training steps.
Args:
x (Torch.tensor): Batch of images from test set
Expand All @@ -197,13 +220,14 @@ def generate_and_save_reconstructions(x,filename,device,model,gaussian_noise_std

"""
n_img = nrows**2 // 2
if x.shape[0] < n_img:
#print(x.shape)
if x.shape[2] < n_img:
msg = ("{} data points required, but given batch has size {}. "
"Please use a larger batch.".format(n_img, x.shape[0]))
raise RuntimeError(msg)
x = x.to(device)
outputs = forward_pass(x, x, device, model, gaussian_noise_std)
# x = x.to(device)

outputs = forward_pass(x, x, device, model, gaussian_noise_std, amp=amp)

# Try to get reconstruction from different sources in order
recons = None
Expand All @@ -221,13 +245,15 @@ def generate_and_save_reconstructions(x,filename,device,model,gaussian_noise_std
raise RuntimeError(msg)

# Pick required number of images

x = x[:n_img]
recons = recons[:n_img]

# Save inputs and reconstructions in a grid
save_image_grid_reconstructions(x, recons, filename)

def save_images(img_folder, device, model, test_loader, gaussian_noise_std, data_std, nrows) -> None:

def save_images(img_folder, device, model, test_loader, gaussian_noise_std, data_std, nrows, amp=True) -> None:
"""Save generated images and denoised images at intermediate training steps.
Args:
img_folder (str): Folder where to save images
Expand All @@ -246,15 +272,18 @@ def save_images(img_folder, device, model, test_loader, gaussian_noise_std, dat
generate_and_save_samples(model, fname, nrows)

# Get first test batch
(x, _) = next(iter(test_loader))
x = next(iter(test_loader))
x = x.unsqueeze(1)

x = x.to(device=device, dtype=torch.float)

# Save model original/reconstructions
fname = os.path.join(img_folder, 'reconstruction_' + str(step) + '.png')

generate_and_save_reconstructions(x, fname, device, model, gaussian_noise_std, data_std, nrows)
generate_and_save_reconstructions(x, fname, device, model, gaussian_noise_std, data_std, nrows, amp)

def _test(epoch, img_folder, device, model, test_loader, gaussian_noise_std, data_std, nrows):

def _test(epoch, img_folder, device, model, test_loader, gaussian_noise_std, data_std, nrows, amp=True):
"""Perform a test step at intermediate training steps.
Args:
epoch (int): Current training epoch
Expand All @@ -270,7 +299,7 @@ def _test(epoch, img_folder, device, model, test_loader, gaussian_noise_std, da
# Evaluation mode
model.eval()
# Save images
save_images(img_folder, device, model, test_loader, gaussian_noise_std, data_std, nrows)
save_images(img_folder, device, model, test_loader, gaussian_noise_std, data_std, nrows, amp)


def get_normalized_tensor(img,model,device):
Expand All @@ -290,6 +319,75 @@ def get_normalized_tensor(img,model,device):
return test_images


def predcit_tiled(img, model, device, patch_size, overlap=0, num_samples=10, tta=False, gaussian_noise_std=None) -> np.ndarray:
"""
Predicts image using tiled prediction.
Parameters
----------
img: array of shape (H,W) or (Z,H,W)
Image to predict.
model: Hierarchical DivNoising model
device: GPU device.
img_width: int
Width of image tiles.
overlap: int
Overlap between tiles.
num_samples: int
Number of samples to use for prediction.
tta: bool
Whether to use test time augmentation.
gaussian_noise_std: float
std of Gaussian noise used to corrupt the data. For intrinsically noisy data, set to None.
"""

#TODO: refactor/rewrite ?

assert patch_size > overlap, f'Patch size {patch_size} must be larger than overlap {overlap}.'

zmin = 0
xmin = 0
ymin = 0
zmax = patch_size[0]
xmax = patch_size[1]
ymax = patch_size[2]
ovLeft = 0
pred = np.zeros(img.shape)
coords = []

while xmin < img.shape[1]:
ovTop = 0
while ymin < img.shape[0]:
ymin_ = min(img.shape[0], ymax) - patch_size
xmin_ = min(img.shape[1], xmax) - patch_size
lastPatchShiftY = ymin - ymin_
lastPatchShiftX = xmin - xmin_
if ((ymin_, ymax), (xmin_, xmax)) not in coords:
coords.append(((ymin_, ymax), (xmin_, xmax)))
img_mmse, samples = boilerplate.predict(img[ymin_:ymax,xmin_:xmax],
num_samples,
model,
gaussian_noise_std,
device,
tta)

# preticted_tile = img_mmse[lastPatchShiftY:,lastPatchShiftX:][ovTop:,ovLeft:]
pred[ymin:ymax,xmin:xmax][ovTop:,ovLeft:] = img_mmse[lastPatchShiftY:,lastPatchShiftX:][ovTop:,ovLeft:]

ymin = ymin-overlap + patch_size
ymax = min(img.shape[0], ymin + patch_size)
ovTop = overlap//2

ymin = 0
ymax = patch_size
xmin = xmin-overlap + patch_size
xmax = min(img.shape[1], xmin + patch_size)
ovLeft = overlap//2



return pred


def predict_sample(img, model, gaussian_noise_std, device):
"""
Predicts a sample.
Expand All @@ -299,7 +397,7 @@ def predict_sample(img, model, gaussian_noise_std, device):
Image for which denoised MMSE estimate needs to be computed.
model: Ladder VAE object
Hierarchical DivNoising model.
gaussian_noise_std: float
gaussian_noise_std: float
std of Gaussian noise used to corrupty data. For intrinsically noisy data, set to None.
device: GPU device
"""
Expand All @@ -326,13 +424,12 @@ def predict_mmse(img_n, num_samples, model, gaussian_noise_std, device, return_s
std of Gaussian noise used to corrupty data. For intrinsically noisy data, set to None.
device: GPU device
"""
img_height,img_width=img_n.shape[0],img_n.shape[1]
img_t = get_normalized_tensor(img_n,model,device)
image_sample = img_t.view(1,1,img_height,img_width)
image_sample = img_t.view(1, 1, *img_n.shape)
image_sample = image_sample.to(device=device, dtype=torch.float)
samples = []

for j in tqdm(range(num_samples)):
for j in range(num_samples):
sample = predict_sample(image_sample, model, gaussian_noise_std, device=device)
samples.append(np.squeeze(sample))

Expand Down Expand Up @@ -424,7 +521,7 @@ def generate_arbitrary_sized_samples(sample_shape, num_samples, model, save_path
torch.set_grad_enabled(False)
n = num_samples
model.eval()
orig_model_img_shape = model.img_shape
orig_model_img_shape = (128,128) #model.img_shape
fname = os.path.join(save_path, "samples_"+str(sample_shape[0])+"x"+str(sample_shape[1])+".png")

if sample_shape[0]<orig_model_img_shape[0] or sample_shape[1]<orig_model_img_shape[1]:
Expand All @@ -448,4 +545,27 @@ def generate_arbitrary_sized_samples(sample_shape, num_samples, model, save_path
samples_numpy = samples_denormalized.detach().cpu().numpy()
samples_numpy = samples_numpy[:,0,:,:]

return samples_numpy


def generate_samples(sample_shape, num_samples, model, top_bu_value, save_path):
assert isinstance(sample_shape [0], int) and isinstance(sample_shape [1], int)
torch.set_grad_enabled(False)
n = num_samples
model.eval()
orig_model_img_shape = (128,128) #model.img_shape
fname = os.path.join(save_path, "samples_"+str(sample_shape[0])+"x"+str(sample_shape[1])+".png")

model.img_shape = sample_shape
top_layer = model.top_down_layers[-1]
assert top_layer.is_top_layer
orig_top_prior_params = top_layer.top_prior_params
modified_top_prior_params = top_bu_value
top_layer.top_prior_params = torch.nn.Parameter(modified_top_prior_params)
samples = generate_and_save_samples(model, fname, nrows=n)

samples_denormalized = (samples*model.data_std)+model.data_mean
samples_numpy = samples_denormalized.detach().cpu().numpy()
samples_numpy = samples_numpy[:,0,:,:]

return samples_numpy
Loading