Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disable execution of last query expression by default #407

Merged
merged 3 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 49 additions & 143 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import time
import traceback
from collections.abc import Iterable, Iterator, Mapping, Sequence
from contextlib import contextmanager, nullcontext
from copy import copy
from dataclasses import dataclass
from functools import cached_property, reduce
Expand All @@ -23,7 +22,6 @@
TYPE_CHECKING,
Any,
Callable,
NamedTuple,
NoReturn,
Optional,
Union,
Expand Down Expand Up @@ -58,7 +56,6 @@
PendingIndexingError,
QueryScriptCancelError,
QueryScriptCompileError,
QueryScriptDatasetNotFound,
QueryScriptRunError,
)
from datachain.listing import Listing
Expand Down Expand Up @@ -115,44 +112,19 @@
pass


@contextmanager
def print_and_capture(
stream: "IO[bytes]|IO[str]", callback: Callable[[str], None] = noop
) -> "Iterator[list[str]]":
lines: list[str] = []
append = lines.append
def _process_stream(stream: "IO[bytes]", callback: Callable[[str], None]) -> None:
buffer = b""
while byt := stream.read(1): # Read one byte at a time
buffer += byt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing some errors coming through the test suite of the form: TypeError: can't concat str to bytes.

Examples are here: https://github.com/iterative/datachain/actions/runs/10821775005/job/30024466321?pr=427#step:7:177

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed tests in #431.


def loop() -> None:
buffer = b""
while byt := stream.read(1): # Read one byte at a time
buffer += byt.encode("utf-8") if isinstance(byt, str) else byt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may still need this to fix the issue @mattseddon mentioned above

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #431.


if byt in (b"\n", b"\r"): # Check for newline or carriage return
line = buffer.decode("utf-8")
print(line, end="")
callback(line)
append(line)
buffer = b"" # Clear buffer for next line

if buffer: # Handle any remaining data in the buffer
if byt in (b"\n", b"\r"): # Check for newline or carriage return
line = buffer.decode("utf-8")
print(line, end="")
callback(line)
append(line)

thread = Thread(target=loop, daemon=True)
thread.start()

try:
yield lines
finally:
thread.join()

buffer = b"" # Clear buffer for next line

class QueryResult(NamedTuple):
dataset: Optional[DatasetRecord]
version: Optional[int]
output: str
if buffer: # Handle any remaining data in the buffer
line = buffer.decode("utf-8")
callback(line)

Check warning on line 127 in src/datachain/catalog/catalog.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/catalog/catalog.py#L126-L127

Added lines #L126 - L127 were not covered by tests
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll no longer print to the stdout when capture_output=True.



class DatasetRowsFetcher(NodesThreadPool):
Expand Down Expand Up @@ -651,11 +623,6 @@
code_ast.body[-1:] = new_expressions
return code_ast

def compile_query_script(self, script: str) -> str:
code_ast = ast.parse(script)
code_ast = self.attach_query_wrapper(code_ast)
return ast.unparse(code_ast)

def parse_url(self, uri: str, **config: Any) -> tuple[Client, str]:
config = config or self.client_config
return Client.parse_url(uri, self.cache, **config)
Expand Down Expand Up @@ -1805,14 +1772,15 @@
def query(
self,
query_script: str,
envs: Optional[Mapping[str, str]] = None,
python_executable: Optional[str] = None,
env: Optional[Mapping[str, str]] = None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to env. I don't think it should be plural.

python_executable: str = sys.executable,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed python_executable to default to sys.executable and not take a None value.

save: bool = False,
capture_output: bool = True,
output_hook: Callable[[str], None] = noop,
params: Optional[dict[str, str]] = None,
job_id: Optional[str] = None,
) -> QueryResult:
_execute_last_expression: bool = False,
Copy link
Member Author

@skshetry skshetry Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to get rid of this when we update Studio. (Well, I plan to release Studio with _execute_last_expression=True set first, then break compatibility in the next release).

) -> None:
"""
Method to run custom user Python script to run a query and, as result,
creates new dataset from the results of a query.
Expand All @@ -1835,92 +1803,21 @@
C.size > 1000
)
"""
if not job_id:
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
job_id = self.metastore.create_job(
name="",
query=query_script,
params=params,
python_version=python_version,
)

lines, proc = self.run_query(
python_executable or sys.executable,
query_script,
envs,
capture_output,
output_hook,
params,
save,
job_id,
)
output = "".join(lines)

if proc.returncode:
if proc.returncode == QUERY_SCRIPT_CANCELED_EXIT_CODE:
raise QueryScriptCancelError(
"Query script was canceled by user",
return_code=proc.returncode,
output=output,
)
raise QueryScriptRunError(
f"Query script exited with error code {proc.returncode}",
return_code=proc.returncode,
output=output,
)

