Skip to content

Commit

Permalink
Video support (#7230)
Browse files Browse the repository at this point in the history
* initial video support

* support map and formatting

* ci test

* set row group size

* add to webdataset

* typos

* try ci without decord just in case

* import torch before decord to fix random_device could not be read

* fix CI

* minor

* better memory handling in push_to_hub

* better memory handling in load_dataset

* basic docs

* add to toc

* streaming tweaks

* keep hf:// URL in the video "path" field for the viewer
  • Loading branch information
lhoestq authored Oct 24, 2024
1 parent 80d6b48 commit 8235fdb
Show file tree
Hide file tree
Showing 25 changed files with 710 additions and 19 deletions.
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@
title: Semantic segmentation
- local: object_detection
title: Object detection
- local: video_load
title: Load video data
title: "Vision"
- sections:
- local: nlp_load
Expand Down
4 changes: 4 additions & 0 deletions docs/source/package_reference/main_classes.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,10 @@ Dictionary with split names as keys ('train', 'test' for example), and `Iterable

[[autodoc]] datasets.Image

### Video

[[autodoc]] datasets.Video

## Filesystems

[[autodoc]] datasets.filesystems.is_remote_filesystem
Expand Down
109 changes: 109 additions & 0 deletions docs/source/video_load.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Load video data

<Tip warning={true}>

Video support is experimental and is subject to change.

</Tip>

Video datasets have [`Video`] type columns, which contain `decord` objects.

<Tip>

To work with video datasets, you need to have the `vision` dependency installed. Check out the [installation](./installation#vision) guide to learn how to install it.

</Tip>

When you load an video dataset and call the video column, the videos are decoded as `decord` Videos:

```py
>>> from datasets import load_dataset, Video

>>> dataset = load_dataset("path/to/video/folder", split="train")
>>> dataset[0]["video"]
<decord.video_reader.VideoReader at 0x1652284c0>
```

<Tip warning={true}>

Index into an video dataset using the row index first and then the `video` column - `dataset[0]["video"]` - to avoid decoding and resampling all the video objects in the dataset. Otherwise, this can be a slow and time-consuming process if you have a large dataset.

</Tip>

For a guide on how to load any type of dataset, take a look at the <a class="underline decoration-sky-400 decoration-2 font-semibold" href="./loading">general loading guide</a>.

## Local files

You can load a dataset from the video path. Use the [`~Dataset.cast_column`] function to accept a column of video file paths, and decode it into a `decord` video with the [`Video`] feature:
```py
>>> from datasets import Dataset, Video

>>> dataset = Dataset.from_dict({"video": ["path/to/video_1", "path/to/video_2", ..., "path/to/video_n"]}).cast_column("video", Video())
>>> dataset[0]["video"]
<decord.video_reader.VideoReader at 0x1657d0280>
```

If you only want to load the underlying path to the video dataset without decoding the video object, set `decode=False` in the [`Video`] feature:

```py
>>> dataset = dataset.cast_column("video", Video(decode=False))
>>> dataset[0]["video"]
{'bytes': None,
'path': 'path/to/video/folder/video0.mp4'}
```

## VideoFolder

You can also load a dataset with an `VideoFolder` dataset builder which does not require writing a custom dataloader. This makes `VideoFolder` ideal for quickly creating and loading video datasets with several thousand videos for different vision tasks. Your video dataset structure should look like this:

```
folder/train/dog/golden_retriever.mp4
folder/train/dog/german_shepherd.mp4
folder/train/dog/chihuahua.mp4
folder/train/cat/maine_coon.mp4
folder/train/cat/bengal.mp4
folder/train/cat/birman.mp4
```

Load your dataset by specifying `videofolder` and the directory of your dataset in `data_dir`:

```py
>>> from datasets import load_dataset

>>> dataset = load_dataset("videofolder", data_dir="/path/to/folder")
>>> dataset["train"][0]
{"video": <decord.video_reader.VideoReader at 0x161715e50>, "label": 0}

>>> dataset["train"][-1]
{"video": <decord.video_reader.VideoReader at 0x16170bd90>, "label": 1}
```

Load remote datasets from their URLs with the `data_files` parameter:

```py
>>> dataset = load_dataset("videofolder", data_files="https://foo.bar/videos.zip", split="train")
```

Some datasets have a metadata file (`metadata.csv`/`metadata.jsonl`) associated with it, containing other information about the data like bounding boxes, text captions, and labels. The metadata is automatically loaded when you call [`load_dataset`] and specify `videofolder`.

To ignore the information in the metadata file, set `drop_labels=False` in [`load_dataset`], and allow `VideoFolder` to automatically infer the label name from the directory name:

```py
>>> from datasets import load_dataset

>>> dataset = load_dataset("videofolder", data_dir="/path/to/folder", drop_labels=False)
```

## WebDataset

The [WebDataset](https://github.com/webdataset/webdataset) format is based on a folder of TAR archives and is suitable for big video datasets.
Because of their size, WebDatasets are generally loaded in streaming mode (using `streaming=True`).

You can load a WebDataset like this:

```python
>>> from datasets import load_dataset

>>> dataset = load_dataset("webdataset", data_dir="/path/to/folder", streaming=True)
```
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@
"transformers>=4.42.0", # Pins numpy < 2
"zstandard",
"polars[timezone]>=0.20.0",
"decord==0.6.0",
]


Expand Down
15 changes: 8 additions & 7 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
from .arrow_writer import ArrowWriter, OptimizedTypedSequence
from .data_files import sanitize_patterns
from .download.streaming_download_manager import xgetsize
from .features import Audio, ClassLabel, Features, Image, Sequence, Value
from .features import Audio, ClassLabel, Features, Image, Sequence, Value, Video
from .features.features import (
FeatureType,
_align_features,
Expand Down Expand Up @@ -1416,9 +1416,9 @@ def save_to_disk(
"""
Saves a dataset to a dataset directory, or in a filesystem using any implementation of `fsspec.spec.AbstractFileSystem`.
For [`Image`] and [`Audio`] data:
For [`Image`], [`Audio`] and [`Video`] data:
All the Image() and Audio() data are stored in the arrow files.
All the Image(), Audio() and Video() data are stored in the arrow files.
If you want to store paths or urls, please use the Value("string") type.
Args:
Expand Down Expand Up @@ -5065,7 +5065,7 @@ def _estimate_nbytes(self) -> int:

def extra_nbytes_visitor(array, feature):
nonlocal extra_nbytes
if isinstance(feature, (Audio, Image)):
if isinstance(feature, (Audio, Image, Video)):
for x in array.to_pylist():
if x is not None and x["bytes"] is None and x["path"] is not None:
size = xgetsize(x["path"])
Expand Down Expand Up @@ -5249,15 +5249,16 @@ def _push_parquet_shards_to_hub(
shards = (self.shard(num_shards=num_shards, index=i, contiguous=True) for i in range(num_shards))

if decodable_columns:
from .io.parquet import get_writer_batch_size

def shards_with_embedded_external_files(shards):
def shards_with_embedded_external_files(shards: Iterator[Dataset]) -> Iterator[Dataset]:
for shard in shards:
format = shard.format
shard = shard.with_format("arrow")
shard = shard.map(
embed_table_storage,
batched=True,
batch_size=1000,
batch_size=get_writer_batch_size(shard.features),
keep_in_memory=True,
)
shard = shard.with_format(**format)
Expand Down Expand Up @@ -5310,7 +5311,7 @@ def push_to_hub(
"""Pushes the dataset to the hub as a Parquet dataset.
The dataset is pushed using HTTP requests and does not need to have neither git or git-lfs installed.
The resulting Parquet files are self-contained by default. If your dataset contains [`Image`] or [`Audio`]
The resulting Parquet files are self-contained by default. If your dataset contains [`Image`], [`Audio`] or [`Video`]
data, the Parquet files will store the bytes of your images or audio files.
You can disable this by setting `embed_external_files` to `False`.
Expand Down
11 changes: 10 additions & 1 deletion src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def __init__(

self.fingerprint = fingerprint
self.disable_nullable = disable_nullable
self.writer_batch_size = writer_batch_size or config.DEFAULT_MAX_BATCH_SIZE
self.writer_batch_size = writer_batch_size
self.update_features = update_features
self.with_metadata = with_metadata
self.unit = unit
Expand All @@ -353,6 +353,11 @@ def __init__(
self.pa_writer: Optional[pa.RecordBatchStreamWriter] = None
self.hkey_record = []

if self.writer_batch_size is None and self._features is not None:
from .io.parquet import get_writer_batch_size

self.writer_batch_size = get_writer_batch_size(self._features) or config.DEFAULT_MAX_BATCH_SIZE

def __len__(self):
"""Return the number of writed and staged examples"""
return self._num_examples + len(self.current_examples) + len(self.current_rows)
Expand Down Expand Up @@ -397,6 +402,10 @@ def _build_writer(self, inferred_schema: pa.Schema):
schema = schema.with_metadata({})
self._schema = schema
self.pa_writer = self._WRITER_CLASS(self.stream, schema)
if self.writer_batch_size is None:
from .io.parquet import get_writer_batch_size

self.writer_batch_size = get_writer_batch_size(self._features) or config.DEFAULT_MAX_BATCH_SIZE

@property
def schema(self):
Expand Down
3 changes: 2 additions & 1 deletion src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,8 @@ def incomplete_dir(dirname):
# Sync info
self.info.dataset_size = sum(split.num_bytes for split in self.info.splits.values())
self.info.download_checksums = dl_manager.get_recorded_sizes_checksums()
self.info.size_in_bytes = self.info.dataset_size + self.info.download_size
if self.info.download_size is not None:
self.info.size_in_bytes = self.info.dataset_size + self.info.download_size
# Save info
self._save_info()

Expand Down
2 changes: 2 additions & 0 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
IS_MP3_SUPPORTED = importlib.util.find_spec("soundfile") is not None and version.parse(
importlib.import_module("soundfile").__libsndfile_version__
) >= version.parse("1.1.0")
DECORD_AVAILABLE = importlib.util.find_spec("decord") is not None

# Optional compression tools
RARFILE_AVAILABLE = importlib.util.find_spec("rarfile") is not None
Expand Down Expand Up @@ -192,6 +193,7 @@
PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS = 100
PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS = 100
PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS = 100
PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS = 10

# Offline mode
_offline = os.environ.get("HF_DATASETS_OFFLINE")
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,9 +1230,9 @@ def save_to_disk(
"""
Saves a dataset dict to a filesystem using `fsspec.spec.AbstractFileSystem`.
For [`Image`] and [`Audio`] data:
For [`Image`], [`Audio`] and [`Video`] data:
All the Image() and Audio() data are stored in the arrow files.
All the Image(), Audio() and Video() data are stored in the arrow files.
If you want to store paths or urls, please use the Value("string") type.
Args:
Expand Down
8 changes: 8 additions & 0 deletions src/datasets/download/streaming_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(
self._data_dir = data_dir
self._base_path = base_path or os.path.abspath(".")
self.download_config = download_config or DownloadConfig()
self.downloaded_size = None
self.record_checksums = False

@property
def manual_dir(self):
Expand Down Expand Up @@ -208,3 +210,9 @@ def iter_files(self, urlpaths: Union[str, List[str]]) -> Iterable[str]:
```
"""
return FilesIterable.from_urlpaths(urlpaths, download_config=self.download_config)

def manage_extracted_files(self):
pass

def get_recorded_sizes_checksums(self):
pass
2 changes: 2 additions & 0 deletions src/datasets/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
"Image",
"Translation",
"TranslationVariableLanguages",
"Video",
]
from .audio import Audio
from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, Sequence, Value
from .image import Image
from .translation import Translation, TranslationVariableLanguages
from .video import Video
7 changes: 5 additions & 2 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .audio import Audio
from .image import Image, encode_pil_image
from .translation import Translation, TranslationVariableLanguages
from .video import Video


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -1202,6 +1203,7 @@ class LargeList:
Array5D,
Audio,
Image,
Video,
]


Expand Down Expand Up @@ -1346,7 +1348,7 @@ def encode_nested_example(schema, obj, level=0):
return list(obj)
# Object with special encoding:
# ClassLabel will convert from string to int, TranslationVariableLanguages does some checks
elif isinstance(schema, (Audio, Image, ClassLabel, TranslationVariableLanguages, Value, _ArrayXD)):
elif isinstance(schema, (Audio, Image, ClassLabel, TranslationVariableLanguages, Value, _ArrayXD, Video)):
return schema.encode_example(obj) if obj is not None else None
# Other object should be directly convertible to a native Arrow type (like Translation and Translation)
return obj
Expand Down Expand Up @@ -1397,7 +1399,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni
else:
return decode_nested_example([schema.feature], obj)
# Object with special decoding:
elif isinstance(schema, (Audio, Image)):
elif isinstance(schema, (Audio, Image, Video)):
# we pass the token to read and decode files from private repositories in streaming mode
if obj is not None and schema.decode:
return schema.decode_example(obj, token_per_repo_id=token_per_repo_id)
Expand All @@ -1417,6 +1419,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni
Array5D.__name__: Array5D,
Audio.__name__: Audio,
Image.__name__: Image,
Video.__name__: Video,
}


Expand Down
Loading

0 comments on commit 8235fdb

Please sign in to comment.