From 2146cc8b8fc132899a9c6f19fc52a0b5caa11fb4 Mon Sep 17 00:00:00 2001 From: Alex Hughes Date: Thu, 14 Oct 2021 13:46:48 -0700 Subject: [PATCH 1/2] Allow providing a background image to the video --- inference.py | 13 +++++++++++-- inference_utils.py | 17 +++++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/inference.py b/inference.py index a116754..6f66320 100644 --- a/inference.py +++ b/inference.py @@ -20,6 +20,8 @@ from tqdm.auto import tqdm from inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriter +from inference_utils import ImageReader + def convert_video(model, input_source: str, @@ -27,6 +29,7 @@ def convert_video(model, downsample_ratio: Optional[float] = None, output_type: str = 'video', output_composition: Optional[str] = None, + bgr_source: Optional[str] = None, output_alpha: Optional[str] = None, output_foreground: Optional[str] = None, output_video_mbps: Optional[float] = None, @@ -46,6 +49,8 @@ def convert_video(model, The composition output path. File path if output_type == 'video'. Directory path if output_type == 'png_sequence'. If output_type == 'video', the composition has green screen background. If output_type == 'png_sequence'. the composition is RGBA png images. + bgr_source: A video file, image sequence directory, or an individual image. + This is only applicable if you choose output_type == video. output_alpha: The alpha output from the model. output_foreground: The foreground output from the model. seq_chunk: Number of frames to process at once. Increase it for better parallelism. @@ -112,8 +117,12 @@ def convert_video(model, device = param.device if (output_composition is not None) and (output_type == 'video'): - bgr = torch.tensor([120, 255, 155], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1) - + if bgr_source is not None and os.path.isfile(bgr_source): + bgr = ImageReader(bgr_source, transform=transform).data() + bgr = bgr.to(device, dtype, non_blocking=True).unsqueeze(0) # [B, T, C, H, W] + else: + bgr = torch.tensor([120, 255, 155], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1) + try: with torch.no_grad(): bar = tqdm(total=len(source), disable=not progress, dynamic_ncols=True) diff --git a/inference_utils.py b/inference_utils.py index 1e92fd4..60f6ab2 100644 --- a/inference_utils.py +++ b/inference_utils.py @@ -3,7 +3,7 @@ import pims import numpy as np from torch.utils.data import Dataset -from torchvision.transforms.functional import to_pil_image +from torchvision.transforms.functional import to_pil_image, pil_to_tensor from PIL import Image @@ -52,6 +52,20 @@ def close(self): self.container.close() +class ImageReader: + def __init__(self, path, transform=None): + self.path = path + self.transform = transform + + def data(self): + with Image.open(self.path) as img: + img.load() + + if self.transform is not None: + return self.transform(img) + return img + + class ImageSequenceReader(Dataset): def __init__(self, path, transform=None): self.path = path @@ -85,4 +99,3 @@ def write(self, frames): def close(self): pass - \ No newline at end of file From 1e97adf385df766fd4ba351f5abd6ec20ff2cc82 Mon Sep 17 00:00:00 2001 From: Alex Hughes Date: Fri, 15 Oct 2021 10:24:31 -0700 Subject: [PATCH 2/2] Support different background source types This allows you to provide Video, ImageSequence, Image or ConstantColour background sources --- inference.py | 19 ++++++++++++------- inference_utils.py | 24 +++++++++++++++++++----- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/inference.py b/inference.py index 6f66320..c47e4bb 100644 --- a/inference.py +++ b/inference.py @@ -20,7 +20,7 @@ from tqdm.auto import tqdm from inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriter -from inference_utils import ImageReader +from inference_utils import ImageReader, ConstantImage def convert_video(model, @@ -115,20 +115,24 @@ def convert_video(model, param = next(model.parameters()) dtype = param.dtype device = param.device - + if (output_composition is not None) and (output_type == 'video'): if bgr_source is not None and os.path.isfile(bgr_source): - bgr = ImageReader(bgr_source, transform=transform).data() - bgr = bgr.to(device, dtype, non_blocking=True).unsqueeze(0) # [B, T, C, H, W] + if os.path.isfile(bgr_source): + if os.path.splitext(bgr_source)[-1].lower() in [".png", ".jpg"]: + bgr_raw = ImageReader(bgr_source, transform=transform) + else: + bgr_raw = VideoReader(bgr_source, transform) + else: + bgr_raw = ImageSequenceReader(bgr_source, transform) else: - bgr = torch.tensor([120, 255, 155], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1) + bgr_raw = ConstantImage(120, 255, 155, device=device, dtype=dtype) try: with torch.no_grad(): bar = tqdm(total=len(source), disable=not progress, dynamic_ncols=True) rec = [None] * 4 - for src in reader: - + for index, src in enumerate(reader): if downsample_ratio is None: downsample_ratio = auto_downsample_ratio(*src.shape[2:]) @@ -141,6 +145,7 @@ def convert_video(model, writer_pha.write(pha[0]) if output_composition is not None: if output_type == 'video': + bgr = bgr_raw[index].to(device, dtype, non_blocking=True).unsqueeze(0) # [B, T, C, H, W] com = fgr * pha + bgr * (1 - pha) else: fgr = fgr * pha.gt(0) diff --git a/inference_utils.py b/inference_utils.py index 60f6ab2..68227a3 100644 --- a/inference_utils.py +++ b/inference_utils.py @@ -2,6 +2,7 @@ import os import pims import numpy as np +import torch from torch.utils.data import Dataset from torchvision.transforms.functional import to_pil_image, pil_to_tensor from PIL import Image @@ -34,7 +35,7 @@ def __init__(self, path, frame_rate, bit_rate=1000000): self.stream = self.container.add_stream('h264', rate=round(frame_rate)) self.stream.pix_fmt = 'yuv420p' self.stream.bit_rate = bit_rate - + def write(self, frames): # frames: [T, C, H, W] self.stream.width = frames.size(3) @@ -46,26 +47,39 @@ def write(self, frames): frame = frames[t] frame = av.VideoFrame.from_ndarray(frame, format='rgb24') self.container.mux(self.stream.encode(frame)) - + def close(self): self.container.mux(self.stream.encode()) self.container.close() -class ImageReader: +class ImageReader(Dataset): def __init__(self, path, transform=None): self.path = path self.transform = transform - def data(self): + def __len__(self): + return 1 + + def __getitem__(self, idx): with Image.open(self.path) as img: img.load() - if self.transform is not None: return self.transform(img) return img +class ConstantImage(Dataset): + def __init__(self, r, g, b, device=None, dtype=None): + self.tensor = torch.tensor([r, g, b], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1) + + def __len__(self): + return 1 + + def __getitem__(self, idx): + return self.tensor + + class ImageSequenceReader(Dataset): def __init__(self, path, transform=None): self.path = path