Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
LorenzoAgnolucci authored Feb 21, 2023
1 parent 866b60c commit f4c8599
Show file tree
Hide file tree
Showing 10 changed files with 830 additions and 83 deletions.
76 changes: 76 additions & 0 deletions src/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC

from vgg_feature_extractor import VGGFeatureExtractor


def l1_loss(pred, target):
return F.l1_loss(pred, target, reduction="none")


def charbonnier_loss(pred, target, eps=1e-12):
return torch.sqrt((pred - target) ** 2 + eps).mean()


class CharbonnierLoss(nn.Module):
"""Charbonnier loss (one variant of Robust L1Loss, a differentiable
variant of L1Loss).
Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
Super-Resolution".
Args:
eps (float): A value used to control the curvature near zero. Default: 1e-12.
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
"""

def __init__(self, eps=1e-12, loss_weight=1.0):
super(CharbonnierLoss, self).__init__()
self.loss_weight = loss_weight
self.eps = eps

def forward(self, pred, target):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
"""
return self.loss_weight * charbonnier_loss(pred, target, eps=self.eps)


class PerceptualLoss(nn.Module):
"""VGG Perceptual loss
"""

def __init__(self, layer_weights, use_input_norm=True, use_range_norm=False, criterion='l2', loss_weight=1.0):
super(PerceptualLoss, self).__init__()
self.layer_weights = layer_weights
self.vgg = VGGFeatureExtractor(layer_name_list=list(layer_weights.keys()),
use_input_norm=use_input_norm,
use_range_norm=use_range_norm)
self.criterion_type = criterion
if self.criterion_type == 'l1':
self.criterion = torch.nn.L1Loss()
elif self.criterion_type == 'l2':
self.criterion = torch.nn.MSELoss()
else:
raise NotImplementedError(f'{criterion} criterion is not supported.')
self.loss_weight = loss_weight

def forward(self, pred, target):
"""Forward function.
Args:
pred (Tensor): Input tensor with shape (n, c, h, w).
target (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
pred_feat = self.vgg(pred)
target_feat = self.vgg(target.detach())

loss = 0.0
for i in pred_feat.keys():
loss += self.criterion(pred_feat[i], target_feat[i]) * self.layer_weights[i]
loss *= self.loss_weight
return loss
13 changes: 7 additions & 6 deletions src/real_world_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def main(args):

checkpoints_path = Path("../pretrained_models") / args.experiment_name
checkpoints_path = Path("pretrained_models") / args.experiment_name
checkpoint_file = os.listdir(checkpoints_path)[-1]
checkpoint = checkpoints_path / checkpoint_file

Expand All @@ -35,7 +35,7 @@ def main(args):
patch_size=args.patch_size,
crop_mode="center")

model = VideoSwinEncoderDecoder(use_checkpoint=args.use_checkpoint, depths=[2, 2, 6, 2], embed_dim=96)
model = VideoSwinEncoderDecoder(use_checkpoint=True, depths=[2, 2, 6, 2], embed_dim=96)
state_dict = torch.load(checkpoint)["state_dict"]
state_dict = dict([(k[len("net_g."):], v) for k, v in state_dict.items() if k.startswith("net_g.")])
model.load_state_dict(state_dict)
Expand Down Expand Up @@ -69,7 +69,8 @@ def main(args):

for i in range(output.shape[0]):

video_clip = osp.dirname(img_name[i])
video_clip = img_name[i].split("/")[0]
print("Video clip: ", video_clip)
(results_path / Path(video_clip)).mkdir(parents=True, exist_ok=True)

restored = (output[i] * 255).astype(np.uint8)
Expand All @@ -86,11 +87,11 @@ def main(args):
combined_video_writer.release()
last_clip = video_clip
restored_video_writer = cv2.VideoWriter(f"{restored_video_path}/{video_clip}.mp4",
cv2.VideoWriter_fourcc(*'mp4v'), 25,
cv2.VideoWriter_fourcc(*'mp4v'), args.fps,
restored.shape[0:2])
combined_shape = (restored.shape[0] * 2, restored.shape[1])
combined_video_writer = cv2.VideoWriter(f"{combined_video_path}/{video_clip}.mp4",
cv2.VideoWriter_fourcc(*'mp4v'), 25,
cv2.VideoWriter_fourcc(*'mp4v'), args.fps,
combined_shape)

restored_video_writer.write(restored)
Expand All @@ -108,7 +109,7 @@ def main(args):
parser.add_argument("--data-base-path", type=str)
parser.add_argument("--results-path", type=str, default="results")
parser.add_argument("--patch-size", type=int, default=768)
parser.add_argument("--no-checkpoint", default=True, action="store_false")
parser.add_argument("--fps", type=int, default=60)
args = parser.parse_args()

