Skip to content

Commit

Permalink
Migrate docstrings to doctests
Browse files Browse the repository at this point in the history
  • Loading branch information
Mathias Burger committed Mar 21, 2023
1 parent 1472157 commit 7ef8a04
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 75 deletions.
16 changes: 14 additions & 2 deletions torchdata/dataloader2/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ class Shuffle(Adapter):
dp = IterableWrapper(range(size)).shuffle()
dl = DataLoader2(dp, [Shuffle(False)])
assert list(range(size)) == list(dl)
"""

def __init__(self, enable=True):
Expand All @@ -86,7 +85,20 @@ class CacheTimeout(Adapter):
timeout: int - amount of seconds parallel processes will wait for cached files to appear.
Example:
>>> dl = DataLoader2(dp, [CacheTimeout(600)])
.. testsetup::
from torchdata.datapipes.iter import IterableWrapper
from torchdata.dataloader2 import DataLoader2
from torchdata.dataloader2.adapter import CacheTimeout
size = 12
.. testcode::
dp = IterableWrapper(range(size)).shuffle()
dl = DataLoader2(dp, [CacheTimeout(600)])
"""

def __init__(self, timeout=None):
Expand Down
59 changes: 47 additions & 12 deletions torchdata/datapipes/iter/load/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,16 @@ class FSSpecFileListerIterDataPipe(IterDataPipe[str]):
e.g. host, port, username, password, etc.
Example:
>>> from torchdata.datapipes.iter import FSSpecFileLister
>>> datapipe = FSSpecFileLister(root=dir_path)
.. testsetup::
dir_path = "path"
.. testcode::
from torchdata.datapipes.iter import FSSpecFileLister
datapipe = FSSpecFileLister(root=dir_path)
"""

def __init__(
Expand Down Expand Up @@ -127,9 +135,17 @@ class FSSpecFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
e.g. host, port, username, password, etc.
Example:
>>> from torchdata.datapipes.iter import FSSpecFileLister
>>> datapipe = FSSpecFileLister(root=dir_path)
>>> file_dp = datapipe.open_files_by_fsspec()
.. testsetup::
dir_path = "path"
.. testcode::
from torchdata.datapipes.iter import FSSpecFileLister
datapipe = FSSpecFileLister(root=dir_path)
file_dp = datapipe.open_files_by_fsspec()
"""

