diff --git a/torchdata/dataloader2/adapter.py b/torchdata/dataloader2/adapter.py index f155feb2b..8f7cb2800 100644 --- a/torchdata/dataloader2/adapter.py +++ b/torchdata/dataloader2/adapter.py @@ -67,7 +67,6 @@ class Shuffle(Adapter): dp = IterableWrapper(range(size)).shuffle() dl = DataLoader2(dp, [Shuffle(False)]) assert list(range(size)) == list(dl) - """ def __init__(self, enable=True): @@ -86,7 +85,19 @@ class CacheTimeout(Adapter): timeout: int - amount of seconds parallel processes will wait for cached files to appear. Example: - >>> dl = DataLoader2(dp, [CacheTimeout(600)]) + + .. testsetup:: + + from torchdata.datapipes.iter import IterableWrapper + from torchdata.dataloader2 import DataLoader2 + from torchdata.dataloader2.adapter import CacheTimeout + + size = 12 + + .. testcode:: + + dp = IterableWrapper(range(size)).shuffle() + dl = DataLoader2(dp, [CacheTimeout(600)]) """ def __init__(self, timeout=None): diff --git a/torchdata/datapipes/iter/load/fsspec.py b/torchdata/datapipes/iter/load/fsspec.py index d012d8215..39a875a94 100644 --- a/torchdata/datapipes/iter/load/fsspec.py +++ b/torchdata/datapipes/iter/load/fsspec.py @@ -47,8 +47,16 @@ class FSSpecFileListerIterDataPipe(IterDataPipe[str]): e.g. host, port, username, password, etc. Example: - >>> from torchdata.datapipes.iter import FSSpecFileLister - >>> datapipe = FSSpecFileLister(root=dir_path) + + .. testsetup:: + + dir_path = "path" + + .. testcode:: + + from torchdata.datapipes.iter import FSSpecFileLister + + datapipe = FSSpecFileLister(root=dir_path) """ def __init__( @@ -127,9 +135,17 @@ class FSSpecFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): e.g. host, port, username, password, etc. Example: - >>> from torchdata.datapipes.iter import FSSpecFileLister - >>> datapipe = FSSpecFileLister(root=dir_path) - >>> file_dp = datapipe.open_files_by_fsspec() + + .. testsetup:: + + dir_path = "path" + + .. testcode:: + + from torchdata.datapipes.iter import FSSpecFileLister + + datapipe = FSSpecFileLister(root=dir_path) + file_dp = datapipe.open_files_by_fsspec() """ def __init__( @@ -169,13 +185,31 @@ class FSSpecSaverIterDataPipe(IterDataPipe[str]): Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> def filepath_fn(name: str) -> str: - >>> return dir_path + name - >>> name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"} - >>> source_dp = IterableWrapper(sorted(name_to_data.items())) - >>> fsspec_saver_dp = source_dp.save_by_fsspec(filepath_fn=filepath_fn, mode="wb") - >>> res_file_paths = list(fsspec_saver_dp) + + .. testsetup:: + + file_prefix = "file" + + .. testcode:: + + from torchdata.datapipes.iter import IterableWrapper + + + def filepath_fn(name: str) -> str: + return file_prefix + name + + + name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"} + source_dp = IterableWrapper(sorted(name_to_data.items())) + fsspec_saver_dp = source_dp.save_by_fsspec(filepath_fn=filepath_fn, mode="wb") + res_file_paths = list(fsspec_saver_dp) + + .. testcleanup:: + + import os + + for name in name_to_data.keys(): + os.remove(file_prefix + name) """ def __init__( diff --git a/torchdata/datapipes/iter/load/iopath.py b/torchdata/datapipes/iter/load/iopath.py index 3a654cf83..445b019a7 100644 --- a/torchdata/datapipes/iter/load/iopath.py +++ b/torchdata/datapipes/iter/load/iopath.py @@ -55,8 +55,16 @@ class IoPathFileListerIterDataPipe(IterDataPipe[str]): S3 URL is supported only with ``iopath``>=0.1.9. Example: - >>> from torchdata.datapipes.iter import IoPathFileLister - >>> datapipe = IoPathFileLister(root=S3URL) + + .. testsetup:: + + s3_url = "path" + + .. testcode:: + + from torchdata.datapipes.iter import IoPathFileLister + + datapipe = IoPathFileLister(root=s3_url) """ def __init__( @@ -113,9 +121,17 @@ class IoPathFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): S3 URL is supported only with `iopath`>=0.1.9. Example: - >>> from torchdata.datapipes.iter import IoPathFileLister - >>> datapipe = IoPathFileLister(root=S3URL) - >>> file_dp = datapipe.open_files_by_iopath() + + .. testsetup:: + + s3_url = "path" + + .. testcode:: + + from torchdata.datapipes.iter import IoPathFileLister + + datapipe = IoPathFileLister(root=s3_url) + file_dp = datapipe.open_files_by_iopath() """ def __init__(self, source_datapipe: IterDataPipe[str], mode: str = "r", pathmgr=None) -> None: @@ -161,13 +177,31 @@ class IoPathSaverIterDataPipe(IterDataPipe[str]): S3 URL is supported only with `iopath`>=0.1.9. Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> def filepath_fn(name: str) -> str: - >>> return S3URL + name - >>> name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"} - >>> source_dp = IterableWrapper(sorted(name_to_data.items())) - >>> iopath_saver_dp = source_dp.save_by_iopath(filepath_fn=filepath_fn, mode="wb") - >>> res_file_paths = list(iopath_saver_dp) + + .. testsetup:: + + s3_url = "url" + + .. testcode:: + + from torchdata.datapipes.iter import IterableWrapper + + + def filepath_fn(name: str) -> str: + return s3_url + name + + + name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"} + source_dp = IterableWrapper(sorted(name_to_data.items())) + iopath_saver_dp = source_dp.save_by_iopath(filepath_fn=filepath_fn, mode="wb") + res_file_paths = list(iopath_saver_dp) + + .. testcleanup:: + + import os + + for file in ["1.txt", "1.txt.lock", "2.txt", "2.txt.lock", "3.txt", "3.txt.lock"]: + os.remove(s3_url + file) """ def __init__( diff --git a/torchdata/datapipes/iter/load/online.py b/torchdata/datapipes/iter/load/online.py index dbbe8c916..2b34fafeb 100644 --- a/torchdata/datapipes/iter/load/online.py +++ b/torchdata/datapipes/iter/load/online.py @@ -54,18 +54,25 @@ class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): **kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/ Example: - >>> from torchdata.datapipes.iter import IterableWrapper, HttpReader - >>> file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE" - >>> query_params = {"auth" : ("fake_username", "fake_password"), "allow_redirects" : True} - >>> timeout = 120 - >>> http_reader_dp = HttpReader(IterableWrapper([file_url]), timeout=timeout, query_params) - >>> reader_dp = http_reader_dp.readlines() - >>> it = iter(reader_dp) - >>> path, line = next(it) - >>> path - https://raw.githubusercontent.com/pytorch/data/main/LICENSE - >>> line - b'BSD 3-Clause License' + + .. testcode:: + + from torchdata.datapipes.iter import IterableWrapper, HttpReader + + file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE" + query_params = {"auth" : ("fake_username", "fake_password"), "allow_redirects" : True} + timeout = 120 + http_reader_dp = HttpReader(IterableWrapper([file_url]), timeout=timeout, **query_params) + reader_dp = http_reader_dp.readlines() + it = iter(reader_dp) + path, line = next(it) + print((path, line)) + + Output: + + .. testoutput:: + + ('https://raw.githubusercontent.com/pytorch/data/main/LICENSE', b'BSD 3-Clause License') """ def __init__( @@ -154,16 +161,31 @@ class GDriveReaderDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): **kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/ Example: - >>> from torchdata.datapipes.iter import IterableWrapper, GDriveReader - >>> gdrive_file_url = "https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile" - >>> gdrive_reader_dp = GDriveReader(IterableWrapper([gdrive_file_url])) - >>> reader_dp = gdrive_reader_dp.readlines() - >>> it = iter(reader_dp) - >>> path, line = next(it) - >>> path - https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile - >>> line - + + .. testsetup:: + + from torchdata.datapipes.iter import GDriveReader + + GDriveReader.readlines = lambda self: [ + ("https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile", b"") + ] + + .. testcode:: + + from torchdata.datapipes.iter import IterableWrapper, GDriveReader + + gdrive_file_url = "https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile" + gdrive_reader_dp = GDriveReader(IterableWrapper([gdrive_file_url])) + reader_dp = gdrive_reader_dp.readlines() + it = iter(reader_dp) + path, line = next(it) + print((path, line)) + + Output: + + .. testoutput:: + + ('https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile', b'') """ source_datapipe: IterDataPipe[str] @@ -207,16 +229,23 @@ class OnlineReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): **kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/ Example: - >>> from torchdata.datapipes.iter import IterableWrapper, OnlineReader - >>> file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE" - >>> online_reader_dp = OnlineReader(IterableWrapper([file_url])) - >>> reader_dp = online_reader_dp.readlines() - >>> it = iter(reader_dp) - >>> path, line = next(it) - >>> path - https://raw.githubusercontent.com/pytorch/data/main/LICENSE - >>> line - b'BSD 3-Clause License' + + .. testcode:: + + from torchdata.datapipes.iter import IterableWrapper, OnlineReader + + file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE" + online_reader_dp = OnlineReader(IterableWrapper([file_url])) + reader_dp = online_reader_dp.readlines() + it = iter(reader_dp) + path, line = next(it) + print((path, line)) + + Output: + + .. testoutput:: + + ('https://raw.githubusercontent.com/pytorch/data/main/LICENSE', b'BSD 3-Clause License') """ source_datapipe: IterDataPipe[str] diff --git a/torchdata/datapipes/iter/load/s3io.py b/torchdata/datapipes/iter/load/s3io.py index 3a17d43ed..24483dbfe 100644 --- a/torchdata/datapipes/iter/load/s3io.py +++ b/torchdata/datapipes/iter/load/s3io.py @@ -40,15 +40,33 @@ class S3FileListerIterDataPipe(IterDataPipe[str]): region: region for access files (inferred from credentials by default) Example: - >>> from torchdata.datapipes.iter import IterableWrapper, S3FileLister - >>> s3_prefixes = IterableWrapper(['s3://bucket-name/folder/', ...]) - >>> dp_s3_urls = S3FileLister(s3_prefixes) - >>> for d in dp_s3_urls: - ... pass + + .. testsetup:: + + from unittest import mock + from torchdata.datapipes.iter import IterableWrapper, S3FileLister + + file_lister_patch = mock.patch.object(S3FileLister, "__iter__", return_value=iter([])) + file_lister_patch.start() + + .. testcode:: + + from torchdata.datapipes.iter import IterableWrapper, S3FileLister + + s3_prefixes = IterableWrapper(['s3://bucket-name/folder/', ...]) + + dp_s3_urls = S3FileLister(s3_prefixes) + for d in dp_s3_urls: + pass + # Functional API - >>> dp_s3_urls = s3_prefixes.list_files_by_s3(request_timeout_ms=100) - >>> for d in dp_s3_urls: - ... pass + dp_s3_urls = s3_prefixes.list_files_by_s3(request_timeout_ms=100) + for d in dp_s3_urls: + pass + + .. testcleanup:: + + file_lister_patch.stop() """ def __init__( @@ -108,20 +126,38 @@ class S3FileLoaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): multi_part_download: flag to split each chunk into small packets and download those packets in parallel (enabled by default) Example: - >>> from torchdata.datapipes.iter import IterableWrapper, S3FileLoader - >>> dp_s3_urls = IterableWrapper(['s3://bucket-name/folder/', ...]).list_files_by_s3() + + .. testsetup:: + + from unittest import mock + from torchdata.datapipes.iter import S3FileLister + + file_lister_patch = mock.patch.object(S3FileLister, "__iter__", return_value=iter([])) + file_lister_patch.start() + + .. testcode:: + + from torchdata.datapipes.iter import IterableWrapper, S3FileLoader + + dp_s3_urls = IterableWrapper(['s3://bucket-name/folder/', ...]).list_files_by_s3() # In order to make sure data are shuffled and sharded in the # distributed environment, `shuffle` and `sharding_filter` # are required. For detail, please check our tutorial in: # https://pytorch.org/data/main/tutorial.html#working-with-dataloader - >>> sharded_s3_urls = dp_s3_urls.shuffle().sharding_filter() - >>> dp_s3_files = S3FileLoader(sharded_s3_urls) - >>> for url, fd in dp_s3_files: # Start loading data - ... data = fd.read() + sharded_s3_urls = dp_s3_urls.shuffle().sharding_filter() + + dp_s3_files = S3FileLoader(sharded_s3_urls) + for url, fd in dp_s3_files: # Start loading data + data = fd.read() + # Functional API - >>> dp_s3_files = sharded_s3_urls.load_files_by_s3(buffer_size=256) - >>> for url, fd in dp_s3_files: - ... data = fd.read() + dp_s3_files = sharded_s3_urls.load_files_by_s3(buffer_size=256) + for url, fd in dp_s3_files: + data = fd.read() + + .. testcleanup:: + + file_lister_patch.stop() """ def __init__( diff --git a/torchdata/datapipes/iter/util/combining.py b/torchdata/datapipes/iter/util/combining.py index 73a959b37..aa8e2df56 100644 --- a/torchdata/datapipes/iter/util/combining.py +++ b/torchdata/datapipes/iter/util/combining.py @@ -191,13 +191,19 @@ class MapKeyZipperIterDataPipe(IterDataPipe[T_co]): from torchdata.datapipes.iter import IterableWrapper from torchdata.datapipes.map import SequenceWrapper + def merge_fn(tuple_from_iter, value_from_map): return tuple_from_iter[0], tuple_from_iter[1] + value_from_map + + dp1 = IterableWrapper([('a', 1), ('b', 2), ('c', 3)]) mapdp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400}) res_dp = dp1.zip_with_map(map_datapipe=mapdp, key_fn=itemgetter(0), merge_fn=merge_fn) + print(list(res_dp)) + Output: + .. testoutput:: [('a', 101), ('b', 202), ('c', 303)]