diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 690c3a19601..295974c3d20 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -72,6 +72,8 @@
title: Semantic segmentation
- local: object_detection
title: Object detection
+ - local: video_load
+ title: Load video data
title: "Vision"
- sections:
- local: nlp_load
diff --git a/docs/source/package_reference/main_classes.mdx b/docs/source/package_reference/main_classes.mdx
index c08d555cd29..402e291be38 100644
--- a/docs/source/package_reference/main_classes.mdx
+++ b/docs/source/package_reference/main_classes.mdx
@@ -245,6 +245,10 @@ Dictionary with split names as keys ('train', 'test' for example), and `Iterable
[[autodoc]] datasets.Image
+### Video
+
+[[autodoc]] datasets.Video
+
## Filesystems
[[autodoc]] datasets.filesystems.is_remote_filesystem
diff --git a/docs/source/video_load.mdx b/docs/source/video_load.mdx
new file mode 100644
index 00000000000..b30f0f4aeff
--- /dev/null
+++ b/docs/source/video_load.mdx
@@ -0,0 +1,109 @@
+# Load video data
+
+
+
+Video support is experimental and is subject to change.
+
+
+
+Video datasets have [`Video`] type columns, which contain `decord` objects.
+
+
+
+To work with video datasets, you need to have the `vision` dependency installed. Check out the [installation](./installation#vision) guide to learn how to install it.
+
+
+
+When you load an video dataset and call the video column, the videos are decoded as `decord` Videos:
+
+```py
+>>> from datasets import load_dataset, Video
+
+>>> dataset = load_dataset("path/to/video/folder", split="train")
+>>> dataset[0]["video"]
+
+```
+
+
+
+Index into an video dataset using the row index first and then the `video` column - `dataset[0]["video"]` - to avoid decoding and resampling all the video objects in the dataset. Otherwise, this can be a slow and time-consuming process if you have a large dataset.
+
+
+
+For a guide on how to load any type of dataset, take a look at the general loading guide.
+
+## Local files
+
+You can load a dataset from the video path. Use the [`~Dataset.cast_column`] function to accept a column of video file paths, and decode it into a `decord` video with the [`Video`] feature:
+```py
+>>> from datasets import Dataset, Video
+
+>>> dataset = Dataset.from_dict({"video": ["path/to/video_1", "path/to/video_2", ..., "path/to/video_n"]}).cast_column("video", Video())
+>>> dataset[0]["video"]
+
+```
+
+If you only want to load the underlying path to the video dataset without decoding the video object, set `decode=False` in the [`Video`] feature:
+
+```py
+>>> dataset = dataset.cast_column("video", Video(decode=False))
+>>> dataset[0]["video"]
+{'bytes': None,
+ 'path': 'path/to/video/folder/video0.mp4'}
+```
+
+## VideoFolder
+
+You can also load a dataset with an `VideoFolder` dataset builder which does not require writing a custom dataloader. This makes `VideoFolder` ideal for quickly creating and loading video datasets with several thousand videos for different vision tasks. Your video dataset structure should look like this:
+
+```
+folder/train/dog/golden_retriever.mp4
+folder/train/dog/german_shepherd.mp4
+folder/train/dog/chihuahua.mp4
+
+folder/train/cat/maine_coon.mp4
+folder/train/cat/bengal.mp4
+folder/train/cat/birman.mp4
+```
+
+Load your dataset by specifying `videofolder` and the directory of your dataset in `data_dir`:
+
+```py
+>>> from datasets import load_dataset
+
+>>> dataset = load_dataset("videofolder", data_dir="/path/to/folder")
+>>> dataset["train"][0]
+{"video": , "label": 0}
+
+>>> dataset["train"][-1]
+{"video": , "label": 1}
+```
+
+Load remote datasets from their URLs with the `data_files` parameter:
+
+```py
+>>> dataset = load_dataset("videofolder", data_files="https://foo.bar/videos.zip", split="train")
+```
+
+Some datasets have a metadata file (`metadata.csv`/`metadata.jsonl`) associated with it, containing other information about the data like bounding boxes, text captions, and labels. The metadata is automatically loaded when you call [`load_dataset`] and specify `videofolder`.
+
+To ignore the information in the metadata file, set `drop_labels=False` in [`load_dataset`], and allow `VideoFolder` to automatically infer the label name from the directory name:
+
+```py
+>>> from datasets import load_dataset
+
+>>> dataset = load_dataset("videofolder", data_dir="/path/to/folder", drop_labels=False)
+```
+
+## WebDataset
+
+The [WebDataset](https://github.com/webdataset/webdataset) format is based on a folder of TAR archives and is suitable for big video datasets.
+Because of their size, WebDatasets are generally loaded in streaming mode (using `streaming=True`).
+
+You can load a WebDataset like this:
+
+```python
+>>> from datasets import load_dataset
+
+>>> dataset = load_dataset("webdataset", data_dir="/path/to/folder", streaming=True)
+```
diff --git a/setup.py b/setup.py
index 9f926e4f7b6..2970ba00299 100644
--- a/setup.py
+++ b/setup.py
@@ -187,6 +187,7 @@
"transformers>=4.42.0", # Pins numpy < 2
"zstandard",
"polars[timezone]>=0.20.0",
+ "decord==0.6.0",
]
diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py
index f65995f93d9..eb8fedce996 100644
--- a/src/datasets/arrow_dataset.py
+++ b/src/datasets/arrow_dataset.py
@@ -77,7 +77,7 @@
from .arrow_writer import ArrowWriter, OptimizedTypedSequence
from .data_files import sanitize_patterns
from .download.streaming_download_manager import xgetsize
-from .features import Audio, ClassLabel, Features, Image, Sequence, Value
+from .features import Audio, ClassLabel, Features, Image, Sequence, Value, Video
from .features.features import (
FeatureType,
_align_features,
@@ -1416,9 +1416,9 @@ def save_to_disk(
"""
Saves a dataset to a dataset directory, or in a filesystem using any implementation of `fsspec.spec.AbstractFileSystem`.
- For [`Image`] and [`Audio`] data:
+ For [`Image`], [`Audio`] and [`Video`] data:
- All the Image() and Audio() data are stored in the arrow files.
+ All the Image(), Audio() and Video() data are stored in the arrow files.
If you want to store paths or urls, please use the Value("string") type.
Args:
@@ -5065,7 +5065,7 @@ def _estimate_nbytes(self) -> int:
def extra_nbytes_visitor(array, feature):
nonlocal extra_nbytes
- if isinstance(feature, (Audio, Image)):
+ if isinstance(feature, (Audio, Image, Video)):
for x in array.to_pylist():
if x is not None and x["bytes"] is None and x["path"] is not None:
size = xgetsize(x["path"])
@@ -5249,15 +5249,16 @@ def _push_parquet_shards_to_hub(
shards = (self.shard(num_shards=num_shards, index=i, contiguous=True) for i in range(num_shards))
if decodable_columns:
+ from .io.parquet import get_writer_batch_size
- def shards_with_embedded_external_files(shards):
+ def shards_with_embedded_external_files(shards: Iterator[Dataset]) -> Iterator[Dataset]:
for shard in shards:
format = shard.format
shard = shard.with_format("arrow")
shard = shard.map(
embed_table_storage,
batched=True,
- batch_size=1000,
+ batch_size=get_writer_batch_size(shard.features),
keep_in_memory=True,
)
shard = shard.with_format(**format)
@@ -5310,7 +5311,7 @@ def push_to_hub(
"""Pushes the dataset to the hub as a Parquet dataset.
The dataset is pushed using HTTP requests and does not need to have neither git or git-lfs installed.
- The resulting Parquet files are self-contained by default. If your dataset contains [`Image`] or [`Audio`]
+ The resulting Parquet files are self-contained by default. If your dataset contains [`Image`], [`Audio`] or [`Video`]
data, the Parquet files will store the bytes of your images or audio files.
You can disable this by setting `embed_external_files` to `False`.
diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py
index 3b9993736e4..763ad8b7a41 100644
--- a/src/datasets/arrow_writer.py
+++ b/src/datasets/arrow_writer.py
@@ -340,7 +340,7 @@ def __init__(
self.fingerprint = fingerprint
self.disable_nullable = disable_nullable
- self.writer_batch_size = writer_batch_size or config.DEFAULT_MAX_BATCH_SIZE
+ self.writer_batch_size = writer_batch_size
self.update_features = update_features
self.with_metadata = with_metadata
self.unit = unit
@@ -353,6 +353,11 @@ def __init__(
self.pa_writer: Optional[pa.RecordBatchStreamWriter] = None
self.hkey_record = []
+ if self.writer_batch_size is None and self._features is not None:
+ from .io.parquet import get_writer_batch_size
+
+ self.writer_batch_size = get_writer_batch_size(self._features) or config.DEFAULT_MAX_BATCH_SIZE
+
def __len__(self):
"""Return the number of writed and staged examples"""
return self._num_examples + len(self.current_examples) + len(self.current_rows)
@@ -397,6 +402,10 @@ def _build_writer(self, inferred_schema: pa.Schema):
schema = schema.with_metadata({})
self._schema = schema
self.pa_writer = self._WRITER_CLASS(self.stream, schema)
+ if self.writer_batch_size is None:
+ from .io.parquet import get_writer_batch_size
+
+ self.writer_batch_size = get_writer_batch_size(self._features) or config.DEFAULT_MAX_BATCH_SIZE
@property
def schema(self):
diff --git a/src/datasets/builder.py b/src/datasets/builder.py
index 7328b90cbca..c3eee41c6e0 100644
--- a/src/datasets/builder.py
+++ b/src/datasets/builder.py
@@ -930,7 +930,8 @@ def incomplete_dir(dirname):
# Sync info
self.info.dataset_size = sum(split.num_bytes for split in self.info.splits.values())
self.info.download_checksums = dl_manager.get_recorded_sizes_checksums()
- self.info.size_in_bytes = self.info.dataset_size + self.info.download_size
+ if self.info.download_size is not None:
+ self.info.size_in_bytes = self.info.dataset_size + self.info.download_size
# Save info
self._save_info()
diff --git a/src/datasets/config.py b/src/datasets/config.py
index e2de170bcba..e2de31b337a 100644
--- a/src/datasets/config.py
+++ b/src/datasets/config.py
@@ -129,6 +129,7 @@
IS_MP3_SUPPORTED = importlib.util.find_spec("soundfile") is not None and version.parse(
importlib.import_module("soundfile").__libsndfile_version__
) >= version.parse("1.1.0")
+DECORD_AVAILABLE = importlib.util.find_spec("decord") is not None
# Optional compression tools
RARFILE_AVAILABLE = importlib.util.find_spec("rarfile") is not None
@@ -192,6 +193,7 @@
PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS = 100
PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS = 100
PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS = 100
+PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS = 10
# Offline mode
_offline = os.environ.get("HF_DATASETS_OFFLINE")
diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py
index f92a1a8afda..40ca0cd7312 100644
--- a/src/datasets/dataset_dict.py
+++ b/src/datasets/dataset_dict.py
@@ -1230,9 +1230,9 @@ def save_to_disk(
"""
Saves a dataset dict to a filesystem using `fsspec.spec.AbstractFileSystem`.
- For [`Image`] and [`Audio`] data:
+ For [`Image`], [`Audio`] and [`Video`] data:
- All the Image() and Audio() data are stored in the arrow files.
+ All the Image(), Audio() and Video() data are stored in the arrow files.
If you want to store paths or urls, please use the Value("string") type.
Args:
diff --git a/src/datasets/download/streaming_download_manager.py b/src/datasets/download/streaming_download_manager.py
index 9a1d1e3a53c..800f6821443 100644
--- a/src/datasets/download/streaming_download_manager.py
+++ b/src/datasets/download/streaming_download_manager.py
@@ -64,6 +64,8 @@ def __init__(
self._data_dir = data_dir
self._base_path = base_path or os.path.abspath(".")
self.download_config = download_config or DownloadConfig()
+ self.downloaded_size = None
+ self.record_checksums = False
@property
def manual_dir(self):
@@ -208,3 +210,9 @@ def iter_files(self, urlpaths: Union[str, List[str]]) -> Iterable[str]:
```
"""
return FilesIterable.from_urlpaths(urlpaths, download_config=self.download_config)
+
+ def manage_extracted_files(self):
+ pass
+
+ def get_recorded_sizes_checksums(self):
+ pass
diff --git a/src/datasets/features/__init__.py b/src/datasets/features/__init__.py
index 35ebfb4ac0c..bf38042eb81 100644
--- a/src/datasets/features/__init__.py
+++ b/src/datasets/features/__init__.py
@@ -12,8 +12,10 @@
"Image",
"Translation",
"TranslationVariableLanguages",
+ "Video",
]
from .audio import Audio
from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, Sequence, Value
from .image import Image
from .translation import Translation, TranslationVariableLanguages
+from .video import Video
diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py
index 1d241e0b7b7..34622cd94d9 100644
--- a/src/datasets/features/features.py
+++ b/src/datasets/features/features.py
@@ -43,6 +43,7 @@
from .audio import Audio
from .image import Image, encode_pil_image
from .translation import Translation, TranslationVariableLanguages
+from .video import Video
logger = logging.get_logger(__name__)
@@ -1202,6 +1203,7 @@ class LargeList:
Array5D,
Audio,
Image,
+ Video,
]
@@ -1346,7 +1348,7 @@ def encode_nested_example(schema, obj, level=0):
return list(obj)
# Object with special encoding:
# ClassLabel will convert from string to int, TranslationVariableLanguages does some checks
- elif isinstance(schema, (Audio, Image, ClassLabel, TranslationVariableLanguages, Value, _ArrayXD)):
+ elif isinstance(schema, (Audio, Image, ClassLabel, TranslationVariableLanguages, Value, _ArrayXD, Video)):
return schema.encode_example(obj) if obj is not None else None
# Other object should be directly convertible to a native Arrow type (like Translation and Translation)
return obj
@@ -1397,7 +1399,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni
else:
return decode_nested_example([schema.feature], obj)
# Object with special decoding:
- elif isinstance(schema, (Audio, Image)):
+ elif isinstance(schema, (Audio, Image, Video)):
# we pass the token to read and decode files from private repositories in streaming mode
if obj is not None and schema.decode:
return schema.decode_example(obj, token_per_repo_id=token_per_repo_id)
@@ -1417,6 +1419,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni
Array5D.__name__: Array5D,
Audio.__name__: Audio,
Image.__name__: Image,
+ Video.__name__: Video,
}
diff --git a/src/datasets/features/video.py b/src/datasets/features/video.py
new file mode 100644
index 00000000000..a9967f3be45
--- /dev/null
+++ b/src/datasets/features/video.py
@@ -0,0 +1,340 @@
+import os
+from dataclasses import dataclass, field
+from io import BytesIO
+from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union
+
+import numpy as np
+import pyarrow as pa
+
+from .. import config
+from ..download.download_config import DownloadConfig
+from ..table import array_cast
+from ..utils.file_utils import is_local_path, xopen
+from ..utils.py_utils import no_op_if_value_is_null, string_to_dict
+
+
+if TYPE_CHECKING:
+ from decord import VideoReader
+
+ from .features import FeatureType
+
+
+@dataclass
+class Video:
+ """Video [`Feature`] to read video data from a video file.
+
+ Input: The Video feature accepts as input:
+ - A `str`: Absolute path to the video file (i.e. random access is allowed).
+ - A `dict` with the keys:
+
+ - `path`: String with relative path of the video file in a dataset repository.
+ - `bytes`: Bytes of the video file.
+
+ This is useful for archived files with sequential access.
+
+ - An `np.ndarray`: NumPy array representing a video.
+ - A `decord.VideoReader`: decord video reader object.
+
+ Args:
+ mode (`str`, *optional*):
+ The mode to convert the video to. If `None`, the native mode of the video is used.
+ decode (`bool`, defaults to `True`):
+ Whether to decode the video data. If `False`,
+ returns the underlying dictionary in the format `{"path": video_path, "bytes": video_bytes}`.
+
+ Examples:
+
+ ```py
+ >>> from datasets import Dataset, Video
+ >>> ds = Dataset.from_dict({"video":["path/to/Screen Recording.mov"]}).cast_column("video", Video())
+ >>> ds.features["video"]
+ Video(decode=True, id=None)
+ >>> ds[0]["video"]
+
+ >>> ds = ds.cast_column('video', Video(decode=False))
+ {'bytes': None,
+ 'path': 'path/to/Screen Recording.mov'}
+ ```
+ """
+
+ decode: bool = True
+ id: Optional[str] = None
+ # Automatically constructed
+ dtype: ClassVar[str] = "decord.VideoReader"
+ pa_type: ClassVar[Any] = pa.struct({"bytes": pa.binary(), "path": pa.string()})
+ _type: str = field(default="Video", init=False, repr=False)
+
+ def __call__(self):
+ return self.pa_type
+
+ def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "VideoReader"]) -> dict:
+ """Encode example into a format for Arrow.
+
+ Args:
+ value (`str`, `np.ndarray`, `VideoReader` or `dict`):
+ Data passed as input to Video feature.
+
+ Returns:
+ `dict` with "path" and "bytes" fields
+ """
+ if config.DECORD_AVAILABLE:
+ # We need to import torch first, otherwise later it can cause issues
+ # e.g. "RuntimeError: random_device could not be read"
+ # when running `torch.tensor(value).share_memory_()`
+ if config.TORCH_AVAILABLE:
+ import torch # noqa
+ from decord import VideoReader
+
+ else:
+ raise ImportError("To support encoding videos, please install 'decord'.")
+
+ if isinstance(value, list):
+ value = np.array(value)
+
+ if isinstance(value, str):
+ return {"path": value, "bytes": None}
+ elif isinstance(value, bytes):
+ return {"path": None, "bytes": value}
+ elif isinstance(value, np.ndarray):
+ # convert the video array to bytes
+ return encode_np_array(value)
+ elif isinstance(value, VideoReader):
+ # convert the decord video reader to bytes
+ return encode_decord_video(value)
+ elif value.get("path") is not None and os.path.isfile(value["path"]):
+ # we set "bytes": None to not duplicate the data if they're already available locally
+ return {"bytes": None, "path": value.get("path")}
+ elif value.get("bytes") is not None or value.get("path") is not None:
+ # store the video bytes, and path is used to infer the video format using the file extension
+ return {"bytes": value.get("bytes"), "path": value.get("path")}
+ else:
+ raise ValueError(
+ f"A video sample should have one of 'path' or 'bytes' but they are missing or None in {value}."
+ )
+
+ def decode_example(self, value: dict, token_per_repo_id=None) -> "VideoReader":
+ """Decode example video file into video data.
+
+ Args:
+ value (`str` or `dict`):
+ A string with the absolute video file path, a dictionary with
+ keys:
+
+ - `path`: String with absolute or relative video file path.
+ - `bytes`: The bytes of the video file.
+ token_per_repo_id (`dict`, *optional*):
+ To access and decode
+ video files from private repositories on the Hub, you can pass
+ a dictionary repo_id (`str`) -> token (`bool` or `str`).
+
+ Returns:
+ `decord.VideoReader`
+ """
+ if not self.decode:
+ raise RuntimeError("Decoding is disabled for this feature. Please use Video(decode=True) instead.")
+
+ if config.DECORD_AVAILABLE:
+ # We need to import torch first, otherwise later it can cause issues
+ # e.g. "RuntimeError: random_device could not be read"
+ # when running `torch.tensor(value).share_memory_()`
+ if config.TORCH_AVAILABLE:
+ import torch # noqa
+ from decord import VideoReader
+ else:
+ raise ImportError("To support decoding videos, please install 'decord'.")
+
+ if token_per_repo_id is None:
+ token_per_repo_id = {}
+
+ path, bytes_ = value["path"], value["bytes"]
+ if bytes_ is None:
+ if path is None:
+ raise ValueError(f"A video should have one of 'path' or 'bytes' but both are None in {value}.")
+ else:
+ if is_local_path(path):
+ video = VideoReader(path)
+ else:
+ source_url = path.split("::")[-1]
+ pattern = (
+ config.HUB_DATASETS_URL
+ if source_url.startswith(config.HF_ENDPOINT)
+ else config.HUB_DATASETS_HFFS_URL
+ )
+ try:
+ repo_id = string_to_dict(source_url, pattern)["repo_id"]
+ token = token_per_repo_id.get(repo_id)
+ except ValueError:
+ token = None
+ download_config = DownloadConfig(token=token)
+ with xopen(path, "rb", download_config=download_config) as f:
+ bytes_ = BytesIO(f.read())
+ video = VideoReader(bytes_)
+ else:
+ video = VideoReader(BytesIO(bytes_))
+ return video
+
+ def flatten(self) -> Union["FeatureType", Dict[str, "FeatureType"]]:
+ """If in the decodable state, return the feature itself, otherwise flatten the feature into a dictionary."""
+ from .features import Value
+
+ return (
+ self
+ if self.decode
+ else {
+ "bytes": Value("binary"),
+ "path": Value("string"),
+ }
+ )
+
+ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArray]) -> pa.StructArray:
+ """Cast an Arrow array to the Video arrow storage type.
+ The Arrow types that can be converted to the Video pyarrow storage type are:
+
+ - `pa.string()` - it must contain the "path" data
+ - `pa.binary()` - it must contain the video bytes
+ - `pa.struct({"bytes": pa.binary()})`
+ - `pa.struct({"path": pa.string()})`
+ - `pa.struct({"bytes": pa.binary(), "path": pa.string()})` - order doesn't matter
+ - `pa.list(*)` - it must contain the video array data
+
+ Args:
+ storage (`Union[pa.StringArray, pa.StructArray, pa.ListArray]`):
+ PyArrow array to cast.
+
+ Returns:
+ `pa.StructArray`: Array in the Video arrow storage type, that is
+ `pa.struct({"bytes": pa.binary(), "path": pa.string()})`.
+ """
+ if pa.types.is_string(storage.type):
+ bytes_array = pa.array([None] * len(storage), type=pa.binary())
+ storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null())
+ elif pa.types.is_binary(storage.type):
+ path_array = pa.array([None] * len(storage), type=pa.string())
+ storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null())
+ elif pa.types.is_struct(storage.type):
+ if storage.type.get_field_index("bytes") >= 0:
+ bytes_array = storage.field("bytes")
+ else:
+ bytes_array = pa.array([None] * len(storage), type=pa.binary())
+ if storage.type.get_field_index("path") >= 0:
+ path_array = storage.field("path")
+ else:
+ path_array = pa.array([None] * len(storage), type=pa.string())
+ storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=storage.is_null())
+ elif pa.types.is_list(storage.type):
+ bytes_array = pa.array(
+ [encode_np_array(np.array(arr))["bytes"] if arr is not None else None for arr in storage.to_pylist()],
+ type=pa.binary(),
+ )
+ path_array = pa.array([None] * len(storage), type=pa.string())
+ storage = pa.StructArray.from_arrays(
+ [bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null()
+ )
+ return array_cast(storage, self.pa_type)
+
+ def embed_storage(self, storage: pa.StructArray) -> pa.StructArray:
+ """Embed video files into the Arrow array.
+
+ Args:
+ storage (`pa.StructArray`):
+ PyArrow array to embed.
+
+ Returns:
+ `pa.StructArray`: Array in the Video arrow storage type, that is
+ `pa.struct({"bytes": pa.binary(), "path": pa.string()})`.
+ """
+
+ @no_op_if_value_is_null
+ def path_to_bytes(path):
+ with xopen(path, "rb") as f:
+ bytes_ = f.read()
+ return bytes_
+
+ bytes_array = pa.array(
+ [
+ (path_to_bytes(x["path"]) if x["bytes"] is None else x["bytes"]) if x is not None else None
+ for x in storage.to_pylist()
+ ],
+ type=pa.binary(),
+ )
+ path_array = pa.array(
+ [
+ (os.path.basename(path) if os.path.isfile(path) else path) if path is not None else None
+ for path in storage.field("path").to_pylist()
+ ],
+ type=pa.string(),
+ )
+ storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null())
+ return array_cast(storage, self.pa_type)
+
+
+def video_to_bytes(video: "VideoReader") -> bytes:
+ """Convert a decord Video object to bytes using native compression if possible"""
+ raise NotImplementedError()
+
+
+def encode_decord_video(video: "VideoReader") -> dict:
+ if hasattr(video, "_hf_encoded"):
+ return video._hf_encoded
+ else:
+ raise NotImplementedError(
+ "Encoding a decord video is not implemented. "
+ "Please call `datasets.features.video.patch_decord()` before loading videos to enable this."
+ )
+
+
+def encode_np_array(array: np.ndarray) -> dict:
+ raise NotImplementedError()
+
+
+# Patching decord a little bit to:
+# 1. store the encoded video data {"path": ..., "bytes": ...} in `video._hf_encoded``
+# 2. set the decord bridge to numpy/torch/tf/jax using `video._hf_bridge_out` (per video instance) instead of decord.bridge.bridge_out (global)
+# This doesn't affect the normal usage of decord.
+
+
+def _patched_init(self: "VideoReader", uri: Union[str, BytesIO], *args, **kwargs) -> None:
+ from decord.bridge import bridge_out
+
+ if hasattr(uri, "read"):
+ self._hf_encoded = {"bytes": uri.read(), "path": None}
+ uri.seek(0)
+ elif isinstance(uri, str):
+ self._hf_encoded = {"bytes": None, "path": uri}
+ self._hf_bridge_out = bridge_out
+ self._original_init(uri, *args, **kwargs)
+
+
+def _patched_next(self: "VideoReader", *args, **kwargs):
+ return self._hf_bridge_out(self._original_next(*args, **kwargs))
+
+
+def _patched_get_batch(self: "VideoReader", *args, **kwargs):
+ return self._hf_bridge_out(self._original_get_batch(*args, **kwargs))
+
+
+def patch_decord():
+ if config.DECORD_AVAILABLE:
+ # We need to import torch first, otherwise later it can cause issues
+ # e.g. "RuntimeError: random_device could not be read"
+ # when running `torch.tensor(value).share_memory_()`
+ if config.TORCH_AVAILABLE:
+ import torch # noqa
+ import decord.video_reader
+ from decord import VideoReader
+
+ if not hasattr(VideoReader, "_hf_patched"):
+ decord.video_reader.bridge_out = lambda x: x
+ VideoReader._original_init = VideoReader.__init__
+ VideoReader.__init__ = _patched_init
+ VideoReader._original_next = VideoReader.next
+ VideoReader.next = _patched_next
+ VideoReader._original_get_batch = VideoReader.get_batch
+ VideoReader.get_batch = _patched_get_batch
+ VideoReader._hf_patched = True
+ else:
+ raise ImportError("To support decoding videos, please install 'decord'.")
+
+
+if config.DECORD_AVAILABLE:
+ patch_decord()
diff --git a/src/datasets/formatting/jax_formatter.py b/src/datasets/formatting/jax_formatter.py
index 8035341c5cd..e247b7b5822 100644
--- a/src/datasets/formatting/jax_formatter.py
+++ b/src/datasets/formatting/jax_formatter.py
@@ -100,11 +100,23 @@ def _tensorize(self, value):
default_dtype = {"dtype": jnp.int32}
elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating):
default_dtype = {"dtype": jnp.float32}
- elif config.PIL_AVAILABLE and "PIL" in sys.modules:
+
+ if config.PIL_AVAILABLE and "PIL" in sys.modules:
import PIL.Image
if isinstance(value, PIL.Image.Image):
value = np.asarray(value)
+ if config.DECORD_AVAILABLE and "decord" in sys.modules:
+ # We need to import torch first, otherwise later it can cause issues
+ # e.g. "RuntimeError: random_device could not be read"
+ # when running `torch.tensor(value).share_memory_()`
+ if config.TORCH_AVAILABLE:
+ import torch # noqa
+ from decord import VideoReader
+
+ if isinstance(value, VideoReader):
+ value._hf_bridge_out = lambda x: jnp.array(np.asarray(x))
+ return value
# using global variable since `jaxlib.xla_extension.Device` is not serializable neither
# with `pickle` nor with `dill`, so we need to use a global variable instead
diff --git a/src/datasets/formatting/np_formatter.py b/src/datasets/formatting/np_formatter.py
index 95bcff2b517..032758bce21 100644
--- a/src/datasets/formatting/np_formatter.py
+++ b/src/datasets/formatting/np_formatter.py
@@ -57,11 +57,23 @@ def _tensorize(self, value):
default_dtype = {"dtype": np.int64}
elif isinstance(value, np.ndarray) and np.issubdtype(value.dtype, np.floating):
default_dtype = {"dtype": np.float32}
- elif config.PIL_AVAILABLE and "PIL" in sys.modules:
+
+ if config.PIL_AVAILABLE and "PIL" in sys.modules:
import PIL.Image
if isinstance(value, PIL.Image.Image):
return np.asarray(value, **self.np_array_kwargs)
+ if config.DECORD_AVAILABLE and "decord" in sys.modules:
+ # We need to import torch first, otherwise later it can cause issues
+ # e.g. "RuntimeError: random_device could not be read"
+ # when running `torch.tensor(value).share_memory_()`
+ if config.TORCH_AVAILABLE:
+ import torch # noqa
+ from decord import VideoReader
+
+ if isinstance(value, VideoReader):
+ value._hf_bridge_out = np.asarray
+ return value
return np.asarray(value, **{**default_dtype, **self.np_array_kwargs})
diff --git a/src/datasets/formatting/tf_formatter.py b/src/datasets/formatting/tf_formatter.py
index adb15cda381..9f0c06ec82a 100644
--- a/src/datasets/formatting/tf_formatter.py
+++ b/src/datasets/formatting/tf_formatter.py
@@ -64,11 +64,24 @@ def _tensorize(self, value):
default_dtype = {"dtype": tf.int64}
elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating):
default_dtype = {"dtype": tf.float32}
- elif config.PIL_AVAILABLE and "PIL" in sys.modules:
+
+ if config.PIL_AVAILABLE and "PIL" in sys.modules:
import PIL.Image
if isinstance(value, PIL.Image.Image):
value = np.asarray(value)
+ if config.DECORD_AVAILABLE and "decord" in sys.modules:
+ # We need to import torch first, otherwise later it can cause issues
+ # e.g. "RuntimeError: random_device could not be read"
+ # when running `torch.tensor(value).share_memory_()`
+ if config.TORCH_AVAILABLE:
+ import torch # noqa
+ from decord import VideoReader
+ from decord.bridge import to_tensorflow
+
+ if isinstance(value, VideoReader):
+ value._hf_bridge_out = to_tensorflow
+ return value
return tf.convert_to_tensor(value, **{**default_dtype, **self.tf_tensor_kwargs})
diff --git a/src/datasets/formatting/torch_formatter.py b/src/datasets/formatting/torch_formatter.py
index 8efe759a144..051badb0ac4 100644
--- a/src/datasets/formatting/torch_formatter.py
+++ b/src/datasets/formatting/torch_formatter.py
@@ -66,7 +66,8 @@ def _tensorize(self, value):
elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating):
default_dtype = {"dtype": torch.float32}
- elif config.PIL_AVAILABLE and "PIL" in sys.modules:
+
+ if config.PIL_AVAILABLE and "PIL" in sys.modules:
import PIL.Image
if isinstance(value, PIL.Image.Image):
@@ -75,6 +76,14 @@ def _tensorize(self, value):
value = value[:, :, np.newaxis]
value = value.transpose((2, 0, 1))
+ if config.DECORD_AVAILABLE and "decord" in sys.modules:
+ from decord import VideoReader
+ from decord.bridge import to_torch
+
+ if isinstance(value, VideoReader):
+ value._hf_bridge_out = to_torch
+ return value
+
return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})
def _recursive_tensorize(self, data_struct):
diff --git a/src/datasets/io/parquet.py b/src/datasets/io/parquet.py
index 51434106fb1..289bd1adfdc 100644
--- a/src/datasets/io/parquet.py
+++ b/src/datasets/io/parquet.py
@@ -5,7 +5,7 @@
import numpy as np
import pyarrow.parquet as pq
-from .. import Audio, Dataset, Features, Image, NamedSplit, Value, config
+from .. import Audio, Dataset, Features, Image, NamedSplit, Value, Video, config
from ..features.features import FeatureType, _visit
from ..formatting import query_table
from ..packaged_modules import _PACKAGED_DATASETS_MODULES
@@ -42,6 +42,8 @@ def set_batch_size(feature: FeatureType) -> None:
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS)
elif isinstance(feature, Audio):
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS)
+ elif isinstance(feature, Video):
+ batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS)
elif isinstance(feature, Value) and feature.dtype == "binary":
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS)
diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py
index 7598d0213ac..94d806edf45 100644
--- a/src/datasets/packaged_modules/__init__.py
+++ b/src/datasets/packaged_modules/__init__.py
@@ -14,6 +14,7 @@
from .parquet import parquet
from .sql import sql
from .text import text
+from .videofolder import videofolder
from .webdataset import webdataset
from .xml import xml
@@ -41,6 +42,7 @@ def _hash_python_lines(lines: List[str]) -> str:
"text": (text.__name__, _hash_python_lines(inspect.getsource(text).splitlines())),
"imagefolder": (imagefolder.__name__, _hash_python_lines(inspect.getsource(imagefolder).splitlines())),
"audiofolder": (audiofolder.__name__, _hash_python_lines(inspect.getsource(audiofolder).splitlines())),
+ "videofolder": (videofolder.__name__, _hash_python_lines(inspect.getsource(videofolder).splitlines())),
"webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())),
"xml": (xml.__name__, _hash_python_lines(inspect.getsource(xml).splitlines())),
}
@@ -77,7 +79,9 @@ def _hash_python_lines(lines: List[str]) -> str:
_EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS})
_EXTENSION_TO_MODULE.update({ext: ("audiofolder", {}) for ext in audiofolder.AudioFolder.EXTENSIONS})
_EXTENSION_TO_MODULE.update({ext.upper(): ("audiofolder", {}) for ext in audiofolder.AudioFolder.EXTENSIONS})
-_MODULE_SUPPORTS_METADATA = {"imagefolder", "audiofolder"}
+_EXTENSION_TO_MODULE.update({ext: ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS})
+_EXTENSION_TO_MODULE.update({ext.upper(): ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS})
+_MODULE_SUPPORTS_METADATA = {"imagefolder", "audiofolder", "videofolder"}
# Used to filter data files based on extensions given a module name
_MODULE_TO_EXTENSIONS: Dict[str, List[str]] = {}
diff --git a/src/datasets/packaged_modules/videofolder/__init__.py b/src/datasets/packaged_modules/videofolder/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/src/datasets/packaged_modules/videofolder/videofolder.py b/src/datasets/packaged_modules/videofolder/videofolder.py
new file mode 100644
index 00000000000..7ce5bcf5655
--- /dev/null
+++ b/src/datasets/packaged_modules/videofolder/videofolder.py
@@ -0,0 +1,36 @@
+from typing import List
+
+import datasets
+
+from ..folder_based_builder import folder_based_builder
+
+
+logger = datasets.utils.logging.get_logger(__name__)
+
+
+class VideoFolderConfig(folder_based_builder.FolderBasedBuilderConfig):
+ """BuilderConfig for ImageFolder."""
+
+ drop_labels: bool = None
+ drop_metadata: bool = None
+
+ def __post_init__(self):
+ super().__post_init__()
+
+
+class VideoFolder(folder_based_builder.FolderBasedBuilder):
+ BASE_FEATURE = datasets.Video
+ BASE_COLUMN_NAME = "video"
+ BUILDER_CONFIG_CLASS = VideoFolderConfig
+ EXTENSIONS: List[str] # definition at the bottom of the script
+
+
+# TODO: initial list, we should check the compatibility of other formats
+VIDEO_EXTENSIONS = [
+ ".mkv",
+ ".mp4",
+ ".avi",
+ ".mpeg",
+ ".mov",
+]
+VideoFolder.EXTENSIONS = VIDEO_EXTENSIONS
diff --git a/src/datasets/packaged_modules/webdataset/webdataset.py b/src/datasets/packaged_modules/webdataset/webdataset.py
index c04f3ba4639..0768437b36a 100644
--- a/src/datasets/packaged_modules/webdataset/webdataset.py
+++ b/src/datasets/packaged_modules/webdataset/webdataset.py
@@ -20,6 +20,7 @@ class WebDataset(datasets.GeneratorBasedBuilder):
DEFAULT_WRITER_BATCH_SIZE = 100
IMAGE_EXTENSIONS: List[str] # definition at the bottom of the script
AUDIO_EXTENSIONS: List[str] # definition at the bottom of the script
+ VIDEO_EXTENSIONS: List[str] # definition at the bottom of the script
DECODERS: Dict[str, Callable[[Any], Any]] # definition at the bottom of the script
NUM_EXAMPLES_FOR_FEATURES_INFERENCE = 5
@@ -97,6 +98,11 @@ def _split_generators(self, dl_manager):
extension = field_name.rsplit(".", 1)[-1]
if extension in self.AUDIO_EXTENSIONS:
features[field_name] = datasets.Audio()
+ # Set Video types
+ for field_name in first_examples[0]:
+ extension = field_name.rsplit(".", 1)[-1]
+ if extension in self.VIDEO_EXTENSIONS:
+ features[field_name] = datasets.Video()
self.info.features = features
return splits
@@ -259,6 +265,17 @@ def base_plus_ext(path):
WebDataset.AUDIO_EXTENSIONS = AUDIO_EXTENSIONS
+# TODO: initial list, we should check the compatibility of other formats
+VIDEO_EXTENSIONS = [
+ ".mkv",
+ ".mp4",
+ ".avi",
+ ".mpeg",
+ ".mov",
+]
+WebDataset.VIDEO_EXTENSIONS = VIDEO_EXTENSIONS
+
+
def text_loads(data: bytes):
return data.decode("utf-8")
diff --git a/tests/features/data/test_video_66x50.mov b/tests/features/data/test_video_66x50.mov
new file mode 100644
index 00000000000..a55dcaa8f7b
Binary files /dev/null and b/tests/features/data/test_video_66x50.mov differ
diff --git a/tests/features/test_video.py b/tests/features/test_video.py
new file mode 100644
index 00000000000..f4c9a8d830b
--- /dev/null
+++ b/tests/features/test_video.py
@@ -0,0 +1,92 @@
+import pytest
+
+from datasets import Dataset, Features, Video
+
+from ..utils import require_decord
+
+
+@require_decord
+@pytest.mark.parametrize(
+ "build_example",
+ [
+ lambda video_path: video_path,
+ lambda video_path: open(video_path, "rb").read(),
+ lambda video_path: {"path": video_path},
+ lambda video_path: {"path": video_path, "bytes": None},
+ lambda video_path: {"path": video_path, "bytes": open(video_path, "rb").read()},
+ lambda video_path: {"path": None, "bytes": open(video_path, "rb").read()},
+ lambda video_path: {"bytes": open(video_path, "rb").read()},
+ ],
+)
+def test_video_feature_encode_example(shared_datadir, build_example):
+ from decord import VideoReader
+
+ video_path = str(shared_datadir / "test_video_66x50.mov")
+ video = Video()
+ encoded_example = video.encode_example(build_example(video_path))
+ assert isinstance(encoded_example, dict)
+ assert encoded_example.keys() == {"bytes", "path"}
+ assert encoded_example["bytes"] is not None or encoded_example["path"] is not None
+ decoded_example = video.decode_example(encoded_example)
+ assert isinstance(decoded_example, VideoReader)
+
+
+@require_decord
+def test_dataset_with_video_feature(shared_datadir):
+ from decord import VideoReader
+ from decord.ndarray import NDArray
+
+ video_path = str(shared_datadir / "test_video_66x50.mov")
+ data = {"video": [video_path]}
+ features = Features({"video": Video()})
+ dset = Dataset.from_dict(data, features=features)
+ item = dset[0]
+ assert item.keys() == {"video"}
+ assert isinstance(item["video"], VideoReader)
+ assert item["video"][0].shape == (50, 66, 3)
+ assert isinstance(item["video"][0], NDArray)
+ batch = dset[:1]
+ assert len(batch) == 1
+ assert batch.keys() == {"video"}
+ assert isinstance(batch["video"], list) and all(isinstance(item, VideoReader) for item in batch["video"])
+ assert batch["video"][0][0].shape == (50, 66, 3)
+ assert isinstance(batch["video"][0][0], NDArray)
+ column = dset["video"]
+ assert len(column) == 1
+ assert isinstance(column, list) and all(isinstance(item, VideoReader) for item in column)
+ assert column[0][0].shape == (50, 66, 3)
+ assert isinstance(column[0][0], NDArray)
+
+ # from bytes
+ with open(video_path, "rb") as f:
+ data = {"video": [f.read()]}
+ dset = Dataset.from_dict(data, features=features)
+ item = dset[0]
+ assert item.keys() == {"video"}
+ assert isinstance(item["video"], VideoReader)
+ assert item["video"][0].shape == (50, 66, 3)
+ assert isinstance(item["video"][0], NDArray)
+
+
+@require_decord
+def test_dataset_with_video_map_and_formatted(shared_datadir):
+ import numpy as np
+ from decord import VideoReader
+
+ video_path = str(shared_datadir / "test_video_66x50.mov")
+ data = {"video": [video_path]}
+ features = Features({"video": Video()})
+ dset = Dataset.from_dict(data, features=features)
+ dset = dset.map(lambda x: x).with_format("numpy")
+ example = dset[0]
+ assert isinstance(example["video"], VideoReader)
+ assert isinstance(example["video"][0], np.ndarray)
+
+ # from bytes
+ with open(video_path, "rb") as f:
+ data = {"video": [f.read()]}
+ dset = Dataset.from_dict(data, features=features)
+ dset = dset.map(lambda x: x).with_format("numpy")
+ example = dset[0]
+ assert isinstance(example["video"], VideoReader)
+ assert isinstance(example["video"][0], np.ndarray)
diff --git a/tests/utils.py b/tests/utils.py
index e19740a2a12..08497e1eae7 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -178,6 +178,18 @@ def require_pil(test_case):
return test_case
+def require_decord(test_case):
+ """
+ Decorator marking a test that requires decord.
+
+ These tests are skipped when decord isn't installed.
+
+ """
+ if not config.DECORD_AVAILABLE:
+ test_case = unittest.skip("test requires decord")(test_case)
+ return test_case
+
+
def require_transformers(test_case):
"""
Decorator marking a test that requires transformers.