def __init__(
Expand Down Expand Up @@ -169,13 +185,32 @@ class FSSpecSaverIterDataPipe(IterDataPipe[str]):
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> def filepath_fn(name: str) -> str:
>>> return dir_path + name
>>> name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
>>> source_dp = IterableWrapper(sorted(name_to_data.items()))
>>> fsspec_saver_dp = source_dp.save_by_fsspec(filepath_fn=filepath_fn, mode="wb")
>>> res_file_paths = list(fsspec_saver_dp)
.. testsetup::
file_prefix = "file"
.. testcode::
from torchdata.datapipes.iter import IterableWrapper
def filepath_fn(name: str) -> str:
return file_prefix + name
name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
source_dp = IterableWrapper(sorted(name_to_data.items()))
fsspec_saver_dp = source_dp.save_by_fsspec(filepath_fn=filepath_fn, mode="wb")
res_file_paths = list(fsspec_saver_dp)
.. testcleanup::
import os
os.remove(file_prefix + "1.txt")
os.remove(file_prefix + "2.txt")
os.remove(file_prefix + "3.txt")
"""

def __init__(
Expand Down
58 changes: 46 additions & 12 deletions torchdata/datapipes/iter/load/iopath.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,16 @@ class IoPathFileListerIterDataPipe(IterDataPipe[str]):
S3 URL is supported only with ``iopath``>=0.1.9.
Example:
>>> from torchdata.datapipes.iter import IoPathFileLister
>>> datapipe = IoPathFileLister(root=S3URL)
.. testsetup::
s3_url = "path"
.. testcode::
from torchdata.datapipes.iter import IoPathFileLister
datapipe = IoPathFileLister(root=s3_url)
"""

def __init__(
Expand Down Expand Up @@ -113,9 +121,17 @@ class IoPathFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
S3 URL is supported only with `iopath`>=0.1.9.
Example:
>>> from torchdata.datapipes.iter import IoPathFileLister
>>> datapipe = IoPathFileLister(root=S3URL)
>>> file_dp = datapipe.open_files_by_iopath()
.. testsetup::
s3_url = "path"
.. testcode::
from torchdata.datapipes.iter import IoPathFileLister
datapipe = IoPathFileLister(root=s3_url)
file_dp = datapipe.open_files_by_iopath()
"""

def __init__(self, source_datapipe: IterDataPipe[str], mode: str = "r", pathmgr=None) -> None:
Expand Down Expand Up @@ -161,13 +177,31 @@ class IoPathSaverIterDataPipe(IterDataPipe[str]):
S3 URL is supported only with `iopath`>=0.1.9.
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> def filepath_fn(name: str) -> str:
>>> return S3URL + name
>>> name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
>>> source_dp = IterableWrapper(sorted(name_to_data.items()))
>>> iopath_saver_dp = source_dp.save_by_iopath(filepath_fn=filepath_fn, mode="wb")
>>> res_file_paths = list(iopath_saver_dp)
.. testsetup::
s3_url = "url"
.. testcode::
from torchdata.datapipes.iter import IterableWrapper
def filepath_fn(name: str) -> str:
return s3_url + name
name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
source_dp = IterableWrapper(sorted(name_to_data.items()))
iopath_saver_dp = source_dp.save_by_iopath(filepath_fn=filepath_fn, mode="wb")
res_file_paths = list(iopath_saver_dp)
.. testcleanup::
import os
for file in ["1.txt", "1.txt.lock", "2.txt", "2.txt.lock", "3.txt", "3.txt.lock"]:
os.remove(s3_url + file)
"""

def __init__(
Expand Down
93 changes: 61 additions & 32 deletions torchdata/datapipes/iter/load/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,25 @@ class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
**kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/
Example:
>>> from torchdata.datapipes.iter import IterableWrapper, HttpReader
>>> file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
>>> query_params = {"auth" : ("fake_username", "fake_password"), "allow_redirects" : True}
>>> timeout = 120
>>> http_reader_dp = HttpReader(IterableWrapper([file_url]), timeout=timeout, query_params)
>>> reader_dp = http_reader_dp.readlines()
>>> it = iter(reader_dp)
>>> path, line = next(it)
>>> path
https://raw.githubusercontent.com/pytorch/data/main/LICENSE
>>> line
b'BSD 3-Clause License'
.. testcode::
from torchdata.datapipes.iter import IterableWrapper, HttpReader
file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
query_params = {"auth" : ("fake_username", "fake_password"), "allow_redirects" : True}
timeout = 120
http_reader_dp = HttpReader(IterableWrapper([file_url]), timeout=timeout, **query_params)
reader_dp = http_reader_dp.readlines()
it = iter(reader_dp)
path, line = next(it)
print((path, line))
Output:
.. testoutput::
('https://raw.githubusercontent.com/pytorch/data/main/LICENSE', b'BSD 3-Clause License')
"""

def __init__(
Expand Down Expand Up @@ -154,16 +161,31 @@ class GDriveReaderDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
**kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/
Example:
>>> from torchdata.datapipes.iter import IterableWrapper, GDriveReader
>>> gdrive_file_url = "https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile"
>>> gdrive_reader_dp = GDriveReader(IterableWrapper([gdrive_file_url]))
>>> reader_dp = gdrive_reader_dp.readlines()
>>> it = iter(reader_dp)
>>> path, line = next(it)
>>> path
https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile
>>> line
<First line from the GDrive File>
.. testsetup::
from torchdata.datapipes.iter import GDriveReader
GDriveReader.readlines = lambda self: [
("https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile", b"<First line from the GDrive File>")
]
.. testcode::
from torchdata.datapipes.iter import IterableWrapper, GDriveReader
gdrive_file_url = "https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile"
gdrive_reader_dp = GDriveReader(IterableWrapper([gdrive_file_url]))
reader_dp = gdrive_reader_dp.readlines()
it = iter(reader_dp)
path, line = next(it)
print((path, line))
Output:
.. testoutput::
('https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile', b'<First line from the GDrive File>')
"""
source_datapipe: IterDataPipe[str]

Expand Down Expand Up @@ -207,16 +229,23 @@ class OnlineReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
**kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/
Example:
>>> from torchdata.datapipes.iter import IterableWrapper, OnlineReader
>>> file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
>>> online_reader_dp = OnlineReader(IterableWrapper([file_url]))
>>> reader_dp = online_reader_dp.readlines()
>>> it = iter(reader_dp)
>>> path, line = next(it)
>>> path
https://raw.githubusercontent.com/pytorch/data/main/LICENSE
>>> line
b'BSD 3-Clause License'
.. testcode::
from torchdata.datapipes.iter import IterableWrapper, OnlineReader
file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
online_reader_dp = OnlineReader(IterableWrapper([file_url]))
reader_dp = online_reader_dp.readlines()
it = iter(reader_dp)
path, line = next(it)
print((path, line))
Output:
.. testoutput::
('https://raw.githubusercontent.com/pytorch/data/main/LICENSE', b'BSD 3-Clause License')
"""
source_datapipe: IterDataPipe[str]

Expand Down
58 changes: 41 additions & 17 deletions torchdata/datapipes/iter/load/s3io.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,27 @@ class S3FileListerIterDataPipe(IterDataPipe[str]):
region: region for access files (inferred from credentials by default)
Example:
>>> from torchdata.datapipes.iter import IterableWrapper, S3FileLister
>>> s3_prefixes = IterableWrapper(['s3://bucket-name/folder/', ...])
>>> dp_s3_urls = S3FileLister(s3_prefixes)
>>> for d in dp_s3_urls:
... pass
.. testsetup::
from torchdata.datapipes.iter import IterableWrapper, S3FileLister
S3FileLister.__iter__ = lambda self: iter([])
.. testcode::
from torchdata.datapipes.iter import IterableWrapper, S3FileLister
s3_prefixes = IterableWrapper(['s3://bucket-name/folder/', ...])
dp_s3_urls = S3FileLister(s3_prefixes)
for d in dp_s3_urls:
pass
# Functional API
>>> dp_s3_urls = s3_prefixes.list_files_by_s3(request_timeout_ms=100)
>>> for d in dp_s3_urls:
... pass
dp_s3_urls = s3_prefixes.list_files_by_s3(request_timeout_ms=100)
for d in dp_s3_urls:
pass
"""

def __init__(
Expand Down Expand Up @@ -108,20 +120,32 @@ class S3FileLoaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
multi_part_download: flag to split each chunk into small packets and download those packets in parallel (enabled by default)
Example:
>>> from torchdata.datapipes.iter import IterableWrapper, S3FileLoader
>>> dp_s3_urls = IterableWrapper(['s3://bucket-name/folder/', ...]).list_files_by_s3()
.. testsetup::
from torchdata.datapipes.iter import S3FileLoader
S3FileLoader.__iter__ = lambda self: iter([])
.. testcode::
from torchdata.datapipes.iter import IterableWrapper, S3FileLoader
dp_s3_urls = IterableWrapper(['s3://bucket-name/folder/', ...]).list_files_by_s3()
# In order to make sure data are shuffled and sharded in the
# distributed environment, `shuffle` and `sharding_filter`
# are required. For detail, please check our tutorial in:
# https://pytorch.org/data/main/tutorial.html#working-with-dataloader
>>> sharded_s3_urls = dp_s3_urls.shuffle().sharding_filter()
>>> dp_s3_files = S3FileLoader(sharded_s3_urls)
>>> for url, fd in dp_s3_files: # Start loading data
... data = fd.read()
sharded_s3_urls = dp_s3_urls.shuffle().sharding_filter()
dp_s3_files = S3FileLoader(sharded_s3_urls)
for url, fd in dp_s3_files: # Start loading data
data = fd.read()
# Functional API
>>> dp_s3_files = sharded_s3_urls.load_files_by_s3(buffer_size=256)
>>> for url, fd in dp_s3_files:
... data = fd.read()
dp_s3_files = sharded_s3_urls.load_files_by_s3(buffer_size=256)
for url, fd in dp_s3_files:
data = fd.read()
"""

def __init__(
Expand Down
Loading

0 comments on commit 7ef8a04

Please sign in to comment.