Skip to content

Commit

Permalink
pull _is_file checks to get_listing (#846)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Jan 24, 2025
1 parent 74ea452 commit 669f359
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/datachain/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def dataset_name(self) -> str:
if self.type == DatasetDependencyType.DATASET:
return self.name

list_dataset_name, _, _ = parse_listing_uri(self.name.strip("/"), None, {})
list_dataset_name, _, _ = parse_listing_uri(self.name.strip("/"), {})
assert list_dataset_name
return list_dataset_name

Expand Down
19 changes: 8 additions & 11 deletions src/datachain/lib/listing.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,12 @@ def _isfile(client: "Client", path: str) -> bool:
return False


def parse_listing_uri(uri: str, cache, client_config) -> tuple[Optional[str], str, str]:
def parse_listing_uri(uri: str, client_config) -> tuple[str, str, str]:
"""
Parsing uri and returns listing dataset name, listing uri and listing path
"""
client_config = client_config or {}
client = Client.get_client(uri, cache, **client_config)
storage_uri, path = Client.parse_url(uri)
telemetry.log_param("client", client.PREFIX)

if not uri.endswith("/") and _isfile(client, uri):
return None, f"{storage_uri}/{path.lstrip('/')}", path
if uses_glob(path):
lst_uri_path = posixpath.dirname(path)
else:
Expand Down Expand Up @@ -157,13 +152,15 @@ def get_listing(
client_config = catalog.client_config

client = Client.get_client(uri, cache, **client_config)
ds_name, list_uri, list_path = parse_listing_uri(uri, cache, client_config)
listing = None
telemetry.log_param("client", client.PREFIX)

# if we don't want to use cached dataset (e.g. for a single file listing)
if not ds_name:
return None, list_uri, list_path, False
# we don't want to use cached dataset (e.g. for a single file listing)
if not uri.endswith("/") and _isfile(client, uri):
storage_uri, path = Client.parse_url(uri)
return None, f"{storage_uri}/{path.lstrip('/')}", path, False

ds_name, list_uri, list_path = parse_listing_uri(uri, client_config)
listing = None
listings = [
ls for ls in catalog.listings() if not ls.is_expired and ls.contains(ds_name)
]
Expand Down
4 changes: 1 addition & 3 deletions tests/func/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@


def listing_stats(uri, catalog):
list_dataset_name, _, _ = parse_listing_uri(
uri, catalog.cache, catalog.client_config
)
list_dataset_name, _, _ = parse_listing_uri(uri, catalog.client_config)
dataset = catalog.get_dataset(list_dataset_name)
return catalog.dataset_stats(dataset.name, dataset.latest_version)

Expand Down
6 changes: 3 additions & 3 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_from_storage_reindex_expired(tmp_dir, test_session):
os.mkdir(tmp_dir)
uri = tmp_dir.as_uri()

lst_ds_name = parse_listing_uri(uri, catalog.cache, catalog.client_config)[0]
lst_ds_name = parse_listing_uri(uri, catalog.client_config)[0]

pd.DataFrame({"name": ["Alice", "Bob"]}).to_parquet(tmp_dir / "test1.parquet")
assert DataChain.from_storage(uri, session=test_session).count() == 1
Expand Down Expand Up @@ -138,7 +138,7 @@ def test_from_storage_partials(cloud_test_catalog):
catalog = session.catalog

def _list_dataset_name(uri: str) -> str:
name = parse_listing_uri(uri, catalog.cache, catalog.client_config)[0]
name = parse_listing_uri(uri, catalog.client_config)[0]
assert name
return name

Expand Down Expand Up @@ -182,7 +182,7 @@ def test_from_storage_partials_with_update(cloud_test_catalog):
catalog = session.catalog

def _list_dataset_name(uri: str) -> str:
name = parse_listing_uri(uri, catalog.cache, catalog.client_config)[0]
name = parse_listing_uri(uri, catalog.client_config)[0]
assert name
return name

Expand Down
2 changes: 1 addition & 1 deletion tests/func/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ def test_dataset_storage_dependencies(cloud_test_catalog, cloud_type, indirect):
ds_name = "some_ds"
DataChain.from_storage(uri, session=session).save(ds_name)

lst_ds_name, _, _ = parse_listing_uri(uri, catalog.cache, catalog.client_config)
lst_ds_name, _, _ = parse_listing_uri(uri, catalog.client_config)
lst_dataset = catalog.metastore.get_dataset(lst_ds_name)

assert [
Expand Down
4 changes: 2 additions & 2 deletions tests/func/test_listing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_parse_listing_uri(cloud_test_catalog, cloud_type):
ctc = cloud_test_catalog
catalog = ctc.catalog
dataset_name, listing_uri, listing_path = parse_listing_uri(
f"{ctc.src_uri}/dogs", catalog.cache, catalog.client_config
f"{ctc.src_uri}/dogs", catalog.client_config
)
assert dataset_name == f"lst__{ctc.src_uri}/dogs/"
assert listing_uri == f"{ctc.src_uri}/dogs/"
Expand All @@ -57,7 +57,7 @@ def test_parse_listing_uri_with_glob(cloud_test_catalog):
ctc = cloud_test_catalog
catalog = ctc.catalog
dataset_name, listing_uri, listing_path = parse_listing_uri(
f"{ctc.src_uri}/dogs/*", catalog.cache, catalog.client_config
f"{ctc.src_uri}/dogs/*", catalog.client_config
)
assert dataset_name == f"lst__{ctc.src_uri}/dogs/"
assert listing_uri == f"{ctc.src_uri}/dogs"
Expand Down

0 comments on commit 669f359

Please sign in to comment.