main(args)
102 changes: 102 additions & 0 deletions src/recurrent_cnn_pl_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import torch
import pytorch_lightning as pl
import torchmetrics.image
import torchmetrics
from torchvision.transforms.functional import to_pil_image
from torchmetrics.functional.image.ssim import structural_similarity_index_measure
import os.path as osp
from einops import rearrange

from losses import CharbonnierLoss, PerceptualLoss


class RecurrentCNNModule(pl.LightningModule):

def __init__(self, opt, generator=None, window_size=5, pixel_loss_weight=200, perceptual_loss_weight=1):
super(RecurrentCNNModule, self).__init__()
self.save_hyperparameters(ignore=["generator"])
self.opt = opt
self.window_size = window_size

self.net_g = generator

self.lr = 1.9e-5
weight_pixel_criterion = pixel_loss_weight
self.pixel_criterion = CharbonnierLoss(loss_weight=weight_pixel_criterion)

vgg_layer_weights = {'conv5_4': 1, 'relu4_4': 1, 'relu3_4': 1, 'relu2_2': 1}
weight_perceptual_criterion = perceptual_loss_weight
self.perceptual_criterion = PerceptualLoss(layer_weights=vgg_layer_weights, loss_weight=weight_perceptual_criterion)

self.psnr = torchmetrics.PeakSignalNoiseRatio(data_range=1.0)
self.ssim = structural_similarity_index_measure
self.lpips = torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type="alex")

def forward(self, *x):
return self.net_g(*x)

def training_step(self, batch, batch_idx):
imgs_lq = batch["imgs_lq"]
imgs_gt = batch["imgs_gt"]
outputs_g = self.net_g(imgs_lq)

B, T, C, H, W = outputs_g.shape
outputs_g = rearrange(outputs_g, 'b t c h w -> (b t) c h w', b=B, t=T)
imgs_gt = rearrange(imgs_gt, 'b t c h w -> (b t) c h w', b=B, t=T)

pixel_loss_g = self.pixel_criterion(outputs_g, imgs_gt)
perceptual_loss_g = self.perceptual_criterion(outputs_g, imgs_gt)

total_loss_g = pixel_loss_g + perceptual_loss_g

log_loss_g = {"total_g": total_loss_g,
"pixel_g": pixel_loss_g,
"perceptual_g": perceptual_loss_g}
self.log_dict(log_loss_g, on_epoch=True, on_step=True, prog_bar=True, logger=True, sync_dist=True, batch_size=imgs_gt.shape[0])
return total_loss_g

def validation_step(self, batch, batch_idx):
imgs_lq = batch["imgs_lq"]
imgs_gt = batch["imgs_gt"]
outputs_g = self.net_g(imgs_lq).to(torch.float32)

psnr, ssim, lpips = 0., 0., 0.
for i, output_g in enumerate(outputs_g):
output_g = torch.clamp(output_g, 0, 1)
img_gt = imgs_gt[i]
psnr += self.psnr(output_g, img_gt)
ssim += self.ssim(output_g, img_gt, data_range=1.)
with torch.no_grad():
lpips += self.lpips(output_g * 2 - 1, img_gt * 2 - 1) # Input must be in [-1, 1] range
psnr /= len(outputs_g)
ssim /= len(outputs_g)
lpips /= len(outputs_g)

log_metrics = {"psnr": psnr,
"ssim": ssim,
"lpips": lpips}

self.log_dict(log_metrics, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=imgs_gt.shape[0])

