From b4dd749bdc6bc37d95b92c6cf0036349bdb28fb4 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Wed, 11 Sep 2024 19:14:23 +1000 Subject: [PATCH] use the correct fixtures in tests --- tests/examples/test_wds_e2e.py | 18 ++++++++++-------- tests/func/test_datachain.py | 4 ++-- tests/test_cli_e2e.py | 2 +- tests/test_query_e2e.py | 2 +- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/examples/test_wds_e2e.py b/tests/examples/test_wds_e2e.py index c4178ad82..84d0fe8a1 100644 --- a/tests/examples/test_wds_e2e.py +++ b/tests/examples/test_wds_e2e.py @@ -69,10 +69,10 @@ def webdataset_metadata(tmp_path): return metadata_path -def test_wds(catalog, webdataset_tars): - res = DataChain.from_storage(Path(webdataset_tars).as_uri()).gen( - laion=process_webdataset(spec=WDSLaion), params="file" - ) +def test_wds(test_session, webdataset_tars): + res = DataChain.from_storage( + Path(webdataset_tars).as_uri(), session=test_session + ).gen(laion=process_webdataset(spec=WDSLaion), params="file") num_rows = 0 for laion_wds in res.collect("laion"): @@ -95,10 +95,12 @@ def test_wds(catalog, webdataset_tars): assert num_rows == len(WDS_TAR_SHARDS) -def test_wds_merge_with_parquet_meta(catalog, webdataset_tars, webdataset_metadata): - wds = DataChain.from_storage(Path(webdataset_tars).as_uri()).gen( - laion=process_webdataset(spec=WDSLaion), params="file" - ) +def test_wds_merge_with_parquet_meta( + test_session, webdataset_tars, webdataset_metadata +): + wds = DataChain.from_storage( + Path(webdataset_tars).as_uri(), session=test_session + ).gen(laion=process_webdataset(spec=WDSLaion), params="file") meta = DataChain.from_parquet(Path(webdataset_metadata).as_uri()) diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index 87176458b..cc1c2a5a5 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -453,8 +453,8 @@ def test_from_storage_check_rows(tmp_dir, test_session): ) -def test_mutate_existing_column(catalog): - ds = DataChain.from_values(ids=[1, 2, 3]) +def test_mutate_existing_column(test_session): + ds = DataChain.from_values(ids=[1, 2, 3], session=test_session) with pytest.raises(DataChainColumnError) as excinfo: ds.mutate(ids=Column("ids") + 1) diff --git a/tests/test_cli_e2e.py b/tests/test_cli_e2e.py index 35f4845e1..442b6458c 100644 --- a/tests/test_cli_e2e.py +++ b/tests/test_cli_e2e.py @@ -213,7 +213,7 @@ def run_step(step): @pytest.mark.e2e -def test_cli_e2e(tmp_dir, catalog): +def test_cli_e2e(tmp_dir, catalog_tmpfile): """End-to-end CLI Test""" for step in E2E_STEPS: run_step(step) diff --git a/tests/test_query_e2e.py b/tests/test_query_e2e.py index 3baa81f10..cdee7d376 100644 --- a/tests/test_query_e2e.py +++ b/tests/test_query_e2e.py @@ -243,7 +243,7 @@ def run_step(step): # noqa: PLR0912 @pytest.mark.e2e -def test_query_e2e(tmp_dir, catalog): +def test_query_e2e(tmp_dir, catalog_tmpfile): """End-to-end CLI Query Test""" for step in E2E_STEPS: run_step(step)