diff --git a/torchdata/datapipes/iter/load/fsspec.py b/torchdata/datapipes/iter/load/fsspec.py index 39a875a94..8747b519b 100644 --- a/torchdata/datapipes/iter/load/fsspec.py +++ b/torchdata/datapipes/iter/load/fsspec.py @@ -43,25 +43,19 @@ class FSSpecFileListerIterDataPipe(IterDataPipe[str]): Args: root: The root `fsspec` path directory or list of path directories to list files from masks: Unix style filter string or string list for filtering file name(s) + recursive: If True, recursively traverse the directory tree. If False, list files only in the root directory. kwargs: Extra options that make sense to a particular storage connection, e.g. host, port, username, password, etc. Example: - - .. testsetup:: - - dir_path = "path" - - .. testcode:: - - from torchdata.datapipes.iter import FSSpecFileLister - - datapipe = FSSpecFileLister(root=dir_path) + >>> from torchdata.datapipes.iter import FSSpecFileLister + >>> datapipe = FSSpecFileLister(root=dir_path, recursive=True) """ def __init__( self, root: Union[str, Sequence[str], IterDataPipe], + recursive: bool = False, masks: Union[str, List[str]] = "", **kwargs, ) -> None: @@ -77,6 +71,7 @@ def __init__( self.datapipe = root self.masks = masks self.kwargs_for_connection = kwargs + self.recursive = recursive def __iter__(self) -> Iterator[str]: for root in self.datapipe: @@ -92,33 +87,34 @@ def __iter__(self) -> Iterator[str]: protocol_list.append("az") is_local = fs.protocol == "file" or not any(root.startswith(protocol) for protocol in protocol_list) - if fs.isfile(path): - yield root + + if self.recursive: + for current_path, dirs, files in fs.walk(path): + for file_name in files: + + abs_path = os.path.join(current_path, file_name) if is_local else posixpath.join(current_path, file_name) + if not match_masks(abs_path, self.masks): + continue + + if any(file_name.startswith(protocol) for protocol in protocol_list): + yield file_name + elif root.startswith(tuple(protocol_list)): + yield protocol_list[0] + "://" + abs_path + else: + yield abs_path else: for file_name in fs.ls(path, detail=False): # Ensure it returns List[str], not List[Dict] - if not match_masks(file_name, self.masks): + abs_path = os.path.join(path, file_name) if is_local else posixpath.join(path, file_name) + + if not match_masks(abs_path, self.masks): continue - # ensure the file name has the full fsspec protocol path if any(file_name.startswith(protocol) for protocol in protocol_list): yield file_name + elif root.startswith(tuple(protocol_list)): + yield protocol_list[0] + "://" + abs_path else: - if is_local: - abs_path = os.path.join(path, file_name) - elif not file_name.startswith(path): - abs_path = posixpath.join(path, file_name) - else: - abs_path = file_name - - starts_with = False - for protocol in protocol_list: - if root.startswith(protocol): - starts_with = True - yield protocol + "://" + abs_path - break - - if not starts_with: - yield abs_path + yield abs_path @functional_datapipe("open_files_by_fsspec")