-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
866b60c
commit f4c8599
Showing
10 changed files
with
830 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
import torch.nn as nn | ||
from torchvision.transforms import InterpolationMode | ||
BICUBIC = InterpolationMode.BICUBIC | ||
|
||
from vgg_feature_extractor import VGGFeatureExtractor | ||
|
||
|
||
def l1_loss(pred, target): | ||
return F.l1_loss(pred, target, reduction="none") | ||
|
||
|
||
def charbonnier_loss(pred, target, eps=1e-12): | ||
return torch.sqrt((pred - target) ** 2 + eps).mean() | ||
|
||
|
||
class CharbonnierLoss(nn.Module): | ||
"""Charbonnier loss (one variant of Robust L1Loss, a differentiable | ||
variant of L1Loss). | ||
Described in "Deep Laplacian Pyramid Networks for Fast and Accurate | ||
Super-Resolution". | ||
Args: | ||
eps (float): A value used to control the curvature near zero. Default: 1e-12. | ||
loss_weight (float): Loss weight for L1 loss. Default: 1.0. | ||
""" | ||
|
||
def __init__(self, eps=1e-12, loss_weight=1.0): | ||
super(CharbonnierLoss, self).__init__() | ||
self.loss_weight = loss_weight | ||
self.eps = eps | ||
|
||
def forward(self, pred, target): | ||
""" | ||
Args: | ||
pred (Tensor): of shape (N, C, H, W). Predicted tensor. | ||
target (Tensor): of shape (N, C, H, W). Ground truth tensor. | ||
""" | ||
return self.loss_weight * charbonnier_loss(pred, target, eps=self.eps) | ||
|
||
|
||
class PerceptualLoss(nn.Module): | ||
"""VGG Perceptual loss | ||
""" | ||
|
||
def __init__(self, layer_weights, use_input_norm=True, use_range_norm=False, criterion='l2', loss_weight=1.0): | ||
super(PerceptualLoss, self).__init__() | ||
self.layer_weights = layer_weights | ||
self.vgg = VGGFeatureExtractor(layer_name_list=list(layer_weights.keys()), | ||
use_input_norm=use_input_norm, | ||
use_range_norm=use_range_norm) | ||
self.criterion_type = criterion | ||
if self.criterion_type == 'l1': | ||
self.criterion = torch.nn.L1Loss() | ||
elif self.criterion_type == 'l2': | ||
self.criterion = torch.nn.MSELoss() | ||
else: | ||
raise NotImplementedError(f'{criterion} criterion is not supported.') | ||
self.loss_weight = loss_weight | ||
|
||
def forward(self, pred, target): | ||
"""Forward function. | ||
Args: | ||
pred (Tensor): Input tensor with shape (n, c, h, w). | ||
target (Tensor): Ground-truth tensor with shape (n, c, h, w). | ||
Returns: | ||
Tensor: Forward results. | ||
""" | ||
pred_feat = self.vgg(pred) | ||
target_feat = self.vgg(target.detach()) | ||
|
||
loss = 0.0 | ||
for i in pred_feat.keys(): | ||
loss += self.criterion(pred_feat[i], target_feat[i]) * self.layer_weights[i] | ||
loss *= self.loss_weight | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import torch | ||
import pytorch_lightning as pl | ||
import torchmetrics.image | ||
import torchmetrics | ||
from torchvision.transforms.functional import to_pil_image | ||
from torchmetrics.functional.image.ssim import structural_similarity_index_measure | ||
import os.path as osp | ||
from einops import rearrange | ||
|
||
from losses import CharbonnierLoss, PerceptualLoss | ||
|
||
|
||
class RecurrentCNNModule(pl.LightningModule): | ||
|
||
def __init__(self, opt, generator=None, window_size=5, pixel_loss_weight=200, perceptual_loss_weight=1): | ||
super(RecurrentCNNModule, self).__init__() | ||
self.save_hyperparameters(ignore=["generator"]) | ||
self.opt = opt | ||
self.window_size = window_size | ||
|
||
self.net_g = generator | ||
|
||
self.lr = 1.9e-5 | ||
weight_pixel_criterion = pixel_loss_weight | ||
self.pixel_criterion = CharbonnierLoss(loss_weight=weight_pixel_criterion) | ||
|
||
vgg_layer_weights = {'conv5_4': 1, 'relu4_4': 1, 'relu3_4': 1, 'relu2_2': 1} | ||
weight_perceptual_criterion = perceptual_loss_weight | ||
self.perceptual_criterion = PerceptualLoss(layer_weights=vgg_layer_weights, loss_weight=weight_perceptual_criterion) | ||
|
||
self.psnr = torchmetrics.PeakSignalNoiseRatio(data_range=1.0) | ||
self.ssim = structural_similarity_index_measure | ||
self.lpips = torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type="alex") | ||
|
||
def forward(self, *x): | ||
return self.net_g(*x) | ||
|
||
def training_step(self, batch, batch_idx): | ||
imgs_lq = batch["imgs_lq"] | ||
imgs_gt = batch["imgs_gt"] | ||
outputs_g = self.net_g(imgs_lq) | ||
|
||
B, T, C, H, W = outputs_g.shape | ||
outputs_g = rearrange(outputs_g, 'b t c h w -> (b t) c h w', b=B, t=T) | ||
imgs_gt = rearrange(imgs_gt, 'b t c h w -> (b t) c h w', b=B, t=T) | ||
|
||
pixel_loss_g = self.pixel_criterion(outputs_g, imgs_gt) | ||
perceptual_loss_g = self.perceptual_criterion(outputs_g, imgs_gt) | ||
|
||
total_loss_g = pixel_loss_g + perceptual_loss_g | ||
|
||
log_loss_g = {"total_g": total_loss_g, | ||
"pixel_g": pixel_loss_g, | ||
"perceptual_g": perceptual_loss_g} | ||
self.log_dict(log_loss_g, on_epoch=True, on_step=True, prog_bar=True, logger=True, sync_dist=True, batch_size=imgs_gt.shape[0]) | ||
return total_loss_g | ||
|
||
def validation_step(self, batch, batch_idx): | ||
imgs_lq = batch["imgs_lq"] | ||
imgs_gt = batch["imgs_gt"] | ||
outputs_g = self.net_g(imgs_lq).to(torch.float32) | ||
|
||
psnr, ssim, lpips = 0., 0., 0. | ||
for i, output_g in enumerate(outputs_g): | ||
output_g = torch.clamp(output_g, 0, 1) | ||
img_gt = imgs_gt[i] | ||
psnr += self.psnr(output_g, img_gt) | ||
ssim += self.ssim(output_g, img_gt, data_range=1.) | ||
with torch.no_grad(): | ||
lpips += self.lpips(output_g * 2 - 1, img_gt * 2 - 1) # Input must be in [-1, 1] range | ||
psnr /= len(outputs_g) | ||
ssim /= len(outputs_g) | ||
lpips /= len(outputs_g) | ||
|
||
log_metrics = {"psnr": psnr, | ||
"ssim": ssim, | ||
"lpips": lpips} | ||
|
||
self.log_dict(log_metrics, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=imgs_gt.shape[0]) | ||
|
||
imgs_name = batch["img_name"] | ||
if batch_idx == 0: | ||
imgs_name = [imgs_name[0]] | ||
for i, img_name in enumerate(imgs_name): | ||
img_num = int(osp.basename(img_name)[:-4]) | ||
if img_num % 100 == 0 or batch_idx == 0: | ||
single_img_lq = imgs_lq[0, self.window_size // 2] | ||
single_img_gt = imgs_gt[0, self.window_size // 2] | ||
single_img_output = torch.clamp(outputs_g[0, self.window_size // 2], 0., 1.) | ||
concatenated_img = torch.cat((single_img_lq, single_img_output, single_img_gt), -1) | ||
self.logger.experiment.log_image(to_pil_image(concatenated_img.cpu()), str(img_num), step=self.current_epoch) | ||
|
||
def test_step(self, batch, batch_idx): | ||
return self.validation_step(batch, batch_idx) | ||
|
||
def on_validation_epoch_end(self) -> None: | ||
self.psnr.reset() | ||
self.lpips.reset() | ||
|
||
def configure_optimizers(self): | ||
optimizer_g = torch.optim.AdamW(self.net_g.parameters(), lr=self.lr, weight_decay=0.01, betas=(0.9, 0.99)) | ||
return optimizer_g |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import comet_ml | ||
import os | ||
import pytorch_lightning as pl | ||
import torch | ||
from pathlib import Path | ||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.callbacks import ModelCheckpoint | ||
from pytorch_lightning.strategies import DDPStrategy | ||
from argparse import ArgumentParser | ||
|
||
from video_swin_unet import VideoSwinEncoderDecoder | ||
from video_recurrent_dataset import VideoRecurrentDataset | ||
from video_data_pl_module import VideoDataModule | ||
from recurrent_cnn_pl_module import RecurrentCNNModule | ||
from utils import init_logger | ||
|
||
|
||
def main(args): | ||
training_params = args.training_params | ||
data_params = args.data_params | ||
model_params = args.model_params | ||
|
||
pl.seed_everything(42, workers=True) | ||
os.environ['PYTHONHASHSEED'] = str(42) | ||
|
||
experiment_name = args.experiment_name | ||
experiment_key = args.experiment_key | ||
online_logger = not args.offline | ||
|
||
data_base_path = Path(args.data_base_path) | ||
|
||
training_input_path = data_base_path / "train" / "input" | ||
training_gt_path = data_base_path / "train" / "gt" | ||
val_input_path = data_base_path / "val" / "input" | ||
val_gt_path = data_base_path / "val" / "gt" | ||
|
||
checkpoints_path = Path(args.checkpoints_path) / experiment_name | ||
logger = init_logger(api_key=args.api_key, experiment_name=experiment_name, experiment_key=experiment_key, online=online_logger) | ||
args.training_params["logger"] = logger | ||
logger.experiment.log_parameters(training_params, prefix="training") | ||
logger.experiment.log_parameters(data_params, prefix="data") | ||
|
||
train_dataset = VideoRecurrentDataset(training_input_path, training_gt_path, | ||
window_size=data_params["window_size"], | ||
frame_offset=data_params["frame_offset"], | ||
gt_patch_size=data_params["gt_patch_size"], | ||
crop_mode="random") | ||
val_dataset = VideoRecurrentDataset(val_input_path, val_gt_path, | ||
window_size=data_params["window_size"], | ||
frame_offset=data_params["frame_offset"], | ||
gt_patch_size=data_params["gt_patch_size"], | ||
crop_mode="center") | ||
data_module = VideoDataModule(train_dataset, val_dataset, batch_size=data_params["batch_size"], | ||
num_workers=data_params["num_workers"]) | ||
|
||
checkpoint_callback = ModelCheckpoint(dirpath=checkpoints_path, | ||
filename="{epoch}-{step}-{lpips:.3f}", | ||
save_weights_only=False, | ||
monitor="lpips", | ||
save_top_k=1, | ||
save_last=True) | ||
|
||
generator = VideoSwinEncoderDecoder(use_checkpoint=True, depths=[2, 2, 6, 2], embed_dim=96) | ||
model = RecurrentCNNModule(opt=None, generator=generator, | ||
window_size=data_params["window_size"], | ||
pixel_loss_weight=model_params["pixel_loss_weight"], | ||
perceptual_loss_weight=model_params["perceptual_loss_weight"]) | ||
|
||
trainer = Trainer(**training_params, callbacks=[checkpoint_callback]) | ||
|
||
trainer.fit(model, data_module.train_dataloader(), data_module.val_dataloader()) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = ArgumentParser() | ||
parser.add_argument("--experiment-name", type=str, default="video_swin_unet") | ||
parser.add_argument("--data-base-path", type=str) | ||
parser.add_argument("--checkpoints-path", type=str, default="pretrained_models") | ||
parser.add_argument("--devices", type=int, nargs="+", default=[0]) | ||
parser.add_argument("--resume-from-checkpoint", default=False, action="store_true") | ||
parser.add_argument("--resume-checkpoint-filename", type=str, default="") | ||
parser.add_argument("--api-key", type=str, default="") | ||
parser.add_argument("--offline", default=False, action="store_true") | ||
parser.add_argument("--experiment-key", type=str, default=None) | ||
parser.add_argument("--batch-size", type=int, default=2) | ||
parser.add_argument("--num-epochs", type=int, default=100) | ||
parser.add_argument("--num-workers", type=int, default=20) | ||
parser.add_argument("--pixel-loss-weight", type=float, default=200) | ||
parser.add_argument("--perceptual-loss-weight", type=float, default=1) | ||
parser.add_argument("--no-ddp-strategy", default=False, action="store_true") | ||
args = parser.parse_args() | ||
|
||
training_params = { | ||
"benchmark": True, | ||
"precision": 16, | ||
"log_every_n_steps": 50, | ||
"accelerator": "gpu", | ||
"devices": args.devices, | ||
"strategy": None if args.no_ddp_strategy else DDPStrategy(find_unused_parameters=True, static_graph=True), | ||
"max_epochs": args.num_epochs, | ||
"resume_from_checkpoint": None if not args.resume_from_checkpoint else Path(args.checkpoints_path) / args.experiment_name / args.resume_checkpoint_filename | ||
} | ||
|
||
data_params = { | ||
"window_size": 5, | ||
"frame_offset": 1, | ||
"batch_size": args.batch_size, | ||
"num_workers": args.num_workers, | ||
"gt_patch_size": 128 | ||
} | ||
|
||
model_params = { | ||
"pixel_loss_weight": args.pixel_loss_weight, | ||
"perceptual_loss_weight": args.perceptual_loss_weight, | ||
} | ||
|
||
args.training_params = training_params | ||
args.data_params = data_params | ||
args.model_params = model_params | ||
|
||
main(args) | ||
|
Oops, something went wrong.