diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 690c3a19601..295974c3d20 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -72,6 +72,8 @@ title: Semantic segmentation - local: object_detection title: Object detection + - local: video_load + title: Load video data title: "Vision" - sections: - local: nlp_load diff --git a/docs/source/package_reference/main_classes.mdx b/docs/source/package_reference/main_classes.mdx index c08d555cd29..402e291be38 100644 --- a/docs/source/package_reference/main_classes.mdx +++ b/docs/source/package_reference/main_classes.mdx @@ -245,6 +245,10 @@ Dictionary with split names as keys ('train', 'test' for example), and `Iterable [[autodoc]] datasets.Image +### Video + +[[autodoc]] datasets.Video + ## Filesystems [[autodoc]] datasets.filesystems.is_remote_filesystem diff --git a/docs/source/video_load.mdx b/docs/source/video_load.mdx new file mode 100644 index 00000000000..b30f0f4aeff --- /dev/null +++ b/docs/source/video_load.mdx @@ -0,0 +1,109 @@ +# Load video data + + + +Video support is experimental and is subject to change. + + + +Video datasets have [`Video`] type columns, which contain `decord` objects. + + + +To work with video datasets, you need to have the `vision` dependency installed. Check out the [installation](./installation#vision) guide to learn how to install it. + + + +When you load an video dataset and call the video column, the videos are decoded as `decord` Videos: + +```py +>>> from datasets import load_dataset, Video + +>>> dataset = load_dataset("path/to/video/folder", split="train") +>>> dataset[0]["video"] + +``` + + + +Index into an video dataset using the row index first and then the `video` column - `dataset[0]["video"]` - to avoid decoding and resampling all the video objects in the dataset. Otherwise, this can be a slow and time-consuming process if you have a large dataset. + + + +For a guide on how to load any type of dataset, take a look at the general loading guide. + +## Local files + +You can load a dataset from the video path. Use the [`~Dataset.cast_column`] function to accept a column of video file paths, and decode it into a `decord` video with the [`Video`] feature: +```py +>>> from datasets import Dataset, Video + +>>> dataset = Dataset.from_dict({"video": ["path/to/video_1", "path/to/video_2", ..., "path/to/video_n"]}).cast_column("video", Video()) +>>> dataset[0]["video"] + +``` + +If you only want to load the underlying path to the video dataset without decoding the video object, set `decode=False` in the [`Video`] feature: + +```py +>>> dataset = dataset.cast_column("video", Video(decode=False)) +>>> dataset[0]["video"] +{'bytes': None, + 'path': 'path/to/video/folder/video0.mp4'} +``` + +## VideoFolder + +You can also load a dataset with an `VideoFolder` dataset builder which does not require writing a custom dataloader. This makes `VideoFolder` ideal for quickly creating and loading video datasets with several thousand videos for different vision tasks. Your video dataset structure should look like this: + +``` +folder/train/dog/golden_retriever.mp4 +folder/train/dog/german_shepherd.mp4 +folder/train/dog/chihuahua.mp4 + +folder/train/cat/maine_coon.mp4 +folder/train/cat/bengal.mp4 +folder/train/cat/birman.mp4 +``` + +Load your dataset by specifying `videofolder` and the directory of your dataset in `data_dir`: + +```py +>>> from datasets import load_dataset + +>>> dataset = load_dataset("videofolder", data_dir="/path/to/folder") +>>> dataset["train"][0] +{"video": , "label": 0} + +>>> dataset["train"][-1] +{"video": , "label": 1} +``` + +Load remote datasets from their URLs with the `data_files` parameter: + +```py +>>> dataset = load_dataset("videofolder", data_files="https://foo.bar/videos.zip", split="train") +``` + +Some datasets have a metadata file (`metadata.csv`/`metadata.jsonl`) associated with it, containing other information about the data like bounding boxes, text captions, and labels. The metadata is automatically loaded when you call [`load_dataset`] and specify `videofolder`. + +To ignore the information in the metadata file, set `drop_labels=False` in [`load_dataset`], and allow `VideoFolder` to automatically infer the label name from the directory name: + +```py +>>> from datasets import load_dataset + +>>> dataset = load_dataset("videofolder", data_dir="/path/to/folder", drop_labels=False) +``` + +## WebDataset + +The [WebDataset](https://github.com/webdataset/webdataset) format is based on a folder of TAR archives and is suitable for big video datasets. +Because of their size, WebDatasets are generally loaded in streaming mode (using `streaming=True`). + +You can load a WebDataset like this: + +```python +>>> from datasets import load_dataset + +>>> dataset = load_dataset("webdataset", data_dir="/path/to/folder", streaming=True) +``` diff --git a/setup.py b/setup.py index 9f926e4f7b6..2970ba00299 100644 --- a/setup.py +++ b/setup.py @@ -187,6 +187,7 @@ "transformers>=4.42.0", # Pins numpy < 2 "zstandard", "polars[timezone]>=0.20.0", + "decord==0.6.0", ] diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index f65995f93d9..eb8fedce996 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -77,7 +77,7 @@ from .arrow_writer import ArrowWriter, OptimizedTypedSequence from .data_files import sanitize_patterns from .download.streaming_download_manager import xgetsize -from .features import Audio, ClassLabel, Features, Image, Sequence, Value +from .features import Audio, ClassLabel, Features, Image, Sequence, Value, Video from .features.features import ( FeatureType, _align_features, @@ -1416,9 +1416,9 @@ def save_to_disk( """ Saves a dataset to a dataset directory, or in a filesystem using any implementation of `fsspec.spec.AbstractFileSystem`. - For [`Image`] and [`Audio`] data: + For [`Image`], [`Audio`] and [`Video`] data: - All the Image() and Audio() data are stored in the arrow files. + All the Image(), Audio() and Video() data are stored in the arrow files. If you want to store paths or urls, please use the Value("string") type. Args: @@ -5065,7 +5065,7 @@ def _estimate_nbytes(self) -> int: def extra_nbytes_visitor(array, feature): nonlocal extra_nbytes - if isinstance(feature, (Audio, Image)): + if isinstance(feature, (Audio, Image, Video)): for x in array.to_pylist(): if x is not None and x["bytes"] is None and x["path"] is not None: size = xgetsize(x["path"]) @@ -5249,15 +5249,16 @@ def _push_parquet_shards_to_hub( shards = (self.shard(num_shards=num_shards, index=i, contiguous=True) for i in range(num_shards)) if decodable_columns: + from .io.parquet import get_writer_batch_size - def shards_with_embedded_external_files(shards): + def shards_with_embedded_external_files(shards: Iterator[Dataset]) -> Iterator[Dataset]: for shard in shards: format = shard.format shard = shard.with_format("arrow") shard = shard.map( embed_table_storage, batched=True, - batch_size=1000, + batch_size=get_writer_batch_size(shard.features), keep_in_memory=True, ) shard = shard.with_format(**format) @@ -5310,7 +5311,7 @@ def push_to_hub( """Pushes the dataset to the hub as a Parquet dataset. The dataset is pushed using HTTP requests and does not need to have neither git or git-lfs installed. - The resulting Parquet files are self-contained by default. If your dataset contains [`Image`] or [`Audio`] + The resulting Parquet files are self-contained by default. If your dataset contains [`Image`], [`Audio`] or [`Video`] data, the Parquet files will store the bytes of your images or audio files. You can disable this by setting `embed_external_files` to `False`. diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index 3b9993736e4..763ad8b7a41 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -340,7 +340,7 @@ def __init__( self.fingerprint = fingerprint self.disable_nullable = disable_nullable - self.writer_batch_size = writer_batch_size or config.DEFAULT_MAX_BATCH_SIZE + self.writer_batch_size = writer_batch_size self.update_features = update_features self.with_metadata = with_metadata self.unit = unit @@ -353,6 +353,11 @@ def __init__( self.pa_writer: Optional[pa.RecordBatchStreamWriter] = None self.hkey_record = [] + if self.writer_batch_size is None and self._features is not None: + from .io.parquet import get_writer_batch_size + + self.writer_batch_size = get_writer_batch_size(self._features) or config.DEFAULT_MAX_BATCH_SIZE + def __len__(self): """Return the number of writed and staged examples""" return self._num_examples + len(self.current_examples) + len(self.current_rows) @@ -397,6 +402,10 @@ def _build_writer(self, inferred_schema: pa.Schema): schema = schema.with_metadata({}) self._schema = schema self.pa_writer = self._WRITER_CLASS(self.stream, schema) + if self.writer_batch_size is None: + from .io.parquet import get_writer_batch_size + + self.writer_batch_size = get_writer_batch_size(self._features) or config.DEFAULT_MAX_BATCH_SIZE @property def schema(self): diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 7328b90cbca..c3eee41c6e0 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -930,7 +930,8 @@ def incomplete_dir(dirname): # Sync info self.info.dataset_size = sum(split.num_bytes for split in self.info.splits.values()) self.info.download_checksums = dl_manager.get_recorded_sizes_checksums() - self.info.size_in_bytes = self.info.dataset_size + self.info.download_size + if self.info.download_size is not None: + self.info.size_in_bytes = self.info.dataset_size + self.info.download_size # Save info self._save_info() diff --git a/src/datasets/config.py b/src/datasets/config.py index e2de170bcba..e2de31b337a 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -129,6 +129,7 @@ IS_MP3_SUPPORTED = importlib.util.find_spec("soundfile") is not None and version.parse( importlib.import_module("soundfile").__libsndfile_version__ ) >= version.parse("1.1.0") +DECORD_AVAILABLE = importlib.util.find_spec("decord") is not None # Optional compression tools RARFILE_AVAILABLE = importlib.util.find_spec("rarfile") is not None @@ -192,6 +193,7 @@ PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS = 100 PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS = 100 PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS = 100 +PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS = 10 # Offline mode _offline = os.environ.get("HF_DATASETS_OFFLINE") diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index f92a1a8afda..40ca0cd7312 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -1230,9 +1230,9 @@ def save_to_disk( """ Saves a dataset dict to a filesystem using `fsspec.spec.AbstractFileSystem`. - For [`Image`] and [`Audio`] data: + For [`Image`], [`Audio`] and [`Video`] data: - All the Image() and Audio() data are stored in the arrow files. + All the Image(), Audio() and Video() data are stored in the arrow files. If you want to store paths or urls, please use the Value("string") type. Args: diff --git a/src/datasets/download/streaming_download_manager.py b/src/datasets/download/streaming_download_manager.py index 9a1d1e3a53c..800f6821443 100644 --- a/src/datasets/download/streaming_download_manager.py +++ b/src/datasets/download/streaming_download_manager.py @@ -64,6 +64,8 @@ def __init__( self._data_dir = data_dir self._base_path = base_path or os.path.abspath(".") self.download_config = download_config or DownloadConfig() + self.downloaded_size = None + self.record_checksums = False @property def manual_dir(self): @@ -208,3 +210,9 @@ def iter_files(self, urlpaths: Union[str, List[str]]) -> Iterable[str]: ``` """ return FilesIterable.from_urlpaths(urlpaths, download_config=self.download_config) + + def manage_extracted_files(self): + pass + + def get_recorded_sizes_checksums(self): + pass diff --git a/src/datasets/features/__init__.py b/src/datasets/features/__init__.py index 35ebfb4ac0c..bf38042eb81 100644 --- a/src/datasets/features/__init__.py +++ b/src/datasets/features/__init__.py @@ -12,8 +12,10 @@ "Image", "Translation", "TranslationVariableLanguages", + "Video", ] from .audio import Audio from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, Sequence, Value from .image import Image from .translation import Translation, TranslationVariableLanguages +from .video import Video diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 1d241e0b7b7..34622cd94d9 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -43,6 +43,7 @@ from .audio import Audio from .image import Image, encode_pil_image from .translation import Translation, TranslationVariableLanguages +from .video import Video logger = logging.get_logger(__name__) @@ -1202,6 +1203,7 @@ class LargeList: Array5D, Audio, Image, + Video, ] @@ -1346,7 +1348,7 @@ def encode_nested_example(schema, obj, level=0): return list(obj) # Object with special encoding: # ClassLabel will convert from string to int, TranslationVariableLanguages does some checks - elif isinstance(schema, (Audio, Image, ClassLabel, TranslationVariableLanguages, Value, _ArrayXD)): + elif isinstance(schema, (Audio, Image, ClassLabel, TranslationVariableLanguages, Value, _ArrayXD, Video)): return schema.encode_example(obj) if obj is not None else None # Other object should be directly convertible to a native Arrow type (like Translation and Translation) return obj @@ -1397,7 +1399,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni else: return decode_nested_example([schema.feature], obj) # Object with special decoding: - elif isinstance(schema, (Audio, Image)): + elif isinstance(schema, (Audio, Image, Video)): # we pass the token to read and decode files from private repositories in streaming mode if obj is not None and schema.decode: return schema.decode_example(obj, token_per_repo_id=token_per_repo_id) @@ -1417,6 +1419,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni Array5D.__name__: Array5D, Audio.__name__: Audio, Image.__name__: Image, + Video.__name__: Video, } diff --git a/src/datasets/features/video.py b/src/datasets/features/video.py new file mode 100644 index 00000000000..a9967f3be45 --- /dev/null +++ b/src/datasets/features/video.py @@ -0,0 +1,340 @@ +import os +from dataclasses import dataclass, field +from io import BytesIO +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union + +import numpy as np +import pyarrow as pa + +from .. import config +from ..download.download_config import DownloadConfig +from ..table import array_cast +from ..utils.file_utils import is_local_path, xopen +from ..utils.py_utils import no_op_if_value_is_null, string_to_dict + + +if TYPE_CHECKING: + from decord import VideoReader + + from .features import FeatureType + + +@dataclass +class Video: + """Video [`Feature`] to read video data from a video file. + + Input: The Video feature accepts as input: + - A `str`: Absolute path to the video file (i.e. random access is allowed). + - A `dict` with the keys: + + - `path`: String with relative path of the video file in a dataset repository. + - `bytes`: Bytes of the video file. + + This is useful for archived files with sequential access. + + - An `np.ndarray`: NumPy array representing a video. + - A `decord.VideoReader`: decord video reader object. + + Args: + mode (`str`, *optional*): + The mode to convert the video to. If `None`, the native mode of the video is used. + decode (`bool`, defaults to `True`): + Whether to decode the video data. If `False`, + returns the underlying dictionary in the format `{"path": video_path, "bytes": video_bytes}`. + + Examples: + + ```py + >>> from datasets import Dataset, Video + >>> ds = Dataset.from_dict({"video":["path/to/Screen Recording.mov"]}).cast_column("video", Video()) + >>> ds.features["video"] + Video(decode=True, id=None) + >>> ds[0]["video"] + + >>> ds = ds.cast_column('video', Video(decode=False)) + {'bytes': None, + 'path': 'path/to/Screen Recording.mov'} + ``` + """ + + decode: bool = True + id: Optional[str] = None + # Automatically constructed + dtype: ClassVar[str] = "decord.VideoReader" + pa_type: ClassVar[Any] = pa.struct({"bytes": pa.binary(), "path": pa.string()}) + _type: str = field(default="Video", init=False, repr=False) + + def __call__(self): + return self.pa_type + + def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "VideoReader"]) -> dict: + """Encode example into a format for Arrow. + + Args: + value (`str`, `np.ndarray`, `VideoReader` or `dict`): + Data passed as input to Video feature. + + Returns: + `dict` with "path" and "bytes" fields + """ + if config.DECORD_AVAILABLE: + # We need to import torch first, otherwise later it can cause issues + # e.g. "RuntimeError: random_device could not be read" + # when running `torch.tensor(value).share_memory_()` + if config.TORCH_AVAILABLE: + import torch # noqa + from decord import VideoReader + + else: + raise ImportError("To support encoding videos, please install 'decord'.") + + if isinstance(value, list): + value = np.array(value) + + if isinstance(value, str): + return {"path": value, "bytes": None} + elif isinstance(value, bytes): + return {"path": None, "bytes": value} + elif isinstance(value, np.ndarray): + # convert the video array to bytes + return encode_np_array(value) + elif isinstance(value, VideoReader): + # convert the decord video reader to bytes + return encode_decord_video(value) + elif value.get("path") is not None and os.path.isfile(value["path"]): + # we set "bytes": None to not duplicate the data if they're already available locally + return {"bytes": None, "path": value.get("path")} + elif value.get("bytes") is not None or value.get("path") is not None: + # store the video bytes, and path is used to infer the video format using the file extension + return {"bytes": value.get("bytes"), "path": value.get("path")} + else: + raise ValueError( + f"A video sample should have one of 'path' or 'bytes' but they are missing or None in {value}." + ) + + def decode_example(self, value: dict, token_per_repo_id=None) -> "VideoReader": + """Decode example video file into video data. + + Args: + value (`str` or `dict`): + A string with the absolute video file path, a dictionary with + keys: + + - `path`: String with absolute or relative video file path. + - `bytes`: The bytes of the video file. + token_per_repo_id (`dict`, *optional*): + To access and decode + video files from private repositories on the Hub, you can pass + a dictionary repo_id (`str`) -> token (`bool` or `str`). + + Returns: + `decord.VideoReader` + """ + if not self.decode: + raise RuntimeError("Decoding is disabled for this feature. Please use Video(decode=True) instead.") + + if config.DECORD_AVAILABLE: + # We need to import torch first, otherwise later it can cause issues + # e.g. "RuntimeError: random_device could not be read" + # when running `torch.tensor(value).share_memory_()` + if config.TORCH_AVAILABLE: + import torch # noqa + from decord import VideoReader + else: + raise ImportError("To support decoding videos, please install 'decord'.") + + if token_per_repo_id is None: + token_per_repo_id = {} + + path, bytes_ = value["path"], value["bytes"] + if bytes_ is None: + if path is None: + raise ValueError(f"A video should have one of 'path' or 'bytes' but both are None in {value}.") + else: + if is_local_path(path): + video = VideoReader(path) + else: + source_url = path.split("::")[-1] + pattern = ( + config.HUB_DATASETS_URL + if source_url.startswith(config.HF_ENDPOINT) + else config.HUB_DATASETS_HFFS_URL + ) + try: + repo_id = string_to_dict(source_url, pattern)["repo_id"] + token = token_per_repo_id.get(repo_id) + except ValueError: + token = None + download_config = DownloadConfig(token=token) + with xopen(path, "rb", download_config=download_config) as f: + bytes_ = BytesIO(f.read()) + video = VideoReader(bytes_) + else: + video = VideoReader(BytesIO(bytes_)) + return video + + def flatten(self) -> Union["FeatureType", Dict[str, "FeatureType"]]: + """If in the decodable state, return the feature itself, otherwise flatten the feature into a dictionary.""" + from .features import Value + + return ( + self + if self.decode + else { + "bytes": Value("binary"), + "path": Value("string"), + } + ) + + def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArray]) -> pa.StructArray: + """Cast an Arrow array to the Video arrow storage type. + The Arrow types that can be converted to the Video pyarrow storage type are: + + - `pa.string()` - it must contain the "path" data + - `pa.binary()` - it must contain the video bytes + - `pa.struct({"bytes": pa.binary()})` + - `pa.struct({"path": pa.string()})` + - `pa.struct({"bytes": pa.binary(), "path": pa.string()})` - order doesn't matter + - `pa.list(*)` - it must contain the video array data + + Args: + storage (`Union[pa.StringArray, pa.StructArray, pa.ListArray]`): + PyArrow array to cast. + + Returns: + `pa.StructArray`: Array in the Video arrow storage type, that is + `pa.struct({"bytes": pa.binary(), "path": pa.string()})`. + """ + if pa.types.is_string(storage.type): + bytes_array = pa.array([None] * len(storage), type=pa.binary()) + storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_binary(storage.type): + path_array = pa.array([None] * len(storage), type=pa.string()) + storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_struct(storage.type): + if storage.type.get_field_index("bytes") >= 0: + bytes_array = storage.field("bytes") + else: + bytes_array = pa.array([None] * len(storage), type=pa.binary()) + if storage.type.get_field_index("path") >= 0: + path_array = storage.field("path") + else: + path_array = pa.array([None] * len(storage), type=pa.string()) + storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_list(storage.type): + bytes_array = pa.array( + [encode_np_array(np.array(arr))["bytes"] if arr is not None else None for arr in storage.to_pylist()], + type=pa.binary(), + ) + path_array = pa.array([None] * len(storage), type=pa.string()) + storage = pa.StructArray.from_arrays( + [bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null() + ) + return array_cast(storage, self.pa_type) + + def embed_storage(self, storage: pa.StructArray) -> pa.StructArray: + """Embed video files into the Arrow array. + + Args: + storage (`pa.StructArray`): + PyArrow array to embed. + + Returns: + `pa.StructArray`: Array in the Video arrow storage type, that is + `pa.struct({"bytes": pa.binary(), "path": pa.string()})`. + """ + + @no_op_if_value_is_null + def path_to_bytes(path): + with xopen(path, "rb") as f: + bytes_ = f.read() + return bytes_ + + bytes_array = pa.array( + [ + (path_to_bytes(x["path"]) if x["bytes"] is None else x["bytes"]) if x is not None else None + for x in storage.to_pylist() + ], + type=pa.binary(), + ) + path_array = pa.array( + [ + (os.path.basename(path) if os.path.isfile(path) else path) if path is not None else None + for path in storage.field("path").to_pylist() + ], + type=pa.string(), + ) + storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null()) + return array_cast(storage, self.pa_type) + + +def video_to_bytes(video: "VideoReader") -> bytes: + """Convert a decord Video object to bytes using native compression if possible""" + raise NotImplementedError() + + +def encode_decord_video(video: "VideoReader") -> dict: + if hasattr(video, "_hf_encoded"): + return video._hf_encoded + else: + raise NotImplementedError( + "Encoding a decord video is not implemented. " + "Please call `datasets.features.video.patch_decord()` before loading videos to enable this." + ) + + +def encode_np_array(array: np.ndarray) -> dict: + raise NotImplementedError() + + +# Patching decord a little bit to: +# 1. store the encoded video data {"path": ..., "bytes": ...} in `video._hf_encoded`` +# 2. set the decord bridge to numpy/torch/tf/jax using `video._hf_bridge_out` (per video instance) instead of decord.bridge.bridge_out (global) +# This doesn't affect the normal usage of decord. + + +def _patched_init(self: "VideoReader", uri: Union[str, BytesIO], *args, **kwargs) -> None: + from decord.bridge import bridge_out + + if hasattr(uri, "read"): + self._hf_encoded = {"bytes": uri.read(), "path": None} + uri.seek(0) + elif isinstance(uri, str): + self._hf_encoded = {"bytes": None, "path": uri} + self._hf_bridge_out = bridge_out + self._original_init(uri, *args, **kwargs) + + +def _patched_next(self: "VideoReader", *args, **kwargs): + return self._hf_bridge_out(self._original_next(*args, **kwargs)) + + +def _patched_get_batch(self: "VideoReader", *args, **kwargs): + return self._hf_bridge_out(self._original_get_batch(*args, **kwargs)) + + +def patch_decord(): + if config.DECORD_AVAILABLE: + # We need to import torch first, otherwise later it can cause issues + # e.g. "RuntimeError: random_device could not be read" + # when running `torch.tensor(value).share_memory_()` + if config.TORCH_AVAILABLE: + import torch # noqa + import decord.video_reader + from decord import VideoReader + + if not hasattr(VideoReader, "_hf_patched"): + decord.video_reader.bridge_out = lambda x: x + VideoReader._original_init = VideoReader.__init__ + VideoReader.__init__ = _patched_init + VideoReader._original_next = VideoReader.next + VideoReader.next = _patched_next + VideoReader._original_get_batch = VideoReader.get_batch + VideoReader.get_batch = _patched_get_batch + VideoReader._hf_patched = True + else: + raise ImportError("To support decoding videos, please install 'decord'.") + + +if config.DECORD_AVAILABLE: + patch_decord() diff --git a/src/datasets/formatting/jax_formatter.py b/src/datasets/formatting/jax_formatter.py index 8035341c5cd..e247b7b5822 100644 --- a/src/datasets/formatting/jax_formatter.py +++ b/src/datasets/formatting/jax_formatter.py @@ -100,11 +100,23 @@ def _tensorize(self, value): default_dtype = {"dtype": jnp.int32} elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating): default_dtype = {"dtype": jnp.float32} - elif config.PIL_AVAILABLE and "PIL" in sys.modules: + + if config.PIL_AVAILABLE and "PIL" in sys.modules: import PIL.Image if isinstance(value, PIL.Image.Image): value = np.asarray(value) + if config.DECORD_AVAILABLE and "decord" in sys.modules: + # We need to import torch first, otherwise later it can cause issues + # e.g. "RuntimeError: random_device could not be read" + # when running `torch.tensor(value).share_memory_()` + if config.TORCH_AVAILABLE: + import torch # noqa + from decord import VideoReader + + if isinstance(value, VideoReader): + value._hf_bridge_out = lambda x: jnp.array(np.asarray(x)) + return value # using global variable since `jaxlib.xla_extension.Device` is not serializable neither # with `pickle` nor with `dill`, so we need to use a global variable instead diff --git a/src/datasets/formatting/np_formatter.py b/src/datasets/formatting/np_formatter.py index 95bcff2b517..032758bce21 100644 --- a/src/datasets/formatting/np_formatter.py +++ b/src/datasets/formatting/np_formatter.py @@ -57,11 +57,23 @@ def _tensorize(self, value): default_dtype = {"dtype": np.int64} elif isinstance(value, np.ndarray) and np.issubdtype(value.dtype, np.floating): default_dtype = {"dtype": np.float32} - elif config.PIL_AVAILABLE and "PIL" in sys.modules: + + if config.PIL_AVAILABLE and "PIL" in sys.modules: import PIL.Image if isinstance(value, PIL.Image.Image): return np.asarray(value, **self.np_array_kwargs) + if config.DECORD_AVAILABLE and "decord" in sys.modules: + # We need to import torch first, otherwise later it can cause issues + # e.g. "RuntimeError: random_device could not be read" + # when running `torch.tensor(value).share_memory_()` + if config.TORCH_AVAILABLE: + import torch # noqa + from decord import VideoReader + + if isinstance(value, VideoReader): + value._hf_bridge_out = np.asarray + return value return np.asarray(value, **{**default_dtype, **self.np_array_kwargs}) diff --git a/src/datasets/formatting/tf_formatter.py b/src/datasets/formatting/tf_formatter.py index adb15cda381..9f0c06ec82a 100644 --- a/src/datasets/formatting/tf_formatter.py +++ b/src/datasets/formatting/tf_formatter.py @@ -64,11 +64,24 @@ def _tensorize(self, value): default_dtype = {"dtype": tf.int64} elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating): default_dtype = {"dtype": tf.float32} - elif config.PIL_AVAILABLE and "PIL" in sys.modules: + + if config.PIL_AVAILABLE and "PIL" in sys.modules: import PIL.Image if isinstance(value, PIL.Image.Image): value = np.asarray(value) + if config.DECORD_AVAILABLE and "decord" in sys.modules: + # We need to import torch first, otherwise later it can cause issues + # e.g. "RuntimeError: random_device could not be read" + # when running `torch.tensor(value).share_memory_()` + if config.TORCH_AVAILABLE: + import torch # noqa + from decord import VideoReader + from decord.bridge import to_tensorflow + + if isinstance(value, VideoReader): + value._hf_bridge_out = to_tensorflow + return value return tf.convert_to_tensor(value, **{**default_dtype, **self.tf_tensor_kwargs}) diff --git a/src/datasets/formatting/torch_formatter.py b/src/datasets/formatting/torch_formatter.py index 8efe759a144..051badb0ac4 100644 --- a/src/datasets/formatting/torch_formatter.py +++ b/src/datasets/formatting/torch_formatter.py @@ -66,7 +66,8 @@ def _tensorize(self, value): elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating): default_dtype = {"dtype": torch.float32} - elif config.PIL_AVAILABLE and "PIL" in sys.modules: + + if config.PIL_AVAILABLE and "PIL" in sys.modules: import PIL.Image if isinstance(value, PIL.Image.Image): @@ -75,6 +76,14 @@ def _tensorize(self, value): value = value[:, :, np.newaxis] value = value.transpose((2, 0, 1)) + if config.DECORD_AVAILABLE and "decord" in sys.modules: + from decord import VideoReader + from decord.bridge import to_torch + + if isinstance(value, VideoReader): + value._hf_bridge_out = to_torch + return value + return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs}) def _recursive_tensorize(self, data_struct): diff --git a/src/datasets/io/parquet.py b/src/datasets/io/parquet.py index 51434106fb1..289bd1adfdc 100644 --- a/src/datasets/io/parquet.py +++ b/src/datasets/io/parquet.py @@ -5,7 +5,7 @@ import numpy as np import pyarrow.parquet as pq -from .. import Audio, Dataset, Features, Image, NamedSplit, Value, config +from .. import Audio, Dataset, Features, Image, NamedSplit, Value, Video, config from ..features.features import FeatureType, _visit from ..formatting import query_table from ..packaged_modules import _PACKAGED_DATASETS_MODULES @@ -42,6 +42,8 @@ def set_batch_size(feature: FeatureType) -> None: batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS) elif isinstance(feature, Audio): batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS) + elif isinstance(feature, Video): + batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS) elif isinstance(feature, Value) and feature.dtype == "binary": batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS) diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index 7598d0213ac..94d806edf45 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -14,6 +14,7 @@ from .parquet import parquet from .sql import sql from .text import text +from .videofolder import videofolder from .webdataset import webdataset from .xml import xml @@ -41,6 +42,7 @@ def _hash_python_lines(lines: List[str]) -> str: "text": (text.__name__, _hash_python_lines(inspect.getsource(text).splitlines())), "imagefolder": (imagefolder.__name__, _hash_python_lines(inspect.getsource(imagefolder).splitlines())), "audiofolder": (audiofolder.__name__, _hash_python_lines(inspect.getsource(audiofolder).splitlines())), + "videofolder": (videofolder.__name__, _hash_python_lines(inspect.getsource(videofolder).splitlines())), "webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())), "xml": (xml.__name__, _hash_python_lines(inspect.getsource(xml).splitlines())), } @@ -77,7 +79,9 @@ def _hash_python_lines(lines: List[str]) -> str: _EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext: ("audiofolder", {}) for ext in audiofolder.AudioFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext.upper(): ("audiofolder", {}) for ext in audiofolder.AudioFolder.EXTENSIONS}) -_MODULE_SUPPORTS_METADATA = {"imagefolder", "audiofolder"} +_EXTENSION_TO_MODULE.update({ext: ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS}) +_EXTENSION_TO_MODULE.update({ext.upper(): ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS}) +_MODULE_SUPPORTS_METADATA = {"imagefolder", "audiofolder", "videofolder"} # Used to filter data files based on extensions given a module name _MODULE_TO_EXTENSIONS: Dict[str, List[str]] = {} diff --git a/src/datasets/packaged_modules/videofolder/__init__.py b/src/datasets/packaged_modules/videofolder/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/datasets/packaged_modules/videofolder/videofolder.py b/src/datasets/packaged_modules/videofolder/videofolder.py new file mode 100644 index 00000000000..7ce5bcf5655 --- /dev/null +++ b/src/datasets/packaged_modules/videofolder/videofolder.py @@ -0,0 +1,36 @@ +from typing import List + +import datasets + +from ..folder_based_builder import folder_based_builder + + +logger = datasets.utils.logging.get_logger(__name__) + + +class VideoFolderConfig(folder_based_builder.FolderBasedBuilderConfig): + """BuilderConfig for ImageFolder.""" + + drop_labels: bool = None + drop_metadata: bool = None + + def __post_init__(self): + super().__post_init__() + + +class VideoFolder(folder_based_builder.FolderBasedBuilder): + BASE_FEATURE = datasets.Video + BASE_COLUMN_NAME = "video" + BUILDER_CONFIG_CLASS = VideoFolderConfig + EXTENSIONS: List[str] # definition at the bottom of the script + + +# TODO: initial list, we should check the compatibility of other formats +VIDEO_EXTENSIONS = [ + ".mkv", + ".mp4", + ".avi", + ".mpeg", + ".mov", +] +VideoFolder.EXTENSIONS = VIDEO_EXTENSIONS diff --git a/src/datasets/packaged_modules/webdataset/webdataset.py b/src/datasets/packaged_modules/webdataset/webdataset.py index c04f3ba4639..0768437b36a 100644 --- a/src/datasets/packaged_modules/webdataset/webdataset.py +++ b/src/datasets/packaged_modules/webdataset/webdataset.py @@ -20,6 +20,7 @@ class WebDataset(datasets.GeneratorBasedBuilder): DEFAULT_WRITER_BATCH_SIZE = 100 IMAGE_EXTENSIONS: List[str] # definition at the bottom of the script AUDIO_EXTENSIONS: List[str] # definition at the bottom of the script + VIDEO_EXTENSIONS: List[str] # definition at the bottom of the script DECODERS: Dict[str, Callable[[Any], Any]] # definition at the bottom of the script NUM_EXAMPLES_FOR_FEATURES_INFERENCE = 5 @@ -97,6 +98,11 @@ def _split_generators(self, dl_manager): extension = field_name.rsplit(".", 1)[-1] if extension in self.AUDIO_EXTENSIONS: features[field_name] = datasets.Audio() + # Set Video types + for field_name in first_examples[0]: + extension = field_name.rsplit(".", 1)[-1] + if extension in self.VIDEO_EXTENSIONS: + features[field_name] = datasets.Video() self.info.features = features return splits @@ -259,6 +265,17 @@ def base_plus_ext(path): WebDataset.AUDIO_EXTENSIONS = AUDIO_EXTENSIONS +# TODO: initial list, we should check the compatibility of other formats +VIDEO_EXTENSIONS = [ + ".mkv", + ".mp4", + ".avi", + ".mpeg", + ".mov", +] +WebDataset.VIDEO_EXTENSIONS = VIDEO_EXTENSIONS + + def text_loads(data: bytes): return data.decode("utf-8") diff --git a/tests/features/data/test_video_66x50.mov b/tests/features/data/test_video_66x50.mov new file mode 100644 index 00000000000..a55dcaa8f7b Binary files /dev/null and b/tests/features/data/test_video_66x50.mov differ diff --git a/tests/features/test_video.py b/tests/features/test_video.py new file mode 100644 index 00000000000..f4c9a8d830b --- /dev/null +++ b/tests/features/test_video.py @@ -0,0 +1,92 @@ +import pytest + +from datasets import Dataset, Features, Video + +from ..utils import require_decord + + +@require_decord +@pytest.mark.parametrize( + "build_example", + [ + lambda video_path: video_path, + lambda video_path: open(video_path, "rb").read(), + lambda video_path: {"path": video_path}, + lambda video_path: {"path": video_path, "bytes": None}, + lambda video_path: {"path": video_path, "bytes": open(video_path, "rb").read()}, + lambda video_path: {"path": None, "bytes": open(video_path, "rb").read()}, + lambda video_path: {"bytes": open(video_path, "rb").read()}, + ], +) +def test_video_feature_encode_example(shared_datadir, build_example): + from decord import VideoReader + + video_path = str(shared_datadir / "test_video_66x50.mov") + video = Video() + encoded_example = video.encode_example(build_example(video_path)) + assert isinstance(encoded_example, dict) + assert encoded_example.keys() == {"bytes", "path"} + assert encoded_example["bytes"] is not None or encoded_example["path"] is not None + decoded_example = video.decode_example(encoded_example) + assert isinstance(decoded_example, VideoReader) + + +@require_decord +def test_dataset_with_video_feature(shared_datadir): + from decord import VideoReader + from decord.ndarray import NDArray + + video_path = str(shared_datadir / "test_video_66x50.mov") + data = {"video": [video_path]} + features = Features({"video": Video()}) + dset = Dataset.from_dict(data, features=features) + item = dset[0] + assert item.keys() == {"video"} + assert isinstance(item["video"], VideoReader) + assert item["video"][0].shape == (50, 66, 3) + assert isinstance(item["video"][0], NDArray) + batch = dset[:1] + assert len(batch) == 1 + assert batch.keys() == {"video"} + assert isinstance(batch["video"], list) and all(isinstance(item, VideoReader) for item in batch["video"]) + assert batch["video"][0][0].shape == (50, 66, 3) + assert isinstance(batch["video"][0][0], NDArray) + column = dset["video"] + assert len(column) == 1 + assert isinstance(column, list) and all(isinstance(item, VideoReader) for item in column) + assert column[0][0].shape == (50, 66, 3) + assert isinstance(column[0][0], NDArray) + + # from bytes + with open(video_path, "rb") as f: + data = {"video": [f.read()]} + dset = Dataset.from_dict(data, features=features) + item = dset[0] + assert item.keys() == {"video"} + assert isinstance(item["video"], VideoReader) + assert item["video"][0].shape == (50, 66, 3) + assert isinstance(item["video"][0], NDArray) + + +@require_decord +def test_dataset_with_video_map_and_formatted(shared_datadir): + import numpy as np + from decord import VideoReader + + video_path = str(shared_datadir / "test_video_66x50.mov") + data = {"video": [video_path]} + features = Features({"video": Video()}) + dset = Dataset.from_dict(data, features=features) + dset = dset.map(lambda x: x).with_format("numpy") + example = dset[0] + assert isinstance(example["video"], VideoReader) + assert isinstance(example["video"][0], np.ndarray) + + # from bytes + with open(video_path, "rb") as f: + data = {"video": [f.read()]} + dset = Dataset.from_dict(data, features=features) + dset = dset.map(lambda x: x).with_format("numpy") + example = dset[0] + assert isinstance(example["video"], VideoReader) + assert isinstance(example["video"][0], np.ndarray) diff --git a/tests/utils.py b/tests/utils.py index e19740a2a12..08497e1eae7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -178,6 +178,18 @@ def require_pil(test_case): return test_case +def require_decord(test_case): + """ + Decorator marking a test that requires decord. + + These tests are skipped when decord isn't installed. + + """ + if not config.DECORD_AVAILABLE: + test_case = unittest.skip("test requires decord")(test_case) + return test_case + + def require_transformers(test_case): """ Decorator marking a test that requires transformers.