def _get_dataset_versions_by_job_id():
for dr, dv, job in self.list_datasets_versions():
if job and str(job.id) == job_id:
yield dr, dv

try:
dr, dv = max(
_get_dataset_versions_by_job_id(), key=lambda x: x[1].created_at
)
except ValueError as e:
if not save:
return QueryResult(dataset=None, version=None, output=output)

raise QueryScriptDatasetNotFound(
"No dataset found after running Query script",
output=output,
) from e
Comment on lines -1872 to -1888
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will have to be done on the caller side. And eventually removed when we drop _execute_last_expression support.


dr = self.update_dataset(
dr,
script_output=output,
query_script=query_script,
)
self.update_dataset_version_with_warehouse_info(
dr,
dv.version,
script_output=output,
query_script=query_script,
job_id=job_id,
is_job_result=True,
)
Comment on lines -1890 to -1902
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was not being used anywhere. Job replaces this.

return QueryResult(dataset=dr, version=dv.version, output=output)
if _execute_last_expression:
try:
code_ast = ast.parse(query_script)
code_ast = self.attach_query_wrapper(code_ast)
query_script_compiled = ast.unparse(code_ast)

Check warning on line 1810 in src/datachain/catalog/catalog.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/catalog/catalog.py#L1809-L1810

Added lines #L1809 - L1810 were not covered by tests
except Exception as exc:
raise QueryScriptCompileError(
f"Query script failed to compile, reason: {exc}"
) from exc
else:
query_script_compiled = query_script
assert not save
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to remove this flag as well.


def run_query(
self,
python_executable: str,
query_script: str,
envs: Optional[Mapping[str, str]],
capture_output: bool,
output_hook: Callable[[str], None],
params: Optional[dict[str, str]],
save: bool,
job_id: Optional[str],
) -> tuple[list[str], subprocess.Popen]:
try:
query_script_compiled = self.compile_query_script(query_script)
except Exception as exc:
raise QueryScriptCompileError(
f"Query script failed to compile, reason: {exc}"
) from exc
envs = dict(envs or os.environ)
envs.update(
env = dict(env or os.environ)
Copy link
Contributor

@dreadatour dreadatour Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, not sure why or? May be something like this?

Suggested change
env = dict(env or os.environ)
env = {**os.environ, **(env or {})}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would you provide a way to override envvars of the current process?

This is how subprocess.Popen works, and given this is a thin wrapper around it, I think it's better to mimic it's API.

Also, Studio already provides copy of all envvars.

env.update(
{
"DATACHAIN_QUERY_PARAMS": json.dumps(params or {}),
"PYTHONPATH": os.getcwd(), # For local imports
Expand All @@ -1929,19 +1826,28 @@
"DATACHAIN_JOB_ID": job_id or "",
},
)
with subprocess.Popen( # noqa: S603
[python_executable, "-c", query_script_compiled],
env=envs,
stdout=subprocess.PIPE if capture_output else None,
stderr=subprocess.STDOUT if capture_output else None,
bufsize=1,
text=False,
) as proc:
out = proc.stdout
_lines: list[str] = []
ctx = print_and_capture(out, output_hook) if out else nullcontext(_lines)
with ctx as lines:
return lines, proc
popen_kwargs = {}
if capture_output:
popen_kwargs = {"stdout": subprocess.PIPE, "stderr": subprocess.STDOUT}

cmd = [python_executable, "-c", query_script_compiled]
with subprocess.Popen(cmd, env=env, **popen_kwargs) as proc: # type: ignore[call-overload] # noqa: S603
if capture_output:
args = (proc.stdout, output_hook)
thread = Thread(target=_process_stream, args=args, daemon=True)
thread.start()
thread.join() # wait for the reader thread

if proc.returncode == QUERY_SCRIPT_CANCELED_EXIT_CODE:
raise QueryScriptCancelError(

Check warning on line 1842 in src/datachain/catalog/catalog.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/catalog/catalog.py#L1842

Added line #L1842 was not covered by tests
"Query script was canceled by user",
return_code=proc.returncode,
)
if proc.returncode:
raise QueryScriptRunError(
f"Query script exited with error code {proc.returncode}",
return_code=proc.returncode,
)

def cp(
self,
Expand Down
4 changes: 0 additions & 4 deletions src/datachain/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ def __init__(self, message: str, return_code: int = 0, output: str = ""):
super().__init__(self.message)


class QueryScriptDatasetNotFound(QueryScriptRunError): # noqa: N818
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer used in query. I will create a similar exception in Studio side.

pass


class QueryScriptCancelError(QueryScriptRunError):
pass

Expand Down
54 changes: 17 additions & 37 deletions tests/func/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from datachain.cli import garbage_collect
from datachain.error import (
QueryScriptCompileError,
QueryScriptDatasetNotFound,
QueryScriptRunError,
StorageNotFoundError,
)
Expand Down Expand Up @@ -59,7 +58,7 @@ def mock_popen_dataset_created(
mocker, monkeypatch, mock_popen, cloud_test_catalog, listed_bucket
):
# create dataset which would be created in subprocess
ds_name = cloud_test_catalog.catalog.generate_query_dataset_name()
ds_name = "my-ds"
job_id = cloud_test_catalog.catalog.metastore.create_job(name="", query="")
mocker.patch.object(
cloud_test_catalog.catalog.metastore, "create_job", return_value=job_id
Expand Down Expand Up @@ -910,25 +909,16 @@ def test_query(cloud_test_catalog, mock_popen_dataset_created):

query_script = f"""\
from datachain.query import C, DatasetQuery
DatasetQuery({src_uri!r})
DatasetQuery({src_uri!r}).save("my-ds")
"""
query_script = dedent(query_script)

result = catalog.query(query_script, save=True)
assert result.dataset
assert_row_names(
catalog,
result.dataset,
result.version,
{
"dog1",
"dog2",
"dog3",
"dog4",
},
)
assert result.dataset.query_script == query_script
assert result.dataset.sources == ""
catalog.query(query_script)

dataset = catalog.get_dataset("my-ds")
assert dataset
assert dataset.versions_values == [1]
assert_row_names(catalog, dataset, 1, {"dog1", "dog2", "dog3", "dog4"})


def test_query_save_size(cloud_test_catalog, mock_popen_dataset_created):
Expand All @@ -937,12 +927,17 @@ def test_query_save_size(cloud_test_catalog, mock_popen_dataset_created):

query_script = f"""\
from datachain.query import C, DatasetQuery
DatasetQuery({src_uri!r})
DatasetQuery({src_uri!r}).save("my-ds")
"""
query_script = dedent(query_script)

result = catalog.query(query_script, save=True)
dataset_version = result.dataset.get_version(result.version)
catalog.query(query_script)

dataset = catalog.get_dataset("my-ds")
assert dataset
assert dataset.versions_values == [1]

dataset_version = dataset.get_version(1)
assert dataset_version.num_objects == 4
assert dataset_version.size == 15

Expand All @@ -953,7 +948,7 @@ def test_query_fail_to_compile(cloud_test_catalog):
query_script = "syntax error"

with pytest.raises(QueryScriptCompileError):
catalog.query(query_script)
catalog.query(query_script, _execute_last_expression=True)


def test_query_subprocess_wrong_return_code(mock_popen, cloud_test_catalog):
Expand All @@ -971,21 +966,6 @@ def test_query_subprocess_wrong_return_code(mock_popen, cloud_test_catalog):
assert str(exc_info.value).startswith("Query script exited with error code 1")


def test_query_dataset_not_returned(mock_popen, cloud_test_catalog):
mock_popen.configure_mock(stdout=io.StringIO("random str"))
catalog = cloud_test_catalog.catalog
src_uri = cloud_test_catalog.src_uri

query_script = f"""
from datachain.query import DatasetQuery, C
DatasetQuery('{src_uri}')
"""

with pytest.raises(QueryScriptDatasetNotFound) as e:
catalog.query(query_script, save=True)
assert e.value.output == "random str"


@pytest.mark.parametrize("cloud_type", ["s3", "azure", "gs"], indirect=True)
def test_storage_stats(cloud_test_catalog):
catalog = cloud_test_catalog.catalog
Expand Down
2 changes: 0 additions & 2 deletions tests/func/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def test_create_dataset_no_version_specified(cloud_test_catalog, create_rows):
dataset_version = dataset.get_version(1)

assert dataset.name == name
assert dataset.query_script == "script"
assert dataset_version.query_script == "script"
assert dataset.schema["similarity"] == Float32
assert dataset_version.schema["similarity"] == Float32
Expand Down Expand Up @@ -87,7 +86,6 @@ def test_create_dataset_with_explicit_version(cloud_test_catalog, create_rows):
dataset_version = dataset.get_version(1)

assert dataset.name == name
assert dataset.query_script == "script"
assert dataset_version.query_script == "script"
assert dataset.schema["similarity"] == Float32
assert dataset_version.schema["similarity"] == Float32
Expand Down
Loading