Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate docstrings to doctests #901

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions torchdata/dataloader2/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ class Shuffle(Adapter):
dp = IterableWrapper(range(size)).shuffle()
dl = DataLoader2(dp, [Shuffle(False)])
assert list(range(size)) == list(dl)

"""

def __init__(self, enable=True):
Expand All @@ -86,7 +85,19 @@ class CacheTimeout(Adapter):
timeout: int - amount of seconds parallel processes will wait for cached files to appear.

Example:
>>> dl = DataLoader2(dp, [CacheTimeout(600)])

.. testsetup::

from torchdata.datapipes.iter import IterableWrapper
from torchdata.dataloader2 import DataLoader2
from torchdata.dataloader2.adapter import CacheTimeout

size = 12

.. testcode::

dp = IterableWrapper(range(size)).shuffle()
dl = DataLoader2(dp, [CacheTimeout(600)])
"""

def __init__(self, timeout=None):
Expand Down
58 changes: 46 additions & 12 deletions torchdata/datapipes/iter/load/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,16 @@ class FSSpecFileListerIterDataPipe(IterDataPipe[str]):
e.g. host, port, username, password, etc.

Example:
>>> from torchdata.datapipes.iter import FSSpecFileLister
>>> datapipe = FSSpecFileLister(root=dir_path)

.. testsetup::

dir_path = "path"

.. testcode::

from torchdata.datapipes.iter import FSSpecFileLister

datapipe = FSSpecFileLister(root=dir_path)
"""

def __init__(
Expand Down Expand Up @@ -127,9 +135,17 @@ class FSSpecFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
e.g. host, port, username, password, etc.

Example:
>>> from torchdata.datapipes.iter import FSSpecFileLister
>>> datapipe = FSSpecFileLister(root=dir_path)
>>> file_dp = datapipe.open_files_by_fsspec()

.. testsetup::

dir_path = "path"

.. testcode::

from torchdata.datapipes.iter import FSSpecFileLister

datapipe = FSSpecFileLister(root=dir_path)
file_dp = datapipe.open_files_by_fsspec()
"""

