Skip to content

Commit

Permalink
[torchcodec] Add support for Nvidia GPU Decoding (#137)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #137

Pull Request resolved: #58

X-link: #58

1. Add CUDA support to VideoDecoder.cpp. This is done by checking what device is passed into the options and using CUDA if the device type is cuda.
2. Add -DENABLE_CUDA flag in cmake.
3. Check ENABLE_CUDA environment variable in setup.py and pass it down to cmake if it is present.
4. Add a unit test to demonstrate that CUDA decoding does work. This uses a different tensor than the one from CPU decoding because hardware decoding is intrinsically a bit inaccurate. I generated the reference tensor by dumping the tensor from the GPU on my devVM. It is possible different Nvidia hardware show different outputs. How to test this in a more robust way is TBD.
5. Added a new parameter for cuda device index for `add_video_stream`. If this is present, we will use it to do hardware decoding on a CUDA device.

There is a whole bunch of TODOs:
1. Currently GPU utilization is only 7-8% when decoding the video. We need to get this higher.
2. Speed it up compared to CPU implementation. Currently this is slower than CPU decoding even for HD videos (probably because we can't hide the CPU to GPU memcpy). However, decode+resize is faster as the benchmark says.

Reviewed By: scotts

Differential Revision: D59121006
  • Loading branch information
ahmadsharif1 authored and facebook-github-bot committed Jul 31, 2024
1 parent 361968f commit 7826e2d
Show file tree
Hide file tree
Showing 18 changed files with 400 additions and 23 deletions.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
cmake_minimum_required(VERSION 3.18)
project(TorchCodec)

option(ENABLE_CUDA "Enable CUDA decoding using NVDEC" OFF)
option(ENABLE_NVTX "Enable NVTX annotations for profiling" OFF)

add_subdirectory(src/torchcodec/decoders/_core)


Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,7 @@ guide](CONTRIBUTING.md) for more details.
## License

TorchCodec is released under the [BSD 3 license](./LICENSE).


If you are building with ENABLE_CUDA and/or ENABLE_NVTX please review
[Nvidia licenses](https://docs.nvidia.com/cuda/eula/index.html).
3 changes: 2 additions & 1 deletion benchmarks/decoders/BenchmarkDecodersMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ void runNDecodeIterationsWithCustomOps(
/*height=*/std::nullopt,
/*thread_count=*/std::nullopt,
/*dimension_order=*/std::nullopt,
/*stream_index=*/std::nullopt);
/*stream_index=*/std::nullopt,
/*device_string=*/std::nullopt);

for (double pts : ptsList) {
seekFrameOp.call(decoderTensor, pts);
Expand Down
101 changes: 101 additions & 0 deletions benchmarks/decoders/manual_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import argparse
import os
import time

import torch.utils.benchmark as benchmark

import torchcodec
from torchvision.transforms import Resize


def transfer_and_resize_frame(frame, device):
# This should be a no-op if the frame is already on the device.
frame = frame.to(device)
frame = Resize((256, 256))(frame)
return frame


def decode_full_video(video_path, decode_device):
decoder = torchcodec.decoders._core.create_from_file(video_path)
num_threads = None
if "cuda" in decode_device:
num_threads = 1
torchcodec.decoders._core.add_video_stream(
decoder, stream_index=0, device_string=decode_device, num_threads=num_threads
)
start_time = time.time()
frame_count = 0
while True:
try:
frame, *_ = torchcodec.decoders._core.get_next_frame(decoder)
# You can do a resize to simulate extra preproc work that happens
# on the GPU by uncommenting the following line:
# frame = transfer_and_resize_frame(frame, decode_device)

frame_count += 1
except Exception as e:
print("EXCEPTION", e)
break
# print(f"current {frame_count=}", flush=True)
end_time = time.time()
elapsed = end_time - start_time
fps = frame_count / (end_time - start_time)
print(
f"****** DECODED full video {decode_device=} {frame_count=} {elapsed=} {fps=}"
)
return frame_count, end_time - start_time


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--devices",
default="cuda:0,cpu",
type=str,
help="Comma-separated devices to test decoding on.",
)
parser.add_argument(
"--video",
type=str,
default=os.path.dirname(__file__) + "/../../test/resources/nasa_13013.mp4",
)
parser.add_argument(
"--use_torch_benchmark",
action=argparse.BooleanOptionalAction,
default=True,
help=(
"Use pytorch benchmark to measure decode time with warmup and "
"autorange. Without this we just run one iteration without warmup "
"to measure the cold start time."
),
)
args = parser.parse_args()
video_path = args.video

if not args.use_torch_benchmark:
for device in args.devices.split(","):
print("Testing on", device)
decode_full_video(video_path, device)
return

results = []
for device in args.devices.split(","):
print("device", device)
t = benchmark.Timer(
stmt="decode_full_video(video_path, device)",
globals={
"device": device,
"video_path": video_path,
"decode_full_video": decode_full_video,
},
label="Decode+Resize Time",
sub_label=f"video={os.path.basename(video_path)}",
description=f"decode_device={device}",
).blocked_autorange()
results.append(t)
compare = benchmark.Compare(results)
compare.print()


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,16 @@ def _build_all_extensions_with_cmake(self):
torch_dir = Path(torch.utils.cmake_prefix_path) / "Torch"
cmake_build_type = os.environ.get("CMAKE_BUILD_TYPE", "Release")
python_version = sys.version_info
enable_cuda = os.environ.get("ENABLE_CUDA", "")
enable_nvtx = os.environ.get("ENABLE_NVTX", "")
cmake_args = [
f"-DCMAKE_INSTALL_PREFIX={self._install_prefix}",
f"-DTorch_DIR={torch_dir}",
"-DCMAKE_VERBOSE_MAKEFILE=ON",
f"-DCMAKE_BUILD_TYPE={cmake_build_type}",
f"-DPYTHON_VERSION={python_version.major}.{python_version.minor}",
f"-DENABLE_CUDA={enable_cuda}",
f"-DENABLE_NVTX={enable_nvtx}",
]

Path(self.build_temp).mkdir(parents=True, exist_ok=True)
Expand Down
33 changes: 33 additions & 0 deletions src/torchcodec/decoders/_core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

find_package(Torch REQUIRED)

if(ENABLE_CUDA)
find_package(CUDA REQUIRED)
endif()

if(ENABLE_NVTX)
file(
DOWNLOAD
https://github.com/cpm-cmake/CPM.cmake/releases/download/v0.38.3/CPM.cmake
${CMAKE_CURRENT_BINARY_DIR}/cmake/CPM.cmake
EXPECTED_HASH SHA256=cc155ce02e7945e7b8967ddfaff0b050e958a723ef7aad3766d368940cb15494
)
include(${CMAKE_CURRENT_BINARY_DIR}/cmake/CPM.cmake)
CPMAddPackage(
NAME NVTX
GITHUB_REPOSITORY NVIDIA/NVTX
GIT_TAG v3.1.0-c-cpp
GIT_SHALLOW TRUE)
endif()

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development)

