Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mod: added option to recursively walk tree to find all files #1179

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 26 additions & 30 deletions torchdata/datapipes/iter/load/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,19 @@ class FSSpecFileListerIterDataPipe(IterDataPipe[str]):
Args:
root: The root `fsspec` path directory or list of path directories to list files from
masks: Unix style filter string or string list for filtering file name(s)
recursive: If True, recursively traverse the directory tree. If False, list files only in the root directory.
kwargs: Extra options that make sense to a particular storage connection,
e.g. host, port, username, password, etc.

Example:

.. testsetup::

dir_path = "path"

.. testcode::

from torchdata.datapipes.iter import FSSpecFileLister

datapipe = FSSpecFileLister(root=dir_path)
>>> from torchdata.datapipes.iter import FSSpecFileLister
>>> datapipe = FSSpecFileLister(root=dir_path, recursive=True)
"""

def __init__(
self,
root: Union[str, Sequence[str], IterDataPipe],
recursive: bool = False,
masks: Union[str, List[str]] = "",
**kwargs,
) -> None:
Expand All @@ -77,6 +71,7 @@ def __init__(
self.datapipe = root
self.masks = masks
self.kwargs_for_connection = kwargs
self.recursive = recursive

def __iter__(self) -> Iterator[str]:
for root in self.datapipe:
Expand All @@ -92,33 +87,34 @@ def __iter__(self) -> Iterator[str]:
protocol_list.append("az")

is_local = fs.protocol == "file" or not any(root.startswith(protocol) for protocol in protocol_list)
if fs.isfile(path):
yield root

if self.recursive:
for current_path, dirs, files in fs.walk(path):
for file_name in files:

abs_path = os.path.join(current_path, file_name) if is_local else posixpath.join(current_path, file_name)
if not match_masks(abs_path, self.masks):
continue

if any(file_name.startswith(protocol) for protocol in protocol_list):
yield file_name
elif root.startswith(tuple(protocol_list)):
yield protocol_list[0] + "://" + abs_path
else:
yield abs_path
else:
for file_name in fs.ls(path, detail=False): # Ensure it returns List[str], not List[Dict]
if not match_masks(file_name, self.masks):
abs_path = os.path.join(path, file_name) if is_local else posixpath.join(path, file_name)

if not match_masks(abs_path, self.masks):
continue

# ensure the file name has the full fsspec protocol path
if any(file_name.startswith(protocol) for protocol in protocol_list):
yield file_name
elif root.startswith(tuple(protocol_list)):
yield protocol_list[0] + "://" + abs_path
else:
if is_local:
abs_path = os.path.join(path, file_name)
elif not file_name.startswith(path):
abs_path = posixpath.join(path, file_name)
else:
abs_path = file_name

starts_with = False
for protocol in protocol_list:
if root.startswith(protocol):
starts_with = True
yield protocol + "://" + abs_path
break

if not starts_with:
yield abs_path
yield abs_path


@functional_datapipe("open_files_by_fsspec")
Expand Down