imgs_name = batch["img_name"]
if batch_idx == 0:
imgs_name = [imgs_name[0]]
for i, img_name in enumerate(imgs_name):
img_num = int(osp.basename(img_name)[:-4])
if img_num % 100 == 0 or batch_idx == 0:
single_img_lq = imgs_lq[0, self.window_size // 2]
single_img_gt = imgs_gt[0, self.window_size // 2]
single_img_output = torch.clamp(outputs_g[0, self.window_size // 2], 0., 1.)
concatenated_img = torch.cat((single_img_lq, single_img_output, single_img_gt), -1)
self.logger.experiment.log_image(to_pil_image(concatenated_img.cpu()), str(img_num), step=self.current_epoch)

def test_step(self, batch, batch_idx):
return self.validation_step(batch, batch_idx)

def on_validation_epoch_end(self) -> None:
self.psnr.reset()
self.lpips.reset()

def configure_optimizers(self):
optimizer_g = torch.optim.AdamW(self.net_g.parameters(), lr=self.lr, weight_decay=0.01, betas=(0.9, 0.99))
return optimizer_g
122 changes: 122 additions & 0 deletions src/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import comet_ml
import os
import pytorch_lightning as pl
import torch
from pathlib import Path
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from argparse import ArgumentParser

from video_swin_unet import VideoSwinEncoderDecoder
from video_recurrent_dataset import VideoRecurrentDataset
from video_data_pl_module import VideoDataModule
from recurrent_cnn_pl_module import RecurrentCNNModule
from utils import init_logger


def main(args):
training_params = args.training_params
data_params = args.data_params
model_params = args.model_params

pl.seed_everything(42, workers=True)
os.environ['PYTHONHASHSEED'] = str(42)

experiment_name = args.experiment_name
experiment_key = args.experiment_key
online_logger = not args.offline

data_base_path = Path(args.data_base_path)

training_input_path = data_base_path / "train" / "input"
training_gt_path = data_base_path / "train" / "gt"
val_input_path = data_base_path / "val" / "input"
val_gt_path = data_base_path / "val" / "gt"

checkpoints_path = Path(args.checkpoints_path) / experiment_name
logger = init_logger(api_key=args.api_key, experiment_name=experiment_name, experiment_key=experiment_key, online=online_logger)
args.training_params["logger"] = logger
logger.experiment.log_parameters(training_params, prefix="training")
logger.experiment.log_parameters(data_params, prefix="data")

train_dataset = VideoRecurrentDataset(training_input_path, training_gt_path,
window_size=data_params["window_size"],
frame_offset=data_params["frame_offset"],
gt_patch_size=data_params["gt_patch_size"],
crop_mode="random")
val_dataset = VideoRecurrentDataset(val_input_path, val_gt_path,
window_size=data_params["window_size"],
frame_offset=data_params["frame_offset"],
gt_patch_size=data_params["gt_patch_size"],
crop_mode="center")
data_module = VideoDataModule(train_dataset, val_dataset, batch_size=data_params["batch_size"],
num_workers=data_params["num_workers"])

checkpoint_callback = ModelCheckpoint(dirpath=checkpoints_path,
filename="{epoch}-{step}-{lpips:.3f}",
save_weights_only=False,
monitor="lpips",
save_top_k=1,
save_last=True)

generator = VideoSwinEncoderDecoder(use_checkpoint=True, depths=[2, 2, 6, 2], embed_dim=96)
model = RecurrentCNNModule(opt=None, generator=generator,
window_size=data_params["window_size"],
pixel_loss_weight=model_params["pixel_loss_weight"],
perceptual_loss_weight=model_params["perceptual_loss_weight"])

trainer = Trainer(**training_params, callbacks=[checkpoint_callback])

trainer.fit(model, data_module.train_dataloader(), data_module.val_dataloader())


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--experiment-name", type=str, default="video_swin_unet")
parser.add_argument("--data-base-path", type=str)
parser.add_argument("--checkpoints-path", type=str, default="pretrained_models")
parser.add_argument("--devices", type=int, nargs="+", default=[0])
parser.add_argument("--resume-from-checkpoint", default=False, action="store_true")
parser.add_argument("--resume-checkpoint-filename", type=str, default="")
parser.add_argument("--api-key", type=str, default="")
parser.add_argument("--offline", default=False, action="store_true")
parser.add_argument("--experiment-key", type=str, default=None)
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--num-epochs", type=int, default=100)
parser.add_argument("--num-workers", type=int, default=20)
parser.add_argument("--pixel-loss-weight", type=float, default=200)
parser.add_argument("--perceptual-loss-weight", type=float, default=1)
parser.add_argument("--no-ddp-strategy", default=False, action="store_true")
args = parser.parse_args()

training_params = {
"benchmark": True,
"precision": 16,
"log_every_n_steps": 50,
"accelerator": "gpu",
"devices": args.devices,
"strategy": None if args.no_ddp_strategy else DDPStrategy(find_unused_parameters=True, static_graph=True),
"max_epochs": args.num_epochs,
"resume_from_checkpoint": None if not args.resume_from_checkpoint else Path(args.checkpoints_path) / args.experiment_name / args.resume_checkpoint_filename
}

data_params = {
"window_size": 5,
"frame_offset": 1,
"batch_size": args.batch_size,
"num_workers": args.num_workers,
"gt_patch_size": 128
}

model_params = {
"pixel_loss_weight": args.pixel_loss_weight,
"perceptual_loss_weight": args.perceptual_loss_weight,
}

args.training_params = training_params
args.data_params = data_params
args.model_params = model_params

main(args)

Loading

0 comments on commit f4c8599

Please sign in to comment.