Expand All @@ -19,6 +39,12 @@ function(make_torchcodec_library library_name ffmpeg_target)
)
add_library(${library_name} SHARED ${sources})
set_property(TARGET ${library_name} PROPERTY CXX_STANDARD 17)
if(ENABLE_CUDA)
target_compile_definitions(${library_name} PRIVATE ENABLE_CUDA=1)
endif()
if(ENABLE_NVTX)
target_compile_definitions(${library_name} PRIVATE ENABLE_NVTX=1)
endif()

target_include_directories(
${library_name}
Expand All @@ -34,6 +60,13 @@ function(make_torchcodec_library library_name ffmpeg_target)
${ffmpeg_target}
${TORCH_LIBRARIES}
${Python3_LIBRARIES}
if(ENABLE_CUDA)
${CUDA_CUDA_LIBRARY}
endif()

if(ENABLE_NVTX)
nvtx3-cpp
endif()
)

# We already set the library_name to be libtorchcodecN, so we don't want
Expand Down
2 changes: 2 additions & 0 deletions src/torchcodec/decoders/_core/FFMPEGCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ using UniqueAVFilterInOut = std::unique_ptr<
Deleterp<AVFilterInOut, void, avfilter_inout_free>>;
using UniqueAVIOContext = std::
unique_ptr<AVIOContext, Deleterp<AVIOContext, void, avio_context_free>>;
using UniqueAVBufferRef =
std::unique_ptr<AVBufferRef, Deleterp<AVBufferRef, void, av_buffer_unref>>;

// av_find_best_stream is not const-correct before commit:
// https://github.com/FFmpeg/FFmpeg/commit/46dac8cf3d250184ab4247809bc03f60e14f4c0c
Expand Down
Loading

0 comments on commit 7826e2d

Please sign in to comment.