def __init__(
Expand Down Expand Up @@ -169,13 +185,31 @@ class FSSpecSaverIterDataPipe(IterDataPipe[str]):


Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> def filepath_fn(name: str) -> str:
>>> return dir_path + name
>>> name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
>>> source_dp = IterableWrapper(sorted(name_to_data.items()))
>>> fsspec_saver_dp = source_dp.save_by_fsspec(filepath_fn=filepath_fn, mode="wb")
>>> res_file_paths = list(fsspec_saver_dp)

.. testsetup::

file_prefix = "file"

.. testcode::

from torchdata.datapipes.iter import IterableWrapper


def filepath_fn(name: str) -> str:
return file_prefix + name


name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
source_dp = IterableWrapper(sorted(name_to_data.items()))
fsspec_saver_dp = source_dp.save_by_fsspec(filepath_fn=filepath_fn, mode="wb")
res_file_paths = list(fsspec_saver_dp)

.. testcleanup::

import os

for name in name_to_data.keys():
os.remove(file_prefix + name)
"""

def __init__(
Expand Down
58 changes: 46 additions & 12 deletions torchdata/datapipes/iter/load/iopath.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,16 @@ class IoPathFileListerIterDataPipe(IterDataPipe[str]):
S3 URL is supported only with ``iopath``>=0.1.9.

Example:
>>> from torchdata.datapipes.iter import IoPathFileLister
>>> datapipe = IoPathFileLister(root=S3URL)

.. testsetup::

s3_url = "path"

.. testcode::

from torchdata.datapipes.iter import IoPathFileLister

datapipe = IoPathFileLister(root=s3_url)
"""

def __init__(
Expand Down Expand Up @@ -113,9 +121,17 @@ class IoPathFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
S3 URL is supported only with `iopath`>=0.1.9.

Example:
>>> from torchdata.datapipes.iter import IoPathFileLister
>>> datapipe = IoPathFileLister(root=S3URL)
>>> file_dp = datapipe.open_files_by_iopath()

.. testsetup::

s3_url = "path"

.. testcode::

from torchdata.datapipes.iter import IoPathFileLister

datapipe = IoPathFileLister(root=s3_url)
file_dp = datapipe.open_files_by_iopath()
"""

def __init__(self, source_datapipe: IterDataPipe[str], mode: str = "r", pathmgr=None) -> None:
Expand Down Expand Up @@ -161,13 +177,31 @@ class IoPathSaverIterDataPipe(IterDataPipe[str]):
S3 URL is supported only with `iopath`>=0.1.9.

Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> def filepath_fn(name: str) -> str:
>>> return S3URL + name
>>> name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
>>> source_dp = IterableWrapper(sorted(name_to_data.items()))
>>> iopath_saver_dp = source_dp.save_by_iopath(filepath_fn=filepath_fn, mode="wb")
>>> res_file_paths = list(iopath_saver_dp)

.. testsetup::

s3_url = "url"

.. testcode::

from torchdata.datapipes.iter import IterableWrapper


def filepath_fn(name: str) -> str:
return s3_url + name


name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
source_dp = IterableWrapper(sorted(name_to_data.items()))
iopath_saver_dp = source_dp.save_by_iopath(filepath_fn=filepath_fn, mode="wb")
res_file_paths = list(iopath_saver_dp)

.. testcleanup::

import os

for file in ["1.txt", "1.txt.lock", "2.txt", "2.txt.lock", "3.txt", "3.txt.lock"]:
os.remove(s3_url + file)
"""

def __init__(
Expand Down
93 changes: 61 additions & 32 deletions torchdata/datapipes/iter/load/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,25 @@ class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
**kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/

Example:
>>> from torchdata.datapipes.iter import IterableWrapper, HttpReader
>>> file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
>>> query_params = {"auth" : ("fake_username", "fake_password"), "allow_redirects" : True}
>>> timeout = 120
>>> http_reader_dp = HttpReader(IterableWrapper([file_url]), timeout=timeout, query_params)
>>> reader_dp = http_reader_dp.readlines()
>>> it = iter(reader_dp)
>>> path, line = next(it)
>>> path
https://raw.githubusercontent.com/pytorch/data/main/LICENSE
>>> line
b'BSD 3-Clause License'

.. testcode::

from torchdata.datapipes.iter import IterableWrapper, HttpReader

file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
query_params = {"auth" : ("fake_username", "fake_password"), "allow_redirects" : True}
timeout = 120
http_reader_dp = HttpReader(IterableWrapper([file_url]), timeout=timeout, **query_params)
reader_dp = http_reader_dp.readlines()
it = iter(reader_dp)
path, line = next(it)
print((path, line))

Output:

.. testoutput::

('https://raw.githubusercontent.com/pytorch/data/main/LICENSE', b'BSD 3-Clause License')
"""

def __init__(
Expand Down Expand Up @@ -154,16 +161,31 @@ class GDriveReaderDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
**kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/

Example:
>>> from torchdata.datapipes.iter import IterableWrapper, GDriveReader
>>> gdrive_file_url = "https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile"
>>> gdrive_reader_dp = GDriveReader(IterableWrapper([gdrive_file_url]))
>>> reader_dp = gdrive_reader_dp.readlines()
>>> it = iter(reader_dp)
>>> path, line = next(it)
>>> path
https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile
>>> line
<First line from the GDrive File>

.. testsetup::

from torchdata.datapipes.iter import GDriveReader

GDriveReader.readlines = lambda self: [
("https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile", b"<First line from the GDrive File>")
]

.. testcode::

from torchdata.datapipes.iter import IterableWrapper, GDriveReader

gdrive_file_url = "https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile"
gdrive_reader_dp = GDriveReader(IterableWrapper([gdrive_file_url]))
reader_dp = gdrive_reader_dp.readlines()
it = iter(reader_dp)
path, line = next(it)
print((path, line))

Output:

.. testoutput::

('https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile', b'<First line from the GDrive File>')
"""
source_datapipe: IterDataPipe[str]

Expand Down Expand Up @@ -207,16 +229,23 @@ class OnlineReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
**kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/

Example:
>>> from torchdata.datapipes.iter import IterableWrapper, OnlineReader
>>> file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
>>> online_reader_dp = OnlineReader(IterableWrapper([file_url]))
>>> reader_dp = online_reader_dp.readlines()
>>> it = iter(reader_dp)
>>> path, line = next(it)
>>> path
https://raw.githubusercontent.com/pytorch/data/main/LICENSE
>>> line
b'BSD 3-Clause License'

.. testcode::

from torchdata.datapipes.iter import IterableWrapper, OnlineReader

file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
online_reader_dp = OnlineReader(IterableWrapper([file_url]))
reader_dp = online_reader_dp.readlines()
it = iter(reader_dp)
path, line = next(it)
print((path, line))

Output:

.. testoutput::

('https://raw.githubusercontent.com/pytorch/data/main/LICENSE', b'BSD 3-Clause License')
"""
source_datapipe: IterDataPipe[str]

Expand Down
70 changes: 53 additions & 17 deletions torchdata/datapipes/iter/load/s3io.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,33 @@ class S3FileListerIterDataPipe(IterDataPipe[str]):
region: region for access files (inferred from credentials by default)

Example:
>>> from torchdata.datapipes.iter import IterableWrapper, S3FileLister
>>> s3_prefixes = IterableWrapper(['s3://bucket-name/folder/', ...])
>>> dp_s3_urls = S3FileLister(s3_prefixes)
>>> for d in dp_s3_urls:
... pass

.. testsetup::

from unittest import mock
from torchdata.datapipes.iter import IterableWrapper, S3FileLister

file_lister_patch = mock.patch.object(S3FileLister, "__iter__", return_value=iter([]))
file_lister_patch.start()

.. testcode::

from torchdata.datapipes.iter import IterableWrapper, S3FileLister

s3_prefixes = IterableWrapper(['s3://bucket-name/folder/', ...])

dp_s3_urls = S3FileLister(s3_prefixes)
for d in dp_s3_urls:
pass

# Functional API
>>> dp_s3_urls = s3_prefixes.list_files_by_s3(request_timeout_ms=100)
>>> for d in dp_s3_urls:
... pass
dp_s3_urls = s3_prefixes.list_files_by_s3(request_timeout_ms=100)
for d in dp_s3_urls:
pass

.. testcleanup::

file_lister_patch.stop()
"""

def __init__(
Expand Down Expand Up @@ -108,20 +126,38 @@ class S3FileLoaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
multi_part_download: flag to split each chunk into small packets and download those packets in parallel (enabled by default)

Example:
>>> from torchdata.datapipes.iter import IterableWrapper, S3FileLoader
>>> dp_s3_urls = IterableWrapper(['s3://bucket-name/folder/', ...]).list_files_by_s3()

.. testsetup::

from unittest import mock
from torchdata.datapipes.iter import S3FileLister

file_lister_patch = mock.patch.object(S3FileLister, "__iter__", return_value=iter([]))
file_lister_patch.start()

.. testcode::

from torchdata.datapipes.iter import IterableWrapper, S3FileLoader

dp_s3_urls = IterableWrapper(['s3://bucket-name/folder/', ...]).list_files_by_s3()
# In order to make sure data are shuffled and sharded in the
# distributed environment, `shuffle` and `sharding_filter`
# are required. For detail, please check our tutorial in:
# https://pytorch.org/data/main/tutorial.html#working-with-dataloader
>>> sharded_s3_urls = dp_s3_urls.shuffle().sharding_filter()
>>> dp_s3_files = S3FileLoader(sharded_s3_urls)
>>> for url, fd in dp_s3_files: # Start loading data
... data = fd.read()
sharded_s3_urls = dp_s3_urls.shuffle().sharding_filter()

dp_s3_files = S3FileLoader(sharded_s3_urls)
for url, fd in dp_s3_files: # Start loading data
data = fd.read()

# Functional API
>>> dp_s3_files = sharded_s3_urls.load_files_by_s3(buffer_size=256)
>>> for url, fd in dp_s3_files:
... data = fd.read()
dp_s3_files = sharded_s3_urls.load_files_by_s3(buffer_size=256)
for url, fd in dp_s3_files:
data = fd.read()

.. testcleanup::

file_lister_patch.stop()
"""

def __init__(
Expand Down
Loading