From 36c574a1b7bf253a8f5be85402ba40563eac9d8d Mon Sep 17 00:00:00 2001 From: Alex Hughes Date: Fri, 15 Oct 2021 10:27:49 -0700 Subject: [PATCH] Support writing the audio stream back into video When output_type == video, we can support adding back the audio stream from the original video input. --- inference.py | 79 ++++++++++++++++++++++++++++++++-------------- inference_utils.py | 42 +++++++++++++++++++----- 2 files changed, 89 insertions(+), 32 deletions(-) diff --git a/inference.py b/inference.py index a116754..0dd26ae 100644 --- a/inference.py +++ b/inference.py @@ -12,6 +12,7 @@ --seq-chunk 1 """ +import av import torch import os from torch.utils.data import DataLoader @@ -20,6 +21,8 @@ from tqdm.auto import tqdm from inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriter +from inference_utils import AudioVideoWriter + def convert_video(model, input_source: str, @@ -33,6 +36,7 @@ def convert_video(model, seq_chunk: int = 1, num_workers: int = 0, progress: bool = True, + passthrough_audio: bool = True, device: Optional[str] = None, dtype: Optional[torch.dtype] = None): @@ -51,10 +55,11 @@ def convert_video(model, seq_chunk: Number of frames to process at once. Increase it for better parallelism. num_workers: PyTorch's DataLoader workers. Only use >0 for image input. progress: Show progress bar. + passthrough_audio: Should we passthrough any audio from the input video device: Only need to manually provide if model is a TorchScript freezed model. dtype: Only need to manually provide if model is a TorchScript freezed model. """ - + assert downsample_ratio is None or (downsample_ratio > 0 and downsample_ratio <= 1), 'Downsample ratio must be between 0 (exclusive) and 1 (inclusive).' assert any([output_composition, output_alpha, output_foreground]), 'Must provide at least one output.' assert output_type in ['video', 'png_sequence'], 'Only support "video" and "png_sequence" output modes.' @@ -76,26 +81,52 @@ def convert_video(model, else: source = ImageSequenceReader(input_source, transform) reader = DataLoader(source, batch_size=seq_chunk, pin_memory=True, num_workers=num_workers) - + + audio_source = None + if os.path.isfile(input_source): + container = av.open(input_source) + if container.streams.get(audio=0): + audio_source = container.streams.get(audio=0)[0] + # Initialize writers if output_type == 'video': frame_rate = source.frame_rate if isinstance(source, VideoReader) else 30 output_video_mbps = 1 if output_video_mbps is None else output_video_mbps - if output_composition is not None: - writer_com = VideoWriter( - path=output_composition, - frame_rate=frame_rate, - bit_rate=int(output_video_mbps * 1000000)) - if output_alpha is not None: - writer_pha = VideoWriter( - path=output_alpha, - frame_rate=frame_rate, - bit_rate=int(output_video_mbps * 1000000)) - if output_foreground is not None: - writer_fgr = VideoWriter( - path=output_foreground, - frame_rate=frame_rate, - bit_rate=int(output_video_mbps * 1000000)) + if passthrough_audio and audio_source: + if output_composition is not None: + writer_com = AudioVideoWriter( + path=output_composition, + frame_rate=frame_rate, + audio_stream=audio_source, + bit_rate=int(output_video_mbps * 1000000)) + if output_alpha is not None: + writer_pha = AudioVideoWriter( + path=output_alpha, + frame_rate=frame_rate, + audio_stream=audio_source, + bit_rate=int(output_video_mbps * 1000000)) + if output_foreground is not None: + writer_fgr = AudioVideoWriter( + path=output_foreground, + frame_rate=frame_rate, + audio_stream=audio_source, + bit_rate=int(output_video_mbps * 1000000)) + else: + if output_composition is not None: + writer_com = VideoWriter( + path=output_composition, + frame_rate=frame_rate, + bit_rate=int(output_video_mbps * 1000000)) + if output_alpha is not None: + writer_pha = VideoWriter( + path=output_alpha, + frame_rate=frame_rate, + bit_rate=int(output_video_mbps * 1000000)) + if output_foreground is not None: + writer_fgr = VideoWriter( + path=output_foreground, + frame_rate=frame_rate, + bit_rate=int(output_video_mbps * 1000000)) else: if output_composition is not None: writer_com = ImageSequenceWriter(output_composition, 'png') @@ -113,7 +144,7 @@ def convert_video(model, if (output_composition is not None) and (output_type == 'video'): bgr = torch.tensor([120, 255, 155], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1) - + try: with torch.no_grad(): bar = tqdm(total=len(source), disable=not progress, dynamic_ncols=True) @@ -137,7 +168,7 @@ def convert_video(model, fgr = fgr * pha.gt(0) com = torch.cat([fgr, pha], dim=-3) writer_com.write(com[0]) - + bar.update(src.size(1)) finally: @@ -167,11 +198,12 @@ def __init__(self, variant: str, checkpoint: str, device: str): def convert(self, *args, **kwargs): convert_video(self.model, device=self.device, dtype=torch.float32, *args, **kwargs) - + + if __name__ == '__main__': import argparse from model import MattingNetwork - + parser = argparse.ArgumentParser() parser.add_argument('--variant', type=str, required=True, choices=['mobilenetv3', 'resnet50']) parser.add_argument('--checkpoint', type=str, required=True) @@ -188,7 +220,7 @@ def convert(self, *args, **kwargs): parser.add_argument('--num-workers', type=int, default=0) parser.add_argument('--disable-progress', action='store_true') args = parser.parse_args() - + converter = Converter(args.variant, args.checkpoint, args.device) converter.convert( input_source=args.input_source, @@ -203,5 +235,4 @@ def convert(self, *args, **kwargs): num_workers=args.num_workers, progress=not args.disable_progress ) - - + diff --git a/inference_utils.py b/inference_utils.py index 1e92fd4..9da7c99 100644 --- a/inference_utils.py +++ b/inference_utils.py @@ -12,14 +12,14 @@ def __init__(self, path, transform=None): self.video = pims.PyAVVideoReader(path) self.rate = self.video.frame_rate self.transform = transform - + @property def frame_rate(self): return self.rate - + def __len__(self): return len(self.video) - + def __getitem__(self, idx): frame = self.video[idx] frame = Image.fromarray(np.asarray(frame)) @@ -57,10 +57,10 @@ def __init__(self, path, transform=None): self.path = path self.files = sorted(os.listdir(path)) self.transform = transform - + def __len__(self): return len(self.files) - + def __getitem__(self, idx): with Image.open(os.path.join(self.path, self.files[idx])) as img: img.load() @@ -75,14 +75,40 @@ def __init__(self, path, extension='jpg'): self.extension = extension self.counter = 0 os.makedirs(path, exist_ok=True) - + def write(self, frames): # frames: [T, C, H, W] for t in range(frames.shape[0]): to_pil_image(frames[t]).save(os.path.join( self.path, str(self.counter).zfill(4) + '.' + self.extension)) self.counter += 1 - + def close(self): pass - \ No newline at end of file + + +class AudioVideoWriter(VideoWriter): + def __init__(self, path, frame_rate, audio_stream=None, bit_rate=1000000): + super(AudioVideoWriter, self).__init__( + path=path, + frame_rate=frame_rate, + bit_rate=bit_rate + ) + self.source_audio_stream = audio_stream + self.output_audio_stream = self.container.add_stream( + codec_name=self.source_audio_stream.codec_context.codec.name, + rate=self.source_audio_stream.rate, + ) + + def remux_audio(self): + input_audio_container = self.source_audio_stream.container + for packet in input_audio_container.demux(self.source_audio_stream): + if packet.dts is None: + continue + packet.stream = self.output_audio_stream + self.container.mux(packet) + + def close(self): + self.remux_audio() + self.container.mux(self.output_audio_stream.encode()) + super(AudioVideoWriter, self).close()