Skip to content

Commit

Permalink
Docs and README update (#396)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Nov 20, 2024
1 parent c8f5deb commit 339cb27
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 69 deletions.
94 changes: 53 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

# TorchCodec

TorchCodec is a Python library for decoding videos into PyTorch tensors. It aims
to be fast, easy to use, and well integrated into the PyTorch ecosystem. If you
want to use PyTorch to train ML models on videos, TorchCodec is how you turn
those videos into data.
TorchCodec is a Python library for decoding videos into PyTorch tensors, on CPU
and CUDA GPU. It aims to be fast, easy to use, and well integrated into the
PyTorch ecosystem. If you want to use PyTorch to train ML models on videos,
TorchCodec is how you turn those videos into data.

We achieve these capabilities through:

Expand All @@ -19,21 +19,24 @@ We achieve these capabilities through:
or used directly to train models.

> [!NOTE]
> ⚠️ TorchCodec is still in early development stage and some APIs may be updated
> in future versions without a deprecation cycle, depending on user feedback.
> ⚠️ TorchCodec is still in development stage and some APIs may be updated
> in future versions, depending on user feedback.
> If you have any suggestions or issues, please let us know by
> [opening an issue](https://github.com/pytorch/torchcodec/issues/new/choose)!
## Using TorchCodec

Here's a condensed summary of what you can do with TorchCodec. For a more
detailed example, [check out our
Here's a condensed summary of what you can do with TorchCodec. For more detailed
examples, [check out our
documentation](https://pytorch.org/torchcodec/stable/generated_examples/)!

#### Decoding

```python
from torchcodec.decoders import VideoDecoder

decoder = VideoDecoder("path/to/video.mp4")
device = "cpu" # or e.g. "cuda" !
decoder = VideoDecoder("path/to/video.mp4", device=device)

decoder.metadata
# VideoStreamMetadata:
Expand All @@ -44,39 +47,47 @@ decoder.metadata
# average_fps: 25.0
# ... (truncated output)

len(decoder) # == decoder.metadata.num_frames!
# 250
decoder.metadata.average_fps # Note: instantaneous fps can be higher or lower
# 25.0

# Simple Indexing API
decoder[0] # uint8 tensor of shape [C, H, W]
decoder[0 : -1 : 20] # uint8 stacked tensor of shape [N, C, H, W]

# Indexing, with PTS and duration info:
decoder.get_frames_at(indices=[2, 100])
# FrameBatch:
# data (shape): torch.Size([2, 3, 270, 480])
# pts_seconds: tensor([0.0667, 3.3367], dtype=torch.float64)
# duration_seconds: tensor([0.0334, 0.0334], dtype=torch.float64)

# Iterate over frames:
for frame in decoder:
pass
# Time-based indexing with PTS and duration info
decoder.get_frames_played_at(seconds=[0.5, 10.4])
# FrameBatch:
# data (shape): torch.Size([2, 3, 270, 480])
# pts_seconds: tensor([ 0.4671, 10.3770], dtype=torch.float64)
# duration_seconds: tensor([0.0334, 0.0334], dtype=torch.float64)
```

# Indexing, with PTS and duration info
decoder.get_frame_at(len(decoder) - 1)
# Frame:
# data (shape): torch.Size([3, 400, 640])
# pts_seconds: 9.960000038146973
# duration_seconds: 0.03999999910593033
#### Clip sampling

decoder.get_frames_in_range(start=10, stop=30, step=5)
# FrameBatch:
# data (shape): torch.Size([4, 3, 400, 640])
# pts_seconds: tensor([0.4000, 0.6000, 0.8000, 1.0000])
# duration_seconds: tensor([0.0400, 0.0400, 0.0400, 0.0400])
```python

# Time-based indexing with PTS and duration info
decoder.get_frame_played_at(pts_seconds=2)
# Frame:
# data (shape): torch.Size([3, 400, 640])
# pts_seconds: 2.0
# duration_seconds: 0.03999999910593033
from torchcodec.samplers import clips_at_regular_timestamps

clips_at_regular_timestamps(
decoder,
seconds_between_clip_starts=1.5,
num_frames_per_clip=4,
seconds_between_frames=0.1
)
# FrameBatch:
# data (shape): torch.Size([9, 4, 3, 270, 480])
# pts_seconds: tensor([[ 0.0000, 0.0667, 0.1668, 0.2669],
# [ 1.4681, 1.5682, 1.6683, 1.7684],
# [ 2.9696, 3.0697, 3.1698, 3.2699],
# ... (truncated), dtype=torch.float64)
# duration_seconds: tensor([[0.0334, 0.0334, 0.0334, 0.0334],
# [0.0334, 0.0334, 0.0334, 0.0334],
# [0.0334, 0.0334, 0.0334, 0.0334],
# ... (truncated), dtype=torch.float64)
```

You can use the following snippet to generate a video with FFmpeg and tryout
Expand Down Expand Up @@ -142,7 +153,7 @@ format you want. Refer to Nvidia's GPU support matrix for more details
[official instructions](https://pytorch.org/get-started/locally/).

3. Install or compile FFmpeg with NVDEC support.
TorchCodec with CUDA should work with FFmpeg versions in [4, 7].
TorchCodec with CUDA should work with FFmpeg versions in [5, 7].

If FFmpeg is not already installed, or you need a more recent version, an
easy way to install it is to use `conda`:
Expand Down Expand Up @@ -172,16 +183,17 @@ format you want. Refer to Nvidia's GPU support matrix for more details
ffmpeg -hwaccel cuda -hwaccel_output_format cuda -i test/resources/nasa_13013.mp4 -f null -
```

4. Install TorchCodec by passing in an `--index-url` parameter that corresponds to your CUDA
Toolkit version, example:
4. Install TorchCodec by passing in an `--index-url` parameter that corresponds
to your CUDA Toolkit version, example:

```bash
# This corresponds to CUDA Toolkit version 12.4 and nightly Pytorch.
pip install torchcodec --index-url=https://download.pytorch.org/whl/nightly/cu124
# This corresponds to CUDA Toolkit version 12.4. It should be the same one
# you used when you installed PyTorch (If you installed PyTorch with pip).
pip install torchcodec --index-url=https://download.pytorch.org/whl/cu124
```

Note that without passing in the `--index-url` parameter, `pip` installs TorchCodec
binaries from PyPi which are CPU-only and do not have CUDA support.
Note that without passing in the `--index-url` parameter, `pip` installs
the CPU-only version of TorchCodec.

## Benchmark Results

Expand Down
26 changes: 13 additions & 13 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
Welcome to the TorchCodec documentation!
========================================

TorchCodec is a Python library for decoding videos into PyTorch tensors. It aims
to be fast, easy to use, and well integrated into the PyTorch ecosystem. If you
want to use PyTorch to train ML models on videos, TorchCodec is how you turn
those videos into data.
TorchCodec is a Python library for decoding videos into PyTorch tensors, on CPU
and CUDA GPU. It aims to be fast, easy to use, and well integrated into the
PyTorch ecosystem. If you want to use PyTorch to train ML models on videos,
TorchCodec is how you turn those videos into data.

We achieve these capabilities through:

Expand All @@ -17,13 +17,6 @@ We achieve these capabilities through:
* Returning data as PyTorch tensors, ready to be fed into PyTorch transforms
or used directly to train models.

.. note::

TorchCodec is still in early development stage and we are actively seeking
feedback. If you have any suggestions or issues, please let us know by
`opening an issue <https://github.com/pytorch/torchcodec/issues/new/choose>`_
on our `GitHub repository <https://github.com/pytorch/torchcodec/>`_.

.. grid:: 3

.. grid-item-card:: :octicon:`file-code;1em`
Expand All @@ -48,16 +41,23 @@ We achieve these capabilities through:
:link: generated_examples/sampling.html
:link-type: url

How to sample video clips
How to sample regular and random clips from a video

.. grid-item-card:: :octicon:`file-code;1em`
GPU decoding using TorchCodec
GPU decoding
:img-top: _static/img/card-background.svg
:link: generated_examples/basic_cuda_example.html
:link-type: url

A simple example demonstrating CUDA GPU decoding

.. note::

TorchCodec is still in development stage and we are actively seeking
feedback. If you have any suggestions or issues, please let us know by
`opening an issue <https://github.com/pytorch/torchcodec/issues/new/choose>`_
on our `GitHub repository <https://github.com/pytorch/torchcodec/>`_.

.. toctree::
:maxdepth: 1
:caption: TorchCodec documentation
Expand Down
27 changes: 12 additions & 15 deletions examples/basic_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,22 +120,22 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
# their :term:`pts` (Presentation Time Stamp), and their duration.
# This can be achieved using the
# :meth:`~torchcodec.decoders.VideoDecoder.get_frame_at` and
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_in_range` methods, which
# will return a :class:`~torchcodec.Frame` and
# :class:`~torchcodec.FrameBatch` objects respectively.
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_at` methods, which will
# return a :class:`~torchcodec.Frame` and :class:`~torchcodec.FrameBatch`
# objects respectively.

last_frame = decoder.get_frame_at(len(decoder) - 1)
print(f"{type(last_frame) = }")
print(last_frame)

# %%
middle_frames = decoder.get_frames_in_range(start=10, stop=20, step=2)
print(f"{type(middle_frames) = }")
print(middle_frames)
other_frames = decoder.get_frames_at([10, 0, 50])
print(f"{type(other_frames) = }")
print(other_frames)

# %%
plot(last_frame.data, "Last frame")
plot(middle_frames.data, "Middle frames")
plot(other_frames.data, "Other frames")

# %%
# Both :class:`~torchcodec.Frame` and
Expand All @@ -152,7 +152,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
# So far, we have retrieved frames based on their index. We can also retrieve
# frames based on *when* they are played with
# :meth:`~torchcodec.decoders.VideoDecoder.get_frame_played_at` and
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_played_in_range`, which
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_played_at`, which
# also returns :class:`~torchcodec.Frame` and :class:`~torchcodec.FrameBatch`
# respectively.

Expand All @@ -161,13 +161,10 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
print(frame_at_2_seconds)

# %%
first_two_seconds = decoder.get_frames_played_in_range(
start_seconds=0,
stop_seconds=2,
)
print(f"{type(first_two_seconds) = }")
print(first_two_seconds)
other_frames = decoder.get_frames_played_at(seconds=[10.1, 0.3, 5])
print(f"{type(other_frames) = }")
print(other_frames)

# %%
plot(frame_at_2_seconds.data, "Frame played at 2 seconds")
plot(first_two_seconds.data, "Frames played during [0, 2) seconds")
plot(other_frames.data, "Other frames")

0 comments on commit 339cb27

Please sign in to comment.