From 2402d4f257a289e51ece2e736b4a6dacb2dba5dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Sat, 7 Sep 2024 12:51:47 +0545 Subject: [PATCH 1/3] avoid setting script_output and query_script in the dataset and dataset_version --- src/datachain/catalog/catalog.py | 14 -------------- tests/func/test_catalog.py | 1 - tests/func/test_datasets.py | 2 -- tests/func/test_query.py | 4 ---- 4 files changed, 21 deletions(-) diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index d525efa2a..9af3713ce 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -1886,20 +1886,6 @@ def _get_dataset_versions_by_job_id(): "No dataset found after running Query script", output=output, ) from e - - dr = self.update_dataset( - dr, - script_output=output, - query_script=query_script, - ) - self.update_dataset_version_with_warehouse_info( - dr, - dv.version, - script_output=output, - query_script=query_script, - job_id=job_id, - is_job_result=True, - ) return QueryResult(dataset=dr, version=dv.version, output=output) def run_query( diff --git a/tests/func/test_catalog.py b/tests/func/test_catalog.py index c47b87668..634804b0b 100644 --- a/tests/func/test_catalog.py +++ b/tests/func/test_catalog.py @@ -927,7 +927,6 @@ def test_query(cloud_test_catalog, mock_popen_dataset_created): "dog4", }, ) - assert result.dataset.query_script == query_script assert result.dataset.sources == "" diff --git a/tests/func/test_datasets.py b/tests/func/test_datasets.py index a5d9bf567..cf0304eae 100644 --- a/tests/func/test_datasets.py +++ b/tests/func/test_datasets.py @@ -57,7 +57,6 @@ def test_create_dataset_no_version_specified(cloud_test_catalog, create_rows): dataset_version = dataset.get_version(1) assert dataset.name == name - assert dataset.query_script == "script" assert dataset_version.query_script == "script" assert dataset.schema["similarity"] == Float32 assert dataset_version.schema["similarity"] == Float32 @@ -87,7 +86,6 @@ def test_create_dataset_with_explicit_version(cloud_test_catalog, create_rows): dataset_version = dataset.get_version(1) assert dataset.name == name - assert dataset.query_script == "script" assert dataset_version.query_script == "script" assert dataset.schema["similarity"] == Float32 assert dataset_version.schema["similarity"] == Float32 diff --git a/tests/func/test_query.py b/tests/func/test_query.py index 34d593055..4b90f0358 100644 --- a/tests/func/test_query.py +++ b/tests/func/test_query.py @@ -189,7 +189,6 @@ def test_query( assert result.version == 1 assert result.dataset.versions_values == [1] - assert result.dataset.query_script == query_script assert_row_names( catalog, result.dataset, @@ -255,7 +254,6 @@ def test_query_where_last_command_is_call_on_save_which_returns_attached_dataset result = catalog.query(query_script, save=True) assert not result.dataset.name.startswith(QUERY_DATASET_PREFIX) - assert result.dataset.query_script == query_script assert result.version == 1 assert result.dataset.versions_values == [1] assert_row_names( @@ -293,7 +291,6 @@ def test_query_where_last_command_is_attached_dataset_query_created_from_save( result = catalog.query(query_script, save=True) assert result.dataset.name == "dogs" - assert result.dataset.query_script == query_script assert result.version == 1 assert result.dataset.versions_values == [1] assert_row_names( @@ -331,7 +328,6 @@ def test_query_where_last_command_is_attached_dataset_query_created_from_query( result = catalog.query(query_script, save=True) assert result.dataset.name == "dogs" - assert result.dataset.query_script == query_script assert result.version == 1 assert result.dataset.versions_values == [1] assert_row_names( From ec06ad1b3d9ff7049631f1999ec49f3802c62498 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Sat, 7 Sep 2024 14:05:57 +0545 Subject: [PATCH 2/3] avoid returning latest dataset, let the caller do the work --- src/datachain/catalog/catalog.py | 152 ++++++++----------------------- tests/func/test_catalog.py | 51 ++++------- tests/func/test_query.py | 10 +- 3 files changed, 61 insertions(+), 152 deletions(-) diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index 9af3713ce..833c93c19 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -12,7 +12,6 @@ import time import traceback from collections.abc import Iterable, Iterator, Mapping, Sequence -from contextlib import contextmanager, nullcontext from copy import copy from dataclasses import dataclass from functools import cached_property, reduce @@ -58,7 +57,6 @@ PendingIndexingError, QueryScriptCancelError, QueryScriptCompileError, - QueryScriptDatasetNotFound, QueryScriptRunError, ) from datachain.listing import Listing @@ -115,38 +113,19 @@ def noop(_: str): pass -@contextmanager -def print_and_capture( - stream: "IO[bytes]|IO[str]", callback: Callable[[str], None] = noop -) -> "Iterator[list[str]]": - lines: list[str] = [] - append = lines.append +def _process_stream(stream: "IO[bytes]", callback: Callable[[str], None]) -> None: + buffer = b"" + while byt := stream.read(1): # Read one byte at a time + buffer += byt - def loop() -> None: - buffer = b"" - while byt := stream.read(1): # Read one byte at a time - buffer += byt.encode("utf-8") if isinstance(byt, str) else byt - - if byt in (b"\n", b"\r"): # Check for newline or carriage return - line = buffer.decode("utf-8") - print(line, end="") - callback(line) - append(line) - buffer = b"" # Clear buffer for next line - - if buffer: # Handle any remaining data in the buffer + if byt in (b"\n", b"\r"): # Check for newline or carriage return line = buffer.decode("utf-8") - print(line, end="") callback(line) - append(line) - - thread = Thread(target=loop, daemon=True) - thread.start() + buffer = b"" # Clear buffer for next line - try: - yield lines - finally: - thread.join() + if buffer: # Handle any remaining data in the buffer + line = buffer.decode("utf-8") + callback(line) class QueryResult(NamedTuple): @@ -651,11 +630,6 @@ def attach_query_wrapper(self, code_ast): code_ast.body[-1:] = new_expressions return code_ast - def compile_query_script(self, script: str) -> str: - code_ast = ast.parse(script) - code_ast = self.attach_query_wrapper(code_ast) - return ast.unparse(code_ast) - def parse_url(self, uri: str, **config: Any) -> tuple[Client, str]: config = config or self.client_config return Client.parse_url(uri, self.cache, **config) @@ -1806,13 +1780,13 @@ def query( self, query_script: str, envs: Optional[Mapping[str, str]] = None, - python_executable: Optional[str] = None, + python_executable: str = sys.executable, save: bool = False, capture_output: bool = True, output_hook: Callable[[str], None] = noop, params: Optional[dict[str, str]] = None, job_id: Optional[str] = None, - ) -> QueryResult: + ) -> None: """ Method to run custom user Python script to run a query and, as result, creates new dataset from the results of a query. @@ -1835,76 +1809,15 @@ def query( C.size > 1000 ) """ - if not job_id: - python_version = f"{sys.version_info.major}.{sys.version_info.minor}" - job_id = self.metastore.create_job( - name="", - query=query_script, - params=params, - python_version=python_version, - ) - - lines, proc = self.run_query( - python_executable or sys.executable, - query_script, - envs, - capture_output, - output_hook, - params, - save, - job_id, - ) - output = "".join(lines) - - if proc.returncode: - if proc.returncode == QUERY_SCRIPT_CANCELED_EXIT_CODE: - raise QueryScriptCancelError( - "Query script was canceled by user", - return_code=proc.returncode, - output=output, - ) - raise QueryScriptRunError( - f"Query script exited with error code {proc.returncode}", - return_code=proc.returncode, - output=output, - ) - - def _get_dataset_versions_by_job_id(): - for dr, dv, job in self.list_datasets_versions(): - if job and str(job.id) == job_id: - yield dr, dv - try: - dr, dv = max( - _get_dataset_versions_by_job_id(), key=lambda x: x[1].created_at - ) - except ValueError as e: - if not save: - return QueryResult(dataset=None, version=None, output=output) - - raise QueryScriptDatasetNotFound( - "No dataset found after running Query script", - output=output, - ) from e - return QueryResult(dataset=dr, version=dv.version, output=output) - - def run_query( - self, - python_executable: str, - query_script: str, - envs: Optional[Mapping[str, str]], - capture_output: bool, - output_hook: Callable[[str], None], - params: Optional[dict[str, str]], - save: bool, - job_id: Optional[str], - ) -> tuple[list[str], subprocess.Popen]: - try: - query_script_compiled = self.compile_query_script(query_script) + code_ast = ast.parse(query_script) + code_ast = self.attach_query_wrapper(code_ast) + query_script_compiled = ast.unparse(code_ast) except Exception as exc: raise QueryScriptCompileError( f"Query script failed to compile, reason: {exc}" ) from exc + envs = dict(envs or os.environ) envs.update( { @@ -1915,19 +1828,34 @@ def run_query( "DATACHAIN_JOB_ID": job_id or "", }, ) - with subprocess.Popen( # noqa: S603 + popen_kwargs = {} + if capture_output: + popen_kwargs = {"stdout": subprocess.PIPE, "stderr": subprocess.STDOUT} + + with subprocess.Popen( # type: ignore[call-overload] # noqa: S603 [python_executable, "-c", query_script_compiled], env=envs, - stdout=subprocess.PIPE if capture_output else None, - stderr=subprocess.STDOUT if capture_output else None, - bufsize=1, - text=False, + **popen_kwargs, ) as proc: - out = proc.stdout - _lines: list[str] = [] - ctx = print_and_capture(out, output_hook) if out else nullcontext(_lines) - with ctx as lines: - return lines, proc + if capture_output: + thread = Thread( + target=_process_stream, + daemon=True, + args=(proc.stdout, output_hook), + ) + thread.start() + thread.join() + + if proc.returncode == QUERY_SCRIPT_CANCELED_EXIT_CODE: + raise QueryScriptCancelError( + "Query script was canceled by user", + return_code=proc.returncode, + ) + if proc.returncode: + raise QueryScriptRunError( + f"Query script exited with error code {proc.returncode}", + return_code=proc.returncode, + ) def cp( self, diff --git a/tests/func/test_catalog.py b/tests/func/test_catalog.py index 634804b0b..5db667b4e 100644 --- a/tests/func/test_catalog.py +++ b/tests/func/test_catalog.py @@ -12,7 +12,6 @@ from datachain.cli import garbage_collect from datachain.error import ( QueryScriptCompileError, - QueryScriptDatasetNotFound, QueryScriptRunError, StorageNotFoundError, ) @@ -59,7 +58,7 @@ def mock_popen_dataset_created( mocker, monkeypatch, mock_popen, cloud_test_catalog, listed_bucket ): # create dataset which would be created in subprocess - ds_name = cloud_test_catalog.catalog.generate_query_dataset_name() + ds_name = "my-ds" job_id = cloud_test_catalog.catalog.metastore.create_job(name="", query="") mocker.patch.object( cloud_test_catalog.catalog.metastore, "create_job", return_value=job_id @@ -910,24 +909,16 @@ def test_query(cloud_test_catalog, mock_popen_dataset_created): query_script = f"""\ from datachain.query import C, DatasetQuery - DatasetQuery({src_uri!r}) + DatasetQuery({src_uri!r}).save("my-ds") """ query_script = dedent(query_script) - result = catalog.query(query_script, save=True) - assert result.dataset - assert_row_names( - catalog, - result.dataset, - result.version, - { - "dog1", - "dog2", - "dog3", - "dog4", - }, - ) - assert result.dataset.sources == "" + catalog.query(query_script) + + dataset = catalog.get_dataset("my-ds") + assert dataset + assert dataset.versions_values == [1] + assert_row_names(catalog, dataset, 1, {"dog1", "dog2", "dog3", "dog4"}) def test_query_save_size(cloud_test_catalog, mock_popen_dataset_created): @@ -936,12 +927,17 @@ def test_query_save_size(cloud_test_catalog, mock_popen_dataset_created): query_script = f"""\ from datachain.query import C, DatasetQuery - DatasetQuery({src_uri!r}) + DatasetQuery({src_uri!r}).save("my-ds") """ query_script = dedent(query_script) - result = catalog.query(query_script, save=True) - dataset_version = result.dataset.get_version(result.version) + catalog.query(query_script) + + dataset = catalog.get_dataset("my-ds") + assert dataset + assert dataset.versions_values == [1] + + dataset_version = dataset.get_version(1) assert dataset_version.num_objects == 4 assert dataset_version.size == 15 @@ -970,21 +966,6 @@ def test_query_subprocess_wrong_return_code(mock_popen, cloud_test_catalog): assert str(exc_info.value).startswith("Query script exited with error code 1") -def test_query_dataset_not_returned(mock_popen, cloud_test_catalog): - mock_popen.configure_mock(stdout=io.StringIO("random str")) - catalog = cloud_test_catalog.catalog - src_uri = cloud_test_catalog.src_uri - - query_script = f""" -from datachain.query import DatasetQuery, C -DatasetQuery('{src_uri}') - """ - - with pytest.raises(QueryScriptDatasetNotFound) as e: - catalog.query(query_script, save=True) - assert e.value.output == "random str" - - @pytest.mark.parametrize("cloud_type", ["s3", "azure", "gs"], indirect=True) def test_storage_stats(cloud_test_catalog): catalog = cloud_test_catalog.catalog diff --git a/tests/func/test_query.py b/tests/func/test_query.py index 4b90f0358..9a768a3e6 100644 --- a/tests/func/test_query.py +++ b/tests/func/test_query.py @@ -136,16 +136,16 @@ def test_query_cli_without_dataset_query_as_a_last_statement( query_script = f"""\ from datachain.query import DatasetQuery - DatasetQuery({src_uri!r}, catalog=catalog).save("temp") + DatasetQuery({src_uri!r}, catalog=catalog).save("my-ds") print("test") """ query_script = setup_catalog(query_script, catalog_info_filepath) - result = catalog.query(query_script) - assert result.dataset - assert result.dataset.name == "temp" - assert result.version == 1 + catalog.query(query_script) + dataset = catalog.get_dataset("my-ds") + assert dataset + assert dataset.versions_values == [1] out, err = capsys.readouterr() assert "test" in out From 01fcceb39ea5df03723761dc6d171b4fbb6ad4f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Mon, 9 Sep 2024 10:17:48 +0545 Subject: [PATCH 3/3] disable wrapping last statement by default --- src/datachain/catalog/catalog.py | 50 +++---- src/datachain/error.py | 4 - tests/func/test_catalog.py | 2 +- tests/func/test_query.py | 246 +++---------------------------- 4 files changed, 45 insertions(+), 257 deletions(-) diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index 833c93c19..f0eda2359 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -22,7 +22,6 @@ TYPE_CHECKING, Any, Callable, - NamedTuple, NoReturn, Optional, Union, @@ -128,12 +127,6 @@ def _process_stream(stream: "IO[bytes]", callback: Callable[[str], None]) -> Non callback(line) -class QueryResult(NamedTuple): - dataset: Optional[DatasetRecord] - version: Optional[int] - output: str - - class DatasetRowsFetcher(NodesThreadPool): def __init__( self, @@ -1779,13 +1772,14 @@ def apply_udf( def query( self, query_script: str, - envs: Optional[Mapping[str, str]] = None, + env: Optional[Mapping[str, str]] = None, python_executable: str = sys.executable, save: bool = False, capture_output: bool = True, output_hook: Callable[[str], None] = noop, params: Optional[dict[str, str]] = None, job_id: Optional[str] = None, + _execute_last_expression: bool = False, ) -> None: """ Method to run custom user Python script to run a query and, as result, @@ -1809,17 +1803,21 @@ def query( C.size > 1000 ) """ - try: - code_ast = ast.parse(query_script) - code_ast = self.attach_query_wrapper(code_ast) - query_script_compiled = ast.unparse(code_ast) - except Exception as exc: - raise QueryScriptCompileError( - f"Query script failed to compile, reason: {exc}" - ) from exc + if _execute_last_expression: + try: + code_ast = ast.parse(query_script) + code_ast = self.attach_query_wrapper(code_ast) + query_script_compiled = ast.unparse(code_ast) + except Exception as exc: + raise QueryScriptCompileError( + f"Query script failed to compile, reason: {exc}" + ) from exc + else: + query_script_compiled = query_script + assert not save - envs = dict(envs or os.environ) - envs.update( + env = dict(env or os.environ) + env.update( { "DATACHAIN_QUERY_PARAMS": json.dumps(params or {}), "PYTHONPATH": os.getcwd(), # For local imports @@ -1832,19 +1830,13 @@ def query( if capture_output: popen_kwargs = {"stdout": subprocess.PIPE, "stderr": subprocess.STDOUT} - with subprocess.Popen( # type: ignore[call-overload] # noqa: S603 - [python_executable, "-c", query_script_compiled], - env=envs, - **popen_kwargs, - ) as proc: + cmd = [python_executable, "-c", query_script_compiled] + with subprocess.Popen(cmd, env=env, **popen_kwargs) as proc: # type: ignore[call-overload] # noqa: S603 if capture_output: - thread = Thread( - target=_process_stream, - daemon=True, - args=(proc.stdout, output_hook), - ) + args = (proc.stdout, output_hook) + thread = Thread(target=_process_stream, args=args, daemon=True) thread.start() - thread.join() + thread.join() # wait for the reader thread if proc.returncode == QUERY_SCRIPT_CANCELED_EXIT_CODE: raise QueryScriptCancelError( diff --git a/src/datachain/error.py b/src/datachain/error.py index ffc5d653d..4a7096c3a 100644 --- a/src/datachain/error.py +++ b/src/datachain/error.py @@ -42,10 +42,6 @@ def __init__(self, message: str, return_code: int = 0, output: str = ""): super().__init__(self.message) -class QueryScriptDatasetNotFound(QueryScriptRunError): # noqa: N818 - pass - - class QueryScriptCancelError(QueryScriptRunError): pass diff --git a/tests/func/test_catalog.py b/tests/func/test_catalog.py index 5db667b4e..0373917d6 100644 --- a/tests/func/test_catalog.py +++ b/tests/func/test_catalog.py @@ -948,7 +948,7 @@ def test_query_fail_to_compile(cloud_test_catalog): query_script = "syntax error" with pytest.raises(QueryScriptCompileError): - catalog.query(query_script) + catalog.query(query_script, _execute_last_expression=True) def test_query_subprocess_wrong_return_code(mock_popen, cloud_test_catalog): diff --git a/tests/func/test_query.py b/tests/func/test_query.py index 9a768a3e6..4f618afe3 100644 --- a/tests/func/test_query.py +++ b/tests/func/test_query.py @@ -1,16 +1,17 @@ -import json import os.path from textwrap import dedent -from typing import Optional +from typing import TYPE_CHECKING import dill import pytest -from datachain.catalog import QUERY_DATASET_PREFIX from datachain.cli import query from datachain.data_storage import AbstractDBMetastore, JobQueryType, JobStatus from tests.utils import assert_row_names +if TYPE_CHECKING: + from datachain.job import Job + @pytest.fixture def catalog_info_filepath(cloud_test_catalog_tmpfile, tmp_path): @@ -67,26 +68,11 @@ def setup_catalog(query: str, catalog_info_filepath: str) -> str: def get_latest_job( metastore: AbstractDBMetastore, -) -> Optional[tuple[str, str, int, int, str, str]]: +) -> "Job": j = metastore._jobs - - latest_jobs_query = ( - metastore._jobs_select( - j.c.id, - j.c.name, - j.c.status, - j.c.query_type, - j.c.error_message, - j.c.error_stack, - j.c.metrics, - ) - .order_by(j.c.created_at.desc()) - .limit(1) - ) - latest_jobs = list(metastore.db.execute(latest_jobs_query)) - if len(latest_jobs) == 0: - return None - return latest_jobs[0] + query = metastore._jobs_select().order_by(j.c.created_at.desc()).limit(1) + (row,) = metastore.db.execute(query) + return metastore._parse_job(row) @pytest.mark.parametrize("cloud_type,version_aware", [("file", False)], indirect=True) @@ -119,80 +105,31 @@ def test_query_cli(cloud_test_catalog_tmpfile, tmp_path, catalog_info_filepath, latest_job = get_latest_job(catalog.metastore) assert latest_job - assert str(latest_job[0]) == str(result_job_id) - assert latest_job[1] == os.path.basename(filepath) - assert latest_job[2] == JobStatus.COMPLETE - assert latest_job[3] == JobQueryType.PYTHON - assert latest_job[4] == "" - assert latest_job[5] == "" + assert str(latest_job.id) == str(result_job_id) + assert latest_job.name == os.path.basename(filepath) + assert latest_job.status == JobStatus.COMPLETE + assert latest_job.query_type == JobQueryType.PYTHON + assert latest_job.error_message == "" + assert latest_job.error_stack == "" -def test_query_cli_without_dataset_query_as_a_last_statement( - cloud_test_catalog_tmpfile, tmp_path, catalog_info_filepath, capsys -): +def test_query(cloud_test_catalog_tmpfile, tmp_path, catalog_info_filepath): catalog = cloud_test_catalog_tmpfile.catalog src_uri = cloud_test_catalog_tmpfile.src_uri query_script = f"""\ from datachain.query import DatasetQuery - DatasetQuery({src_uri!r}, catalog=catalog).save("my-ds") - - print("test") """ query_script = setup_catalog(query_script, catalog_info_filepath) - catalog.query(query_script) + dataset = catalog.get_dataset("my-ds") - assert dataset assert dataset.versions_values == [1] - - out, err = capsys.readouterr() - assert "test" in out - assert not err - - -@pytest.mark.parametrize( - "save", - (True, False), -) -@pytest.mark.parametrize("save_dataset", (None, "new-dataset")) -def test_query( - save, - save_dataset, - cloud_test_catalog_tmpfile, - tmp_path, - catalog_info_filepath, -): - catalog = cloud_test_catalog_tmpfile.catalog - src_uri = cloud_test_catalog_tmpfile.src_uri - - query_script = f"""\ - from datachain.query import DatasetQuery - ds = DatasetQuery({src_uri!r}, catalog=catalog) - if {save_dataset!r}: - ds = ds.save({save_dataset!r}) - ds - """ - query_script = setup_catalog(query_script, catalog_info_filepath) - - result = catalog.query(query_script, save=save) - if save_dataset: - assert result.dataset.name == save_dataset - assert catalog.get_dataset(save_dataset) - elif save: - assert result.dataset.name.startswith(QUERY_DATASET_PREFIX) - else: - assert result.dataset is None - assert result.version is None - return - - assert result.version == 1 - assert result.dataset.versions_values == [1] assert_row_names( catalog, - result.dataset, - result.version, + dataset, + 1, { "cat1", "cat2", @@ -205,146 +142,8 @@ def test_query( ) -@pytest.mark.parametrize( - "params,count", - ( - (None, 7), - ({"limit": 1}, 1), - ({"limit": 5}, 5), - ), -) -def test_query_params( - params, count, cloud_test_catalog_tmpfile, tmp_path, catalog_info_filepath -): - catalog = cloud_test_catalog_tmpfile.catalog - src_uri = cloud_test_catalog_tmpfile.src_uri - - query_script = f"""\ - from datachain.query import DatasetQuery, param - - ds = DatasetQuery({src_uri!r}, catalog=catalog) - if param("limit"): - ds = ds.limit(int(param("limit"))) - ds - """ - query_script = setup_catalog(query_script, catalog_info_filepath) - - result = catalog.query(query_script, save=True, params=params) - assert ( - len(list(catalog.ls_dataset_rows(result.dataset.name, result.version))) == count - ) - - -def test_query_where_last_command_is_call_on_save_which_returns_attached_dataset( - cloud_test_catalog_tmpfile, tmp_path, catalog_info_filepath -): - """ - Testing use case where last command is call on DatasetQuery save which returns - attached instance to underlying saved dataset - """ - catalog = cloud_test_catalog_tmpfile.catalog - src_uri = cloud_test_catalog_tmpfile.src_uri - - query_script = f"""\ - from datachain.query import C, DatasetQuery - - DatasetQuery({src_uri!r}, catalog=catalog).filter(C.path.glob("*dog*")).save("dogs") - """ - query_script = setup_catalog(query_script, catalog_info_filepath) - - result = catalog.query(query_script, save=True) - assert not result.dataset.name.startswith(QUERY_DATASET_PREFIX) - assert result.version == 1 - assert result.dataset.versions_values == [1] - assert_row_names( - catalog, - result.dataset, - result.version, - { - "dog1", - "dog2", - "dog3", - "dog4", - }, - ) - - -def test_query_where_last_command_is_attached_dataset_query_created_from_save( - cloud_test_catalog_tmpfile, tmp_path, catalog_info_filepath -): - """ - Testing use case where last command is instance of DatasetQuery which is - attached to underlying dataset by calling save just before - """ - catalog = cloud_test_catalog_tmpfile.catalog - src_uri = cloud_test_catalog_tmpfile.src_uri - - query_script = f"""\ - from datachain.query import C, DatasetQuery - - ds = DatasetQuery( - {src_uri!r}, catalog=catalog - ).filter(C.path.glob("*dog*")).save("dogs") - ds - """ - query_script = setup_catalog(query_script, catalog_info_filepath) - - result = catalog.query(query_script, save=True) - assert result.dataset.name == "dogs" - assert result.version == 1 - assert result.dataset.versions_values == [1] - assert_row_names( - catalog, - result.dataset, - result.version, - { - "dog1", - "dog2", - "dog3", - "dog4", - }, - ) - - -def test_query_where_last_command_is_attached_dataset_query_created_from_query( - cloud_test_catalog_tmpfile, tmp_path, catalog_info_filepath -): - """ - Testing use case where last command is instance of DatasetQuery which is - attached to underlying dataset by creating query pointing to it - """ - catalog = cloud_test_catalog_tmpfile.catalog - src_uri = cloud_test_catalog_tmpfile.src_uri - - query_script = f"""\ - from datachain.query import C, DatasetQuery - - ds = DatasetQuery( - {src_uri!r}, catalog=catalog - ).filter(C.path.glob("*dog*")).save("dogs") - DatasetQuery(name="dogs", version=1, catalog=catalog) - """ - query_script = setup_catalog(query_script, catalog_info_filepath) - - result = catalog.query(query_script, save=True) - assert result.dataset.name == "dogs" - assert result.version == 1 - assert result.dataset.versions_values == [1] - assert_row_names( - catalog, - result.dataset, - result.version, - { - "dog1", - "dog2", - "dog3", - "dog4", - }, - ) - - @pytest.mark.parametrize("cloud_type,version_aware", [("file", False)], indirect=True) -def test_query_params_metrics( +def test_cli_query_params_metrics( cloud_test_catalog_tmpfile, tmp_path, catalog_info_filepath, capsys ): catalog = cloud_test_catalog_tmpfile.catalog @@ -357,7 +156,7 @@ def test_query_params_metrics( metrics.set("count", ds.count()) - ds + ds.save("my-ds") """ query_script = setup_catalog(query_script, catalog_info_filepath) @@ -369,5 +168,6 @@ def test_query_params_metrics( latest_job = get_latest_job(catalog.metastore) assert latest_job - assert latest_job[2] == JobStatus.COMPLETE - assert json.loads(latest_job[6]) == {"count": 7} + assert latest_job.status == JobStatus.COMPLETE + assert latest_job.params == {"url": src_uri} + assert latest_job.metrics == {"count": 7}