diff --git a/CMakeLists.txt b/CMakeLists.txt index 4dfeb060..fc8d17c2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,9 @@ cmake_minimum_required(VERSION 3.18) project(TorchCodec) +option(ENABLE_CUDA "Enable CUDA decoding using NVDEC" OFF) +option(ENABLE_NVTX "Enable NVTX annotations for profiling" OFF) + add_subdirectory(src/torchcodec/decoders/_core) diff --git a/README.md b/README.md index 18e2ec68..377d6325 100644 --- a/README.md +++ b/README.md @@ -127,3 +127,7 @@ guide](CONTRIBUTING.md) for more details. ## License TorchCodec is released under the [BSD 3 license](./LICENSE). + + +If you are building with ENABLE_CUDA and/or ENABLE_NVTX please review +[Nvidia licenses](https://docs.nvidia.com/cuda/eula/index.html). diff --git a/benchmarks/decoders/BenchmarkDecodersMain.cpp b/benchmarks/decoders/BenchmarkDecodersMain.cpp index a9762a0b..5be64e6d 100644 --- a/benchmarks/decoders/BenchmarkDecodersMain.cpp +++ b/benchmarks/decoders/BenchmarkDecodersMain.cpp @@ -145,7 +145,8 @@ void runNDecodeIterationsWithCustomOps( /*height=*/std::nullopt, /*thread_count=*/std::nullopt, /*dimension_order=*/std::nullopt, - /*stream_index=*/std::nullopt); + /*stream_index=*/std::nullopt, + /*device_string=*/std::nullopt); for (double pts : ptsList) { seekFrameOp.call(decoderTensor, pts); diff --git a/benchmarks/decoders/gpu_benchmark.py b/benchmarks/decoders/gpu_benchmark.py new file mode 100644 index 00000000..a19c1d43 --- /dev/null +++ b/benchmarks/decoders/gpu_benchmark.py @@ -0,0 +1,101 @@ +import argparse +import os +import time + +import torch.utils.benchmark as benchmark + +import torchcodec +from torchvision.transforms import Resize + + +def transfer_and_resize_frame(frame, device): + # This should be a no-op if the frame is already on the device. + frame = frame.to(device) + frame = Resize((256, 256))(frame) + return frame + + +def decode_full_video(video_path, decode_device): + decoder = torchcodec.decoders._core.create_from_file(video_path) + num_threads = None + if "cuda" in decode_device: + num_threads = 1 + torchcodec.decoders._core.add_video_stream( + decoder, stream_index=0, device_string=decode_device, num_threads=num_threads + ) + start_time = time.time() + frame_count = 0 + while True: + try: + frame, *_ = torchcodec.decoders._core.get_next_frame(decoder) + # You can do a resize to simulate extra preproc work that happens + # on the GPU by uncommenting the following line: + # frame = transfer_and_resize_frame(frame, decode_device) + + frame_count += 1 + except Exception as e: + print("EXCEPTION", e) + break + # print(f"current {frame_count=}", flush=True) + end_time = time.time() + elapsed = end_time - start_time + fps = frame_count / (end_time - start_time) + print( + f"****** DECODED full video {decode_device=} {frame_count=} {elapsed=} {fps=}" + ) + return frame_count, end_time - start_time + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--devices", + default="cuda:0,cpu", + type=str, + help="Comma-separated devices to test decoding on.", + ) + parser.add_argument( + "--video", + type=str, + default=os.path.dirname(__file__) + "/../../test/resources/nasa_13013.mp4", + ) + parser.add_argument( + "--use_torch_benchmark", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Use pytorch benchmark to measure decode time with warmup and " + "autorange. Without this we just run one iteration without warmup " + "to measure the cold start time." + ), + ) + args = parser.parse_args() + video_path = args.video + + if not args.use_torch_benchmark: + for device in args.devices.split(","): + print("Testing on", device) + decode_full_video(video_path, device) + return + + results = [] + for device in args.devices.split(","): + print("device", device) + t = benchmark.Timer( + stmt="decode_full_video(video_path, device)", + globals={ + "device": device, + "video_path": video_path, + "decode_full_video": decode_full_video, + }, + label="Decode+Resize Time", + sub_label=f"video={os.path.basename(video_path)}", + description=f"decode_device={device}", + ).blocked_autorange() + results.append(t) + compare = benchmark.Compare(results) + compare.print() + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index fb5d0278..75310cea 100644 --- a/setup.py +++ b/setup.py @@ -112,12 +112,16 @@ def _build_all_extensions_with_cmake(self): torch_dir = Path(torch.utils.cmake_prefix_path) / "Torch" cmake_build_type = os.environ.get("CMAKE_BUILD_TYPE", "Release") python_version = sys.version_info + enable_cuda = os.environ.get("ENABLE_CUDA", "") + enable_nvtx = os.environ.get("ENABLE_NVTX", "") cmake_args = [ f"-DCMAKE_INSTALL_PREFIX={self._install_prefix}", f"-DTorch_DIR={torch_dir}", "-DCMAKE_VERBOSE_MAKEFILE=ON", f"-DCMAKE_BUILD_TYPE={cmake_build_type}", f"-DPYTHON_VERSION={python_version.major}.{python_version.minor}", + f"-DENABLE_CUDA={enable_cuda}", + f"-DENABLE_NVTX={enable_nvtx}", ] Path(self.build_temp).mkdir(parents=True, exist_ok=True) diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index 0fe8f243..d8554bda 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -4,6 +4,28 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(Torch REQUIRED) + +if(ENABLE_CUDA) + find_package(CUDA REQUIRED) + + if(ENABLE_NVTX) + # We only need CPM for NVTX: + # https://github.com/NVIDIA/NVTX#cmake + file( + DOWNLOAD + https://github.com/cpm-cmake/CPM.cmake/releases/download/v0.38.3/CPM.cmake + ${CMAKE_CURRENT_BINARY_DIR}/cmake/CPM.cmake + EXPECTED_HASH SHA256=cc155ce02e7945e7b8967ddfaff0b050e958a723ef7aad3766d368940cb15494 + ) + include(${CMAKE_CURRENT_BINARY_DIR}/cmake/CPM.cmake) + CPMAddPackage( + NAME NVTX + GITHUB_REPOSITORY NVIDIA/NVTX + GIT_TAG v3.1.0-c-cpp + GIT_SHALLOW TRUE) + endif() +endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development) @@ -19,6 +41,12 @@ function(make_torchcodec_library library_name ffmpeg_target) ) add_library(${library_name} SHARED ${sources}) set_property(TARGET ${library_name} PROPERTY CXX_STANDARD 17) + if(ENABLE_CUDA) + target_compile_definitions(${library_name} PRIVATE ENABLE_CUDA=1) + endif() + if(ENABLE_NVTX) + target_compile_definitions(${library_name} PRIVATE ENABLE_NVTX=1) + endif() target_include_directories( ${library_name} @@ -28,12 +56,17 @@ function(make_torchcodec_library library_name ffmpeg_target) ${Python3_INCLUDE_DIRS} ) + set(NEEDED_LIBRARIES ${ffmpeg_target} ${TORCH_LIBRARIES} ${Python3_LIBRARIES}) + if(ENABLE_CUDA) + list(APPEND NEEDED_LIBRARIES ${CUDA_CUDA_LIBRARY}) + endif() + if(ENABLE_NVTX) + list(APPEND NEEDED_LIBRARIES nvtx3-cpp) + endif() target_link_libraries( ${library_name} PUBLIC - ${ffmpeg_target} - ${TORCH_LIBRARIES} - ${Python3_LIBRARIES} + ${NEEDED_LIBRARIES} ) # We already set the library_name to be libtorchcodecN, so we don't want diff --git a/src/torchcodec/decoders/_core/FFMPEGCommon.h b/src/torchcodec/decoders/_core/FFMPEGCommon.h index 7bb61cef..b5ad4e03 100644 --- a/src/torchcodec/decoders/_core/FFMPEGCommon.h +++ b/src/torchcodec/decoders/_core/FFMPEGCommon.h @@ -57,6 +57,8 @@ using UniqueAVFilterInOut = std::unique_ptr< Deleterp>; using UniqueAVIOContext = std:: unique_ptr>; +using UniqueAVBufferRef = + std::unique_ptr>; // av_find_best_stream is not const-correct before commit: // https://github.com/FFmpeg/FFmpeg/commit/46dac8cf3d250184ab4247809bc03f60e14f4c0c diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 0674a51b..28781490 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -5,11 +5,22 @@ // LICENSE file in the root directory of this source tree. #include "src/torchcodec/decoders/_core/VideoDecoder.h" +#include +#include #include #include +#include #include #include -#include "torch/types.h" + +#ifdef ENABLE_CUDA +#include +#include +#include +#ifdef ENABLE_NVTX +#include +#endif +#endif extern "C" { #include @@ -18,6 +29,9 @@ extern "C" { #include #include #include +#ifdef ENABLE_CUDA +#include +#endif } #include "src/torchcodec/decoders/_core/FFMPEGCommon.h" @@ -95,6 +109,87 @@ std::vector splitStringWithDelimiters( return result; } +#ifdef ENABLE_CUDA + +AVBufferRef* getCudaContext() { + enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); + TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); + int err = 0; + AVBufferRef* hw_device_ctx; + err = av_hwdevice_ctx_create( + &hw_device_ctx, + type, + nullptr, + nullptr, + // Introduced in 58.26.100: + // https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265 +#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100) + AV_CUDA_USE_CURRENT_CONTEXT +#else + 0 +#endif + ); + if (err < 0) { + TORCH_CHECK( + false, + "Failed to create specified HW device", + getFFMPEGErrorStringFromErrorCode(err)); + } + return hw_device_ctx; +} + +torch::Tensor allocateDeviceTensor( + at::IntArrayRef shape, + torch::Device device, + const torch::Dtype dtype = torch::kUInt8) { + return torch::empty( + shape, + torch::TensorOptions() + .dtype(dtype) + .layout(torch::kStrided) + .device(device)); +} + +torch::Tensor convertFrameToTensorUsingCUDA( + const AVCodecContext* codecContext, + const VideoDecoder::VideoStreamDecoderOptions& options, + const AVFrame* src) { + TORCH_CHECK( + src->format == AV_PIX_FMT_CUDA, + "Expected format to be AV_PIX_FMT_CUDA, got " + + std::string(av_get_pix_fmt_name((AVPixelFormat)src->format))); + int width = options.width.value_or(codecContext->width); + int height = options.height.value_or(codecContext->height); + NppStatus status; + NppiSize oSizeROI; + oSizeROI.width = width; + oSizeROI.height = height; + Npp8u* input[2]; + input[0] = (Npp8u*)src->data[0]; + input[1] = (Npp8u*)src->data[1]; + torch::Tensor dst = allocateDeviceTensor({height, width, 3}, options.device); + auto start = std::chrono::high_resolution_clock::now(); + status = nppiNV12ToRGB_8u_P2C3R( + input, + src->linesize[0], + static_cast(dst.data_ptr()), + dst.stride(0), + oSizeROI); + TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration duration = end - start; + VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width + << " took: " << duration.count() << "us" << std::endl; + if (options.dimensionOrder == "NCHW") { + // The docs guaranty this to return a view: + // https://pytorch.org/docs/stable/generated/torch.permute.html + dst = dst.permute({2, 0, 1}); + } + return dst; +} + +#endif + } // namespace VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions( @@ -340,13 +435,13 @@ void VideoDecoder::initializeFilterGraphForStream( inputs.reset(inputsTmp); if (ffmpegStatus < 0) { throw std::runtime_error( - "Failed to parse filter description: " + - getFFMPEGErrorStringFromErrorCode(ffmpegStatus)); + "Failed to parse filter description: " + std::string(description) + + "; " + getFFMPEGErrorStringFromErrorCode(ffmpegStatus)); } ffmpegStatus = avfilter_graph_config(filterState.filterGraph.get(), nullptr); if (ffmpegStatus < 0) { throw std::runtime_error( - "Failed to configure filter graph: " + + "Failed to configure filter graph: " + std::string(description) + "; " + getFFMPEGErrorStringFromErrorCode(ffmpegStatus)); } } @@ -395,15 +490,37 @@ void VideoDecoder::addVideoStreamDecoder( int retVal = avcodec_parameters_to_context( streamInfo.codecContext.get(), streamInfo.stream->codecpar); TORCH_CHECK_EQ(retVal, AVSUCCESS); + + if (options.device.type() == torch::DeviceType::CUDA) { +#ifdef ENABLE_CUDA + // We create a small tensor using pytorch to initialize the cuda context. + torch::Tensor dummyTensorForCudaInitialization = torch::zeros( + {1}, + torch::TensorOptions().dtype(torch::kUInt8).device(options.device)); + codecContext->hw_device_ctx = av_buffer_ref(getCudaContext()); + + TORCH_INTERNAL_ASSERT( + codecContext->hw_device_ctx, + "Failed to create/reference the CUDA HW device context for index=" + + std::to_string(options.device.index()) + "."); +#else + throw std::runtime_error( + "CUDA support is not enabled in this build of TorchCodec."); +#endif + } + retVal = avcodec_open2(streamInfo.codecContext.get(), codec, nullptr); if (retVal < AVSUCCESS) { throw std::invalid_argument(getFFMPEGErrorStringFromErrorCode(retVal)); } + codecContext->time_base = streamInfo.stream->time_base; activeStreamIndices_.insert(streamNumber); updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext); streamInfo.options = options; - initializeFilterGraphForStream(streamNumber, options); + if (options.device.is_cpu()) { + initializeFilterGraphForStream(streamNumber, options); + } } void VideoDecoder::updateMetadataWithCodecContext( @@ -632,6 +749,9 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { VideoDecoder::DecodedOutput VideoDecoder::getDecodedOutputWithFilter( std::function filterFunction) { +#ifdef ENABLE_NVTX + nvtx3::scoped_range loop{"decodeOneFrame"}; +#endif if (activeStreamIndices_.size() == 0) { throw std::runtime_error("No active streams configured."); } @@ -719,8 +839,13 @@ VideoDecoder::DecodedOutput VideoDecoder::getDecodedOutputWithFilter( // This packet is not for any of the active streams. continue; } - ffmpegStatus = avcodec_send_packet( - streams_[packet->stream_index].codecContext.get(), packet.get()); + { +#ifdef ENABLE_NVTX + nvtx3::scoped_range loop{"avcodec_send_packet"}; +#endif + ffmpegStatus = avcodec_send_packet( + streams_[packet->stream_index].codecContext.get(), packet.get()); + } decodeStats_.numPacketsSentToDecoder++; if (ffmpegStatus < AVSUCCESS) { throw std::runtime_error( @@ -757,8 +882,9 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( UniqueAVFrame frame) { // Convert the frame to tensor. DecodedOutput output; + auto& streamInfo = streams_[streamIndex]; output.streamIndex = streamIndex; - output.streamType = streams_[streamIndex].stream->codecpar->codec_type; + output.streamType = streamInfo.stream->codecpar->codec_type; output.pts = frame->pts; output.ptsSeconds = ptsToSeconds(frame->pts, formatContext_->streams[streamIndex]->time_base); @@ -766,8 +892,22 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( output.durationSeconds = ptsToSeconds( getDuration(frame), formatContext_->streams[streamIndex]->time_base); if (output.streamType == AVMEDIA_TYPE_VIDEO) { - output.frame = - convertFrameToTensorUsingFilterGraph(streamIndex, frame.get()); + if (streamInfo.options.device.is_cpu()) { + output.frame = + convertFrameToTensorUsingFilterGraph(streamIndex, frame.get()); + } else if (streamInfo.options.device.is_cuda()) { +#ifdef ENABLE_CUDA + { +#ifdef ENABLE_NVTX + nvtx3::scoped_range loop{"convertFrameUsingCuda"}; +#endif + output.frame = convertFrameToTensorUsingCUDA( + streamInfo.codecContext.get(), streamInfo.options, frame.get()); + } +#else + throw std::runtime_error("CUDA is not enabled in this build."); +#endif // ENABLE_CUDA + } } else if (output.streamType == AVMEDIA_TYPE_AUDIO) { // TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement // audio decoding. diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 41a1dd27..f864dc0f 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -12,6 +12,7 @@ #include #include +#include "c10/core/Device.h" #include "src/torchcodec/decoders/_core/FFMPEGCommon.h" namespace facebook::torchcodec { @@ -139,6 +140,8 @@ class VideoDecoder { // is the same as the original video. std::optional width; std::optional height; + // Set the device to torch::kGPU for GPU decoding. + torch::Device device = torch::kCPU; }; struct AudioStreamDecoderOptions {}; void addVideoStreamDecoder( @@ -253,6 +256,8 @@ class VideoDecoder { FilterState filterState; std::vector keyFrames; std::vector allFrames; + AVPixelFormat hwPixelFormat = AV_PIX_FMT_NONE; + UniqueAVBufferRef hwDeviceContext; }; VideoDecoder(); // Returns the key frame index of the presentation timestamp using FFMPEG's diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 6f05299f..5f8ced3d 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -29,7 +29,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def("create_from_file(str filename) -> Tensor"); m.def("create_from_tensor(Tensor video_tensor) -> Tensor"); m.def( - "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None) -> ()"); + "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device_string=None) -> ()"); m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()"); m.def("get_next_frame(Tensor(a!) decoder) -> (Tensor, Tensor, Tensor)"); m.def( @@ -113,7 +113,8 @@ void add_video_stream( std::optional height, std::optional num_threads, std::optional dimension_order, - std::optional stream_index) { + std::optional stream_index, + std::optional device_string) { VideoDecoder::VideoStreamDecoderOptions options; options.width = width; options.height = height; @@ -124,6 +125,10 @@ void add_video_stream( TORCH_CHECK(stdDimensionOrder == "NHWC" || stdDimensionOrder == "NCHW"); options.dimensionOrder = stdDimensionOrder; } + if (device_string.has_value()) { + std::string deviceString{device_string.value()}; + options.device = torch::Device(deviceString); + } auto videoDecoder = unwrapTensorToGetDecoder(decoder); videoDecoder->addVideoStreamDecoder(stream_index.value_or(-1), options); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 1da65f6d..f79a1b04 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -35,7 +35,8 @@ void add_video_stream( std::optional height = std::nullopt, std::optional num_threads = std::nullopt, std::optional dimension_order = std::nullopt, - std::optional stream_index = std::nullopt); + std::optional stream_index = std::nullopt, + std::optional device_string = std::nullopt); // Seek to a particular presentation timestamp in the video in seconds. void seek_to_pts(at::Tensor& decoder, double seconds); diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 6a22c098..b7d9b60d 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -114,6 +114,7 @@ def add_video_stream_abstract( num_threads: Optional[int] = None, dimension_order: Optional[str] = None, stream_index: Optional[int] = None, + device_string: Optional[str] = None, ) -> None: return diff --git a/test/decoders/CMakeLists.txt b/test/decoders/CMakeLists.txt index 21791dde..a5a26704 100644 --- a/test/decoders/CMakeLists.txt +++ b/test/decoders/CMakeLists.txt @@ -26,6 +26,10 @@ add_executable( VideoDecoderOpsTest.cpp ) +if(ENABLE_CUDA) + target_compile_definitions(VideoDecoderTest PRIVATE ENABLE_CUDA=1) +endif() + target_include_directories(VideoDecoderTest SYSTEM PRIVATE ${TORCH_INCLUDE_DIRS}) target_include_directories(VideoDecoderTest PRIVATE ../../) target_include_directories(VideoDecoderOpsTest SYSTEM PRIVATE ${TORCH_INCLUDE_DIRS}) diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index 057148b3..1fe19316 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -17,6 +17,10 @@ #include "tools/cxx/Resources.h" #endif +#ifdef ENABLE_CUDA +#include +#endif + using namespace ::testing; C10_DEFINE_bool( @@ -201,6 +205,54 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) { } } +#ifdef ENABLE_CUDA +TEST(GPUVideoDecoderTest, ReturnsFirstTwoFramesOfVideo) { + if (!torch::cuda::is_available()) { + return; + } + at::cuda::getDefaultCUDAStream(); + std::string path = getResourcePath("nasa_13013.mp4"); + std::unique_ptr ourDecoder = + VideoDecoder::createFromFilePath(path); + VideoDecoder::VideoStreamDecoderOptions streamOptions; + streamOptions.device = torch::Device("cuda"); + ASSERT_TRUE(streamOptions.device.is_cuda()); + ASSERT_EQ(streamOptions.device.type(), torch::DeviceType::CUDA); + ourDecoder->addVideoStreamDecoder(-1, streamOptions); + auto output = ourDecoder->getNextDecodedOutput(); + torch::Tensor tensor1FromOurDecoder = output.frame; + EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector({3, 270, 480})); + EXPECT_EQ(output.ptsSeconds, 0.0); + EXPECT_EQ(output.pts, 0); + output = ourDecoder->getNextDecodedOutput(); + torch::Tensor tensor2FromOurDecoder = output.frame; + EXPECT_EQ(tensor2FromOurDecoder.sizes(), std::vector({3, 270, 480})); + EXPECT_EQ(output.ptsSeconds, 1'001. / 30'000); + EXPECT_EQ(output.pts, 1001); + + torch::Tensor tensor1FromFFMPEG = + readTensorFromDisk("nasa_13013.mp4.frame000001.cuda.pt"); + torch::Tensor tensor2FromFFMPEG = + readTensorFromDisk("nasa_13013.mp4.frame000002.cuda.pt"); + + EXPECT_EQ(tensor1FromFFMPEG.sizes(), std::vector({3, 270, 480})); + EXPECT_EQ(tensor2FromFFMPEG.sizes(), std::vector({3, 270, 480})); + EXPECT_EQ(tensor1FromOurDecoder.device().type(), torch::DeviceType::CUDA); + EXPECT_EQ(tensor2FromOurDecoder.device().type(), torch::DeviceType::CUDA); + torch::Tensor tensor1FromOurDecoderCPU = tensor1FromOurDecoder.cpu(); + torch::Tensor tensor2FromOurDecoderCPU = tensor1FromOurDecoder.cpu(); + EXPECT_TRUE(torch::equal(tensor1FromOurDecoderCPU, tensor1FromFFMPEG)); + EXPECT_TRUE(torch::equal(tensor2FromOurDecoderCPU, tensor2FromFFMPEG)); + + if (FLAGS_dump_frames_for_debugging) { + dumpTensorToDisk(tensor1FromFFMPEG, "tensor1FromFFMPEG.pt"); + dumpTensorToDisk(tensor2FromFFMPEG, "tensor2FromFFMPEG.pt"); + dumpTensorToDisk(tensor1FromOurDecoderCPU, "tensor1FromOurDecoder.pt"); + dumpTensorToDisk(tensor2FromOurDecoderCPU, "tensor2FromOurDecoder.pt"); + } +} +#endif + TEST_P(VideoDecoderTest, DecodesFramesInABatchInNCHW) { std::string path = getResourcePath("nasa_13013.mp4"); std::unique_ptr ourDecoder = diff --git a/test/decoders/manual_smoke_test.py b/test/decoders/manual_smoke_test.py index 7351155c..389aa5f4 100644 --- a/test/decoders/manual_smoke_test.py +++ b/test/decoders/manual_smoke_test.py @@ -4,17 +4,36 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import argparse import os import torchcodec from torchvision.io.image import write_png -decoder = torchcodec.decoders._core.create_from_file( - os.path.dirname(__file__) + "/../resources/nasa_13013.mp4" -) -torchcodec.decoders._core.scan_all_streams_to_update_metadata(decoder) -torchcodec.decoders._core.add_video_stream(decoder, stream_index=3) -frame, _, _ = torchcodec.decoders._core.get_frame_at_index( - decoder, stream_index=3, frame_index=180 -) -write_png(frame, "frame180.png") + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--device", default="cpu", type=str, help="Specify 'cuda:0' for CUDA decoding" + ) + args = parser.parse_args() + + decoder = torchcodec.decoders._core.create_from_file( + os.path.dirname(__file__) + "/../resources/nasa_13013.mp4" + ) + torchcodec.decoders._core.scan_all_streams_to_update_metadata(decoder) + torchcodec.decoders._core.add_video_stream( + decoder, stream_index=3, device_string=args.device + ) + frame, _, _ = torchcodec.decoders._core.get_frame_at_index( + decoder, stream_index=3, frame_index=180 + ) + if "cuda" in args.device: + output_name = "frame180.cuda.png" + else: + output_name = "frame180.cpu.png" + write_png(frame.cpu(), output_name) + + +if __name__ == "__main__": + main() diff --git a/test/generate_reference_resources.sh b/test/generate_reference_resources.sh index be77ffae..7f998509 100755 --- a/test/generate_reference_resources.sh +++ b/test/generate_reference_resources.sh @@ -40,6 +40,8 @@ ffmpeg -y -ss 12.979633 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time12.979633. # Audio generation in the form of an mp3. ffmpeg -y -i "$VIDEO_PATH" -b:a 192K -vn "$VIDEO_PATH.audio.mp3" +# TODO: Add frames decoded by Nvidia's NVDEC. + for bmp in "$RESOURCES_DIR"/*.bmp do python3 convert_image_to_tensor.py "$bmp" diff --git a/test/resources/nasa_13013.mp4.frame000001.cuda.pt b/test/resources/nasa_13013.mp4.frame000001.cuda.pt new file mode 100644 index 00000000..17c59fd4 Binary files /dev/null and b/test/resources/nasa_13013.mp4.frame000001.cuda.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000002.cuda.pt b/test/resources/nasa_13013.mp4.frame000002.cuda.pt new file mode 100644 index 00000000..17c59fd4 Binary files /dev/null and b/test/resources/nasa_13013.mp4.frame000002.cuda.pt differ