Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a decode+resize benchmark and cuda decoder #378

Merged
merged 6 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 96 additions & 1 deletion benchmarks/decoders/benchmark_decoders_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ def __init__(self):
def get_frames_from_video(self, video_file, pts_list):
pass

@abc.abstractmethod
def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
pass

@abc.abstractmethod
def decode_and_transform(self, video_file, pts_list, height, width, device):
pass


class DecordAccurate(AbstractDecoder):
def __init__(self):
Expand Down Expand Up @@ -89,8 +97,10 @@ def __init__(self, backend):
self._backend = backend
self._print_each_iteration_time = False
import torchvision # noqa: F401
from torchvision.transforms import v2 as transforms_v2

self.torchvision = torchvision
self.transforms_v2 = transforms_v2

def get_frames_from_video(self, video_file, pts_list):
self.torchvision.set_video_backend(self._backend)
Expand All @@ -111,6 +121,18 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
frames.append(frame["data"].permute(1, 2, 0))
return frames

def decode_and_transform(self, video_file, pts_list, height, width, device):
self.torchvision.set_video_backend(self._backend)
reader = self.torchvision.io.VideoReader(video_file, "video")
frames = []
for pts in pts_list:
reader.seek(pts)
frame = next(reader)
frames.append(frame["data"].permute(1, 2, 0))
frames = [frame.to(device) for frame in frames]
frames = self.transforms_v2.functional.resize(frames, (height, width))
Copy link
Contributor

@scotts scotts Nov 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naive question that applies to all implementations that use the transformation: how do we ensure it's done on the GPU?

Realized after walking away from my laptop that it's controlled by where the data lives, not by some parameter on the transform. So to answer my question: in the frame.to(device) call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct

return frames


class TorchCodecCore(AbstractDecoder):
def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"):
Expand Down Expand Up @@ -239,6 +261,12 @@ def __init__(self, num_ffmpeg_threads=None, device="cpu"):
)
self._device = device

import torchvision # noqa: F401
from torchvision.transforms import v2 as transforms_v2
scotts marked this conversation as resolved.
Show resolved Hide resolved

self.torchvision = torchvision
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we actually use self.torchvision? We should be able to remove it.

Copy link
Contributor Author

@ahmadsharif1 ahmadsharif1 Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. done.

self.transforms_v2 = transforms_v2

def get_frames_from_video(self, video_file, pts_list):
decoder = VideoDecoder(
video_file, num_ffmpeg_threads=self._num_ffmpeg_threads, device=self._device
Expand All @@ -258,6 +286,14 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
break
return frames

def decode_and_transform(self, video_file, pts_list, height, width, device):
decoder = VideoDecoder(
video_file, num_ffmpeg_threads=self._num_ffmpeg_threads, device=self._device
)
frames = decoder.get_frames_played_at(pts_list)
frames = self.transforms_v2.functional.resize(frames.data, (height, width))
return frames


@torch.compile(fullgraph=True, backend="eager")
def compiled_seek_and_next(decoder, pts):
Expand Down Expand Up @@ -299,7 +335,9 @@ def __init__(self):

self.torchaudio = torchaudio

pass
from torchvision.transforms import v2 as transforms_v2

self.transforms_v2 = transforms_v2

def get_frames_from_video(self, video_file, pts_list):
stream_reader = self.torchaudio.io.StreamReader(src=video_file)
Expand All @@ -325,6 +363,19 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):

return frames

def decode_and_transform(self, video_file, pts_list, height, width, device):
stream_reader = self.torchaudio.io.StreamReader(src=video_file)
stream_reader.add_basic_video_stream(frames_per_chunk=1)
frames = []
for pts in pts_list:
stream_reader.seek(pts)
stream_reader.fill_buffer()
clip = stream_reader.pop_chunks()
frames.append(clip[0][0])
frames = [frame.to(device) for frame in frames]
frames = self.transforms_v2.functional.resize(frames, (height, width))
return frames


def create_torchcodec_decoder_from_file(video_file):
video_decoder = create_from_file(video_file)
Expand Down Expand Up @@ -486,6 +537,14 @@ class BatchParameters:
batch_size: int


@dataclass
class DataLoaderInspiredWorkloadParameters:
batch_parameters: BatchParameters
resize_height: int
resize_width: int
resize_device: str


def run_batch_using_threads(
function,
*args,
Expand Down Expand Up @@ -525,6 +584,7 @@ def run_benchmarks(
num_sequential_frames_from_start: list[int],
min_runtime_seconds: float,
benchmark_video_creation: bool,
dataloader_parameters: DataLoaderInspiredWorkloadParameters = None,
batch_parameters: BatchParameters = None,
) -> list[dict[str, str | float | int]]:
# Ensure that we have the same seed across benchmark runs.
Expand All @@ -535,6 +595,8 @@ def run_benchmarks(
results = []
df_data = []
verbose = False
# TODO: change this back before landing.
min_runtime_seconds = 0.1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for video_file_path in video_files_paths:
metadata = get_metadata(video_file_path)
metadata_label = f"{metadata.codec} {metadata.width}x{metadata.height}, {metadata.duration_seconds}s {metadata.average_fps}fps"
Expand All @@ -550,6 +612,39 @@ def run_benchmarks(
for decoder_name, decoder in decoder_dict.items():
print(f"video={video_file_path}, decoder={decoder_name}")

if dataloader_parameters:
bp = dataloader_parameters.batch_parameters
dataloader_result = benchmark.Timer(
stmt="run_batch_using_threads(decoder.decode_and_transform, video_file, pts_list, height, width, device, batch_parameters=batch_parameters)",
globals={
"video_file": str(video_file_path),
"pts_list": uniform_pts_list,
"decoder": decoder,
"run_batch_using_threads": run_batch_using_threads,
"batch_parameters": dataloader_parameters.batch_parameters,
"height": dataloader_parameters.resize_height,
"width": dataloader_parameters.resize_width,
"device": dataloader_parameters.resize_device,
},
label=f"video={video_file_path} {metadata_label}",
sub_label=decoder_name,
description=f"dataloader[threads={bp.num_threads} batch_size={bp.batch_size}] {num_samples} decode_and_transform()",
)
results.append(
dataloader_result.blocked_autorange(
min_run_time=min_runtime_seconds
)
)
df_data.append(
convert_result_to_df_item(
results[-1],
decoder_name,
video_file_path,
num_samples * dataloader_parameters.batch_parameters.batch_size,
f"dataloader[threads={bp.num_threads} batch_size={bp.batch_size}] {num_samples} x decode_and_transform()",
)
)

for kind, pts_list in [
("uniform", uniform_pts_list),
("random", random_pts_list),
Expand Down
Loading
Loading