diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index bf3f9145..0494feec 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -43,11 +43,11 @@ class VideoDecoder: cheap no-copy operation that allows these frames to be transformed using the `torchvision transforms `_. - num_ffmpeg_threads (int or None, optional): The number of threads to use for decoding. + num_ffmpeg_threads (int or str, optional): The number of threads to use for decoding. Use 1 for single-threaded decoding which may be best if you are running multiple instances of ``VideoDecoder`` in parallel. Use a higher number for multi-threaded decoding which is best if you are running a single instance of ``VideoDecoder``. - ``None`` is equivalent to passing 0 and lets FFmpeg automatically decide. + "auto" is equivalent to passing 0 and lets FFmpeg automatically decide. Default: 1. device (str or torch.device, optional): The device to use for decoding. Default: "cpu". @@ -65,7 +65,7 @@ def __init__( *, stream_index: Optional[int] = None, dimension_order: Literal["NCHW", "NHWC"] = "NCHW", - num_ffmpeg_threads: Optional[int] = 1, + num_ffmpeg_threads: Union[int, str] = 1, device: Optional[Union[str, device]] = "cpu", ): if isinstance(source, str): @@ -89,6 +89,14 @@ def __init__( f"Supported values are {', '.join(allowed_dimension_orders)}." ) + if isinstance(num_ffmpeg_threads, str): + if num_ffmpeg_threads != "auto": + raise ValueError( + "num_ffmpeg_threads should be 'auto' when it's a string. " + f"Got {num_ffmpeg_threads}" + ) + num_ffmpeg_threads = 0 + core.scan_all_streams_to_update_metadata(self._decoder) core.add_video_stream( self._decoder,