diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 8f30a96eaa8d..64f19186bddd 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -311,6 +311,7 @@ jobs: WB_SERVER_HOST: http://wandbservice WF_CLICKHOUSE_HOST: localhost WEAVE_SERVER_DISABLE_ECOSYSTEM: 1 + DD_TRACE_ENABLED: false run: | nox -e "tests-${{ matrix.python-version-major }}.${{ matrix.python-version-minor }}(shard='${{ matrix.nox-shard }}')" -- \ -m "weave_client and not skip_clickhouse_client" \ @@ -328,6 +329,7 @@ jobs: ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + DD_TRACE_ENABLED: false run: | nox -e "tests-${{ matrix.python-version-major }}.${{ matrix.python-version-minor }}(shard='${{ matrix.nox-shard }}')" trace-tests-matrix-check: # This job does nothing and is only used for the branch protection diff --git a/docs/docs/guides/core-types/models.md b/docs/docs/guides/core-types/models.md index 34d2434108e7..989bdd3b7359 100644 --- a/docs/docs/guides/core-types/models.md +++ b/docs/docs/guides/core-types/models.md @@ -76,92 +76,6 @@ A `Model` is a combination of data (which can include configuration, trained mod model.predict('world') ``` - ## Pairwise evaluation of models - - When [scoring](../evaluation/scorers.md) models in a Weave [evaluation](../core-types/evaluations.md), absolute value metrics (e.g. `9/10` for Model A and `8/10` for Model B) are typically harder to assign than than relative ones (e.g. Model A performs better than Model B). _Pairwise evaluation_ allows you to compare the outputs of two models by ranking them relative to each other. This approach is particularly useful when you want to determine which model performs better for subjective tasks such as text generation, summarization, or question answering. With pairwise evaluation, you can obtain a relative preference ranking that reveals which model is best for specific inputs. - - The following code sample demonstrates how to implement a pairwise evaluation in Weave by creating a [class-based scorer](../evaluation/scorers.md#class-based-scorers) called `PreferenceScorer`. The `PreferenceScorer` compares two models, `ModelA` and `ModelB`, and returns a relative score of the model outputs based on explicit hints in the input text. - - ```python - from weave import Model, Evaluation, Scorer, Dataset - from weave.flow.model import ApplyModelError, apply_model_async - - class ModelA(Model): - @weave.op - def predict(self, input_text: str): - if "Prefer model A" in input_text: - return {"response": "This is a great answer from Model A"} - return {"response": "Meh, whatever"} - - class ModelB(Model): - @weave.op - def predict(self, input_text: str): - if "Prefer model B" in input_text: - return {"response": "This is a thoughtful answer from Model B"} - return {"response": "I don't know"} - - class PreferenceScorer(Scorer): - @weave.op - async def _get_other_model_output(self, example: dict) -> Any: - """Get output from the other model for comparison. - Args: - example: The input example data to run through the other model - Returns: - The output from the other model - """ - - other_model_result = await apply_model_async( - self.other_model, - example, - None, - ) - - if isinstance(other_model_result, ApplyModelError): - return None - - return other_model_result.model_output - - @weave.op - async def score(self, output: dict, input_text: str) -> dict: - """Compare the output of the primary model with the other model. - Args: - output (dict): The output from the primary model. - other_output (dict): The output from the other model being compared. - inputs (str): The input text used to generate the outputs. - Returns: - dict: A flat dictionary containing the comparison result and reason. - """ - other_output = await self._get_other_model_output( - {"input_text": inputs} - ) - if other_output is None: - return {"primary_is_better": False, "reason": "Other model failed"} - - if "Prefer model A" in input_text: - primary_is_better = True - reason = "Model A gave a great answer" - else: - primary_is_better = False - reason = "Model B is preferred for this type of question" - - return {"primary_is_better": primary_is_better, "reason": reason} - - dataset = Dataset( - rows=[ - {"input_text": "Prefer model A: Question 1"}, # Model A wins - {"input_text": "Prefer model A: Question 2"}, # Model A wins - {"input_text": "Prefer model B: Question 3"}, # Model B wins - {"input_text": "Prefer model B: Question 4"}, # Model B wins - ] - ) - - model_a = ModelA() - model_b = ModelB() - pref_scorer = PreferenceScorer(other_model=model_b) - evaluation = Evaluation(dataset=dataset, scorers=[pref_scorer]) - evaluation.evaluate(model_a) -``` - ```plaintext diff --git a/docs/docs/guides/tracking/faqs.md b/docs/docs/guides/tracking/faqs.md index 9303dc46e2e0..95a3ac202868 100644 --- a/docs/docs/guides/tracking/faqs.md +++ b/docs/docs/guides/tracking/faqs.md @@ -51,3 +51,92 @@ When your program is exiting it may appear to pause while any remaining enqueued ## How is Weave data ingestion calculated? We define ingested bytes as bytes that we receive, process, and store on your behalf. This includes trace metadata, LLM inputs/outputs, and any other information you explicitly log to Weave, but does not include communication overhead (e.g., HTTP headers) or any other data that is not placed in long-term storage. We count bytes as "ingested" only once at the time they are received and stored. + +## What is pairwise evaluation and how do I do it? + +When [scoring](../evaluation/scorers.md) models in a Weave [evaluation](../core-types/evaluations.md), absolute value metrics (e.g. `9/10` for Model A and `8/10` for Model B) are typically harder to assign than relative ones (e.g. Model A performs better than Model B). _Pairwise evaluation_ allows you to compare the outputs of two models by ranking them relative to each other. This approach is particularly useful when you want to determine which model performs better for subjective tasks such as text generation, summarization, or question answering. With pairwise evaluation, you can obtain a relative preference ranking that reveals which model is best for specific inputs. + +:::important +This approach is a workaround and may change in future releases. We are actively working on a more robust API to support pairwise evaluations. Stay tuned for updates! +::: + +The following code sample demonstrates how to implement a pairwise evaluation in Weave by creating a [class-based scorer](../evaluation/scorers.md#class-based-scorers) called `PreferenceScorer`. The `PreferenceScorer` compares two models, `ModelA` and `ModelB`, and returns a relative score of the model outputs based on explicit hints in the input text. + +```python +from weave import Model, Evaluation, Scorer, Dataset +from weave.flow.model import ApplyModelError, apply_model_async + +class ModelA(Model): + @weave.op + def predict(self, input_text: str): + if "Prefer model A" in input_text: + return {"response": "This is a great answer from Model A"} + return {"response": "Meh, whatever"} + +class ModelB(Model): + @weave.op + def predict(self, input_text: str): + if "Prefer model B" in input_text: + return {"response": "This is a thoughtful answer from Model B"} + return {"response": "I don't know"} + +class PreferenceScorer(Scorer): + @weave.op + async def _get_other_model_output(self, example: dict) -> Any: + """Get output from the other model for comparison. + Args: + example: The input example data to run through the other model + Returns: + The output from the other model + """ + + other_model_result = await apply_model_async( + self.other_model, + example, + None, + ) + + if isinstance(other_model_result, ApplyModelError): + return None + + return other_model_result.model_output + + @weave.op + async def score(self, output: dict, input_text: str) -> dict: + """Compare the output of the primary model with the other model. + Args: + output (dict): The output from the primary model. + input_text (str): The input text used to generate the outputs. + Returns: + dict: A flat dictionary containing the comparison result and reason. + """ + other_output = await self._get_other_model_output( + {"input_text": input_text} + ) + if other_output is None: + return {"primary_is_better": False, "reason": "Other model failed"} + + if "Prefer model A" in input_text: + primary_is_better = True + reason = "Model A gave a great answer" + else: + primary_is_better = False + reason = "Model B is preferred for this type of question" + + return {"primary_is_better": primary_is_better, "reason": reason} + +dataset = Dataset( + rows=[ + {"input_text": "Prefer model A: Question 1"}, # Model A wins + {"input_text": "Prefer model A: Question 2"}, # Model A wins + {"input_text": "Prefer model B: Question 3"}, # Model B wins + {"input_text": "Prefer model B: Question 4"}, # Model B wins + ] +) + +model_a = ModelA() +model_b = ModelB() +pref_scorer = PreferenceScorer(other_model=model_b) +evaluation = Evaluation(dataset=dataset, scorers=[pref_scorer]) +evaluation.evaluate(model_a) +``` diff --git a/noxfile.py b/noxfile.py index 342c9815330a..06930b88d3a5 100644 --- a/noxfile.py +++ b/noxfile.py @@ -72,6 +72,7 @@ def tests(session, shard): "WB_SERVER_HOST", "WF_CLICKHOUSE_HOST", "WEAVE_SERVER_DISABLE_ECOSYSTEM", + "DD_TRACE_ENABLED", ] } # Add the GOOGLE_API_KEY environment variable for the "google" shard diff --git a/pyproject.toml b/pyproject.toml index 2612023fd2d9..34e08c430ee4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ dependencies = [ # this to a separate package. Note, when that happens, we will need to pull along some of the #default dependencies as well. trace_server = [ + "ddtrace>=2.7.0", # BYOB - S3 "boto3>=1.34.0", # BYOB - Azure diff --git a/tests/conftest.py b/tests/conftest.py index 6c042aca048a..132daf76f218 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,6 +32,53 @@ os.environ["WANDB_ERROR_REPORTING"] = "false" +@pytest.fixture(autouse=True) +def disable_datadog(): + """ + Disables Datadog logging and tracing for tests. + + This prevents Datadog from polluting test logs with messages like + 'failed to send, dropping 1 traces to intake at...' + """ + # Save original values to restore later + original_dd_env = os.environ.get("DD_ENV") + original_dd_trace = os.environ.get("DD_TRACE_ENABLED") + + # Disable Datadog + os.environ["DD_ENV"] = "none" + os.environ["DD_TRACE_ENABLED"] = "false" + + # Silence Datadog loggers + dd_loggers = [ + "ddtrace", + "ddtrace.writer", + "ddtrace.api", + "ddtrace.internal", + "datadog", + "datadog.dogstatsd", + "datadog.api", + ] + + original_levels = {} + for logger_name in dd_loggers: + logger = logging.getLogger(logger_name) + original_levels[logger_name] = logger.level + logger.setLevel(logging.CRITICAL) # Only show critical errors + + yield + + # Restore original values + if original_dd_env is not None: + os.environ["DD_ENV"] = original_dd_env + elif "DD_ENV" in os.environ: + del os.environ["DD_ENV"] + + if original_dd_trace is not None: + os.environ["DD_TRACE_ENABLED"] = original_dd_trace + elif "DD_TRACE_ENABLED" in os.environ: + del os.environ["DD_TRACE_ENABLED"] + + def pytest_addoption(parser): parser.addoption( "--weave-server", @@ -362,14 +409,14 @@ def emit(self, record): self.log_records[curr_test] = [] self.log_records[curr_test].append(record) - def get_error_logs(self): + def _get_logs(self, levelname: str): curr_test = get_test_name() logs = self.log_records.get(curr_test, []) return [ record for record in logs - if record.levelname == "ERROR" + if record.levelname == levelname and record.name.startswith("weave") # (Tim) For some reason that i cannot figure out, there is some test that # a) is trying to connect to the PROD trace server @@ -386,13 +433,22 @@ def get_error_logs(self): and not "legacy" in record.name ] + def get_error_logs(self): + return self._get_logs("ERROR") + + def get_warning_logs(self): + return self._get_logs("WARNING") + @pytest.fixture -def log_collector(): +def log_collector(request): handler = InMemoryWeaveLogCollector() logger = logging.getLogger() # Get your specific logger here if needed logger.addHandler(handler) - logger.setLevel(logging.ERROR) # Set the level to capture all logs + if hasattr(request, "param") and request.param == "warning": + logger.setLevel(logging.WARNING) + else: + logger.setLevel(logging.ERROR) yield handler logger.removeHandler(handler) # Clean up after the test diff --git a/tests/trace/test_client_server_caching.py b/tests/trace/test_client_server_caching.py index 2e8662c80e57..6774f5a1fb9a 100644 --- a/tests/trace/test_client_server_caching.py +++ b/tests/trace/test_client_server_caching.py @@ -58,8 +58,17 @@ def test_server_caching(client): gotten_dataset = client.get(ref) assert caching_server.get_cache_recorder() == { "hits": 0, - # 1 obj read for the dataset + # get the ref + "misses": 1, + "errors": 0, + "skips": 0, + } + caching_server.reset_cache_recorder() + rows = list(gotten_dataset) + assert caching_server.get_cache_recorder() == { + "hits": 0, # 1 table read for the rows + # 1 table_query_stats for len(rows) # 5 images "misses": 7, "errors": 0, @@ -71,7 +80,7 @@ def test_server_caching(client): caching_server.reset_cache_recorder() compare_datasets(client.get(ref), dataset) assert caching_server.get_cache_recorder() == { - "hits": 8, + "hits": 7, "misses": 0, "errors": 0, "skips": 0, @@ -86,7 +95,7 @@ def test_server_caching(client): "hits": 0, "misses": 0, "errors": 0, - "skips": 8, + "skips": 7, } diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index ecb34ace7e2f..7b8aa0ea6266 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -123,7 +123,12 @@ def test_dataset(client): d = Dataset(rows=[{"a": 5, "b": 6}, {"a": 7, "b": 10}]) ref = weave.publish(d) d2 = weave.ref(ref.uri()).get() + + # This might seem redundant, but it is useful to ensure that the + # dataset can be re-iterated over multiple times and equality is preserved. assert list(d2.rows) == list(d2.rows) + assert list(d.rows) == list(d2.rows) + assert list(d.rows) == list(d.rows) def test_trace_server_call_start_and_end(client): diff --git a/tests/trace/test_custom_objs.py b/tests/trace/test_custom_objs.py index 497271f42cdb..09ef34e12e47 100644 --- a/tests/trace/test_custom_objs.py +++ b/tests/trace/test_custom_objs.py @@ -1,6 +1,6 @@ from PIL import Image -from weave.trace.custom_objs import decode_custom_obj, encode_custom_obj +from weave.trace.serialization.custom_objs import decode_custom_obj, encode_custom_obj def test_decode_custom_obj_known_type(client): diff --git a/tests/trace/test_dataset.py b/tests/trace/test_dataset.py index 3054bef1a4c0..bd6c7ed925db 100644 --- a/tests/trace/test_dataset.py +++ b/tests/trace/test_dataset.py @@ -1,6 +1,8 @@ import pytest import weave +from tests.trace.test_evaluate import Dataset +from weave.trace.context.tests_context import raise_on_captured_errors def test_basic_dataset_lifecycle(client): @@ -35,6 +37,81 @@ def test_pythonic_access(client): ds[-1] +def _top_level_logs(log): + """Strip out internal logs from the log list""" + return [l for l in log if not l.startswith("_")] + + +def test_dataset_laziness(client): + """ + The intention of this test is to show that local construction of + a dataset does not trigger any remote operations. + """ + dataset = Dataset(rows=[{"input": i} for i in range(300)]) + log = client.server.attribute_access_log + assert _top_level_logs(log) == ["ensure_project_exists"] + client.server.attribute_access_log = [] + + length = len(dataset) + log = client.server.attribute_access_log + assert _top_level_logs(log) == [] + + length2 = len(dataset) + log = client.server.attribute_access_log + assert _top_level_logs(log) == [] + + assert length == length2 + + for row in dataset: + log = client.server.attribute_access_log + assert _top_level_logs(log) == [] + + +def test_published_dataset_laziness(client): + """ + The intention of this test is to show that publishing a dataset, + then iterating through the "gotten" version of the dataset has + minimal remote operations - and importantly delays the fetching + of the rows until they are actually needed. + """ + dataset = Dataset(rows=[{"input": i} for i in range(300)]) + log = client.server.attribute_access_log + assert _top_level_logs(log) == ["ensure_project_exists"] + client.server.attribute_access_log = [] + + ref = weave.publish(dataset) + log = client.server.attribute_access_log + assert _top_level_logs(log) == ["table_create", "obj_create"] + client.server.attribute_access_log = [] + + dataset = ref.get() + log = client.server.attribute_access_log + assert _top_level_logs(log) == ["obj_read"] + client.server.attribute_access_log = [] + + length = len(dataset) + log = client.server.attribute_access_log + assert _top_level_logs(log) == ["table_query_stats"] + client.server.attribute_access_log = [] + + length2 = len(dataset) + log = client.server.attribute_access_log + assert _top_level_logs(log) == [] + + assert length == length2 + + for i, row in enumerate(dataset): + log = client.server.attribute_access_log + # This is the critical part of the test - ensuring that + # the rows are only fetched when they are actually needed. + # + # In a future improvement, we might eagerly fetch the next + # page of results, which would result in this assertion changing + # in that there would always be one more "table_query" than + # the number of pages. + assert _top_level_logs(log) == ["table_query"] * ((i // 100) + 1) + + def test_dataset_from_calls(client): @weave.op def greet(name: str, age: int) -> str: @@ -54,3 +131,13 @@ def greet(name: str, age: int) -> str: assert rows[1]["inputs"]["name"] == "Bob" assert rows[1]["inputs"]["age"] == 25 assert rows[1]["output"] == "Hello Bob, you are 25!" + + +def test_dataset_caching(client): + ds = weave.Dataset(rows=[{"a": i} for i in range(200)]) + ref = weave.publish(ds) + + ds2 = ref.get() + + with raise_on_captured_errors(): + assert len(ds2) == 200 diff --git a/tests/trace/test_evaluate.py b/tests/trace/test_evaluate.py index ff0bd63a8d48..b81daa6a020b 100644 --- a/tests/trace/test_evaluate.py +++ b/tests/trace/test_evaluate.py @@ -1,5 +1,7 @@ import asyncio +import os import time +from unittest.mock import patch import pytest @@ -359,3 +361,67 @@ def score(target, output): result = asyncio.run(evaluation.evaluate(model)) assert result == expected_eval_result + + +def test_evaluate_table_lazy_iter(client): + """ + The intention of this test is to show that an evaluation harness + lazily fetches rows from a table rather than eagerly fetching all + rows up front. + """ + dataset = Dataset(rows=[{"input": i} for i in range(300)]) + ref = weave.publish(dataset) + dataset = ref.get() + + @weave.op() + async def model_predict(input) -> int: + return input * 1 + + @weave.op() + def score_simple(input, output): + return input == output + + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == [ + "ensure_project_exists", + "table_create", + "obj_create", + "obj_read", + ] + client.server.attribute_access_log = [] + + evaluation = Evaluation( + dataset=dataset, + scorers=[score_simple], + ) + log = client.server.attribute_access_log + assert [l for l in log if not l.startswith("_")] == [] + + # Make sure we have deterministic results + with patch.dict(os.environ, {"WEAVE_PARALLELISM": "1"}): + result = asyncio.run(evaluation.evaluate(model_predict)) + assert result["output"] == {"mean": 149.5} + assert result["score_simple"] == {"true_count": 300, "true_fraction": 1.0} + + log = client.server.attribute_access_log + log = [l for l in log if not l.startswith("_")] + + # Make sure that the length was figured out deterministically + assert "table_query_stats" in log + + counts_split_by_table_query = [0] + for log_entry in log: + if log_entry == "table_query": + counts_split_by_table_query.append(0) + else: + counts_split_by_table_query[-1] += 1 + + # Note: these exact numbers might change if we change the way eval traces work. + # However, the key part is that we have basically X + 2 splits, with the middle X + # being equal. We want to ensure that the table_query is not called in sequence, + # but rather lazily after each batch. + assert counts_split_by_table_query[0] <= 13 + # Note: if this test suite is ran in a different order, then the low level eval ops will already be saved + # so the first count can be different. + count = counts_split_by_table_query[0] + assert counts_split_by_table_query == [count, 700, 700, 700, 5], log diff --git a/tests/trace/test_op_decorator_behaviour.py b/tests/trace/test_op_decorator_behaviour.py index b29fcc8d7e13..d9b556de93f6 100644 --- a/tests/trace/test_op_decorator_behaviour.py +++ b/tests/trace/test_op_decorator_behaviour.py @@ -4,8 +4,7 @@ import pytest import weave -from weave.trace import errors -from weave.trace.op import is_op, op +from weave.trace.op import OpCallError, is_op, op from weave.trace.refs import ObjectRef, parse_uri from weave.trace.vals import MissingSelfInstanceError from weave.trace.weave_client import Call @@ -140,7 +139,7 @@ def test_sync_method_call(client, weave_obj, py_obj): with pytest.raises(MissingSelfInstanceError): weave_obj_method2 = weave_obj_method_ref.get() - with pytest.raises(errors.OpCallError): + with pytest.raises(OpCallError): res2, call2 = py_obj.method.call(1) @@ -175,7 +174,7 @@ async def test_async_method_call(client, weave_obj, py_obj): with pytest.raises(MissingSelfInstanceError): weave_obj_amethod2 = weave_obj_amethod_ref.get() - with pytest.raises(errors.OpCallError): + with pytest.raises(OpCallError): res2, call2 = await py_obj.amethod.call(1) diff --git a/tests/trace/test_serialize.py b/tests/trace/test_serialize.py index 1c51028ee571..251455ca4471 100644 --- a/tests/trace/test_serialize.py +++ b/tests/trace/test_serialize.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from weave.trace.serialize import ( +from weave.trace.serialization.serialize import ( dictify, fallback_encode, is_pydantic_model_class, diff --git a/tests/trace/test_weave_client.py b/tests/trace/test_weave_client.py index a624a59b943e..f28b06f2b816 100644 --- a/tests/trace/test_weave_client.py +++ b/tests/trace/test_weave_client.py @@ -28,7 +28,10 @@ TABLE_ROW_ID_EDGE_NAME, DeletedRef, ) -from weave.trace.serializer import get_serializer_for_obj, register_serializer +from weave.trace.serialization.serializer import ( + get_serializer_for_obj, + register_serializer, +) from weave.trace_server.clickhouse_trace_server_batched import NotFoundError from weave.trace_server.constants import MAX_DISPLAY_NAME_LENGTH from weave.trace_server.sqlite_trace_server import ( diff --git a/tests/trace_server_bindings/test_async_batch_processor.py b/tests/trace_server_bindings/test_async_batch_processor.py new file mode 100644 index 000000000000..e6c363810653 --- /dev/null +++ b/tests/trace_server_bindings/test_async_batch_processor.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import time +from unittest.mock import MagicMock, call + +from weave.trace_server_bindings.async_batch_processor import AsyncBatchProcessor + + +def test_max_batch_size(): + processor_fn = MagicMock() + processor = AsyncBatchProcessor(processor_fn, max_batch_size=2) + + # Queue up 2 batches of 3 items + processor.enqueue([1, 2, 3]) + processor.stop_accepting_new_work_and_flush_queue() + + # But the max batch size is 2, so the batch is split apart + processor_fn.assert_has_calls( + [ + call([1, 2]), + call([3]), + ] + ) + + +def test_min_batch_interval(): + processor_fn = MagicMock() + processor = AsyncBatchProcessor( + processor_fn, max_batch_size=100, min_batch_interval=1 + ) + + # Queue up batches of 3 items within the min_batch_interval + processor.enqueue([1, 2, 3]) + time.sleep(0.1) + processor.enqueue([4, 5, 6]) + time.sleep(0.1) + processor.enqueue([7, 8, 9]) + processor.stop_accepting_new_work_and_flush_queue() + + # Processor should batch them all together + processor_fn.assert_called_once_with([1, 2, 3, 4, 5, 6, 7, 8, 9]) + + +def test_wait_until_all_processed(): + processor_fn = MagicMock() + processor = AsyncBatchProcessor( + processor_fn, max_batch_size=100, min_batch_interval=0.01 + ) + + processor.enqueue([1, 2, 3]) + processor.stop_accepting_new_work_and_flush_queue() + + # Despite queueing extra items, they will never get flushed because the processor is + # already shut down. + processor.enqueue([4, 5, 6]) + processor.stop_accepting_new_work_and_flush_queue() + processor.enqueue([7, 8, 9]) + processor.stop_accepting_new_work_and_flush_queue() + + # We should only see the first batch. Everything else is stuck in the queue. + processor_fn.assert_has_calls([call([1, 2, 3])]) + assert processor.queue.qsize() == 6 diff --git a/tests/trace_server_bindings/test_remote_http_trace_server.py b/tests/trace_server_bindings/test_remote_http_trace_server.py new file mode 100644 index 000000000000..a4b7105fe430 --- /dev/null +++ b/tests/trace_server_bindings/test_remote_http_trace_server.py @@ -0,0 +1,343 @@ +from __future__ import annotations + +import datetime +import json +from queue import Full +from types import MethodType +from unittest.mock import MagicMock, patch + +import pytest +import requests +import tenacity + +from weave.trace_server import trace_server_interface as tsi +from weave.trace_server.ids import generate_id +from weave.trace_server_bindings.remote_http_trace_server import ( + Batch, + EndBatchItem, + RemoteHTTPTraceServer, + StartBatchItem, +) + + +def generate_start(id: str | None = None) -> tsi.StartedCallSchemaForInsert: + return tsi.StartedCallSchemaForInsert( + project_id="test", + id=id or generate_id(), + op_name="test_name", + trace_id="test_trace_id", + parent_id="test_parent_id", + started_at=datetime.datetime.now(tz=datetime.timezone.utc), + attributes={"a": 5}, + inputs={"b": 5}, + ) + + +def generate_end(id: str | None = None) -> tsi.EndedCallSchemaForInsert: + return tsi.EndedCallSchemaForInsert( + project_id="test", + id=id or generate_id(), + ended_at=datetime.datetime.now(tz=datetime.timezone.utc) + + datetime.timedelta(seconds=1), + outputs={"c": 5}, + error=None, + summary={"result": "Test summary"}, + ) + + +def generate_call_start_end_pair( + id: str | None = None, +) -> tuple[tsi.CallStartReq, tsi.CallEndReq]: + start = generate_start(id) + end = generate_end(id) + return tsi.CallStartReq(start=start), tsi.CallEndReq(end=end) + + +@pytest.fixture +def success_response(): + response = MagicMock() + response.status_code = 200 + response.json.return_value = {"id": "test_id", "trace_id": "test_trace_id"} + return response + + +@pytest.fixture +def server(request): + _server = RemoteHTTPTraceServer("http://example.com", should_batch=True) + + if request.param == "normal": + _server._send_batch_to_server = MagicMock() + elif request.param == "small_limit": + _server.remote_request_bytes_limit = 1024 # 1kb + _server._send_batch_to_server = MagicMock() + elif request.param == "fast_retrying": + fast_retry = tenacity.retry( + wait=tenacity.wait_fixed(0.1), + stop=tenacity.stop_after_attempt(2), + reraise=True, + ) + unwrapped_send_batch_to_server = MethodType( + _server._send_batch_to_server.__wrapped__, _server + ) + _server._send_batch_to_server = fast_retry(unwrapped_send_batch_to_server) + + yield _server + + if hasattr(_server, "call_processor"): + _server.call_processor.stop_accepting_new_work_and_flush_queue() + + +@pytest.mark.parametrize("server", ["small_limit"], indirect=True) +def test_large_batch_is_split_into_multiple_smaller_batches(server): + # Create a large batch with many items to exceed the size limit + batch = [] + for _ in range(20): + start, end = generate_call_start_end_pair() + batch.append(StartBatchItem(req=start)) + batch.append(EndBatchItem(req=end)) + + # Verify the batch is actually large enough to trigger splitting + data = Batch(batch=batch).model_dump_json() + encoded_data = data.encode("utf-8") + assert len(encoded_data) > server.remote_request_bytes_limit + + # Process the batch and verify _send_batch_to_server was called more than once, + # implying the batch was split into smaller chunks + server._flush_calls(batch) + assert server._send_batch_to_server.call_count > 1 + + # Verify all items were sent + total_items_sent = 0 + for call in server._send_batch_to_server.call_args_list: + called_data = call[0][0] + decoded_batch = json.loads(called_data.decode("utf-8")) + total_items_sent += len(decoded_batch["batch"]) + + assert total_items_sent == len(batch) + + +@pytest.mark.parametrize("server", ["normal"], indirect=True) +def test_small_batch_is_sent_in_one_request(server): + """Test that a small batch is sent without splitting.""" + # Create and process a single item + start, _ = generate_call_start_end_pair() + batch = [StartBatchItem(req=start)] + server._flush_calls(batch) + + # Verify _send_batch_to_server was called once with the entire batch + assert server._send_batch_to_server.call_count == 1 + called_data = server._send_batch_to_server.call_args[0][0] + decoded_batch = json.loads(called_data.decode("utf-8")) + assert len(decoded_batch["batch"]) == 1 + + +@pytest.mark.parametrize("server", ["normal"], indirect=True) +def test_empty_batch_is_noop(server): + batch = [] + server._flush_calls(batch) + + # Verify _send_batch_to_server was not called + assert server._send_batch_to_server.call_count == 0 + + +@pytest.mark.parametrize("server", ["small_limit"], indirect=True) +def test_oversized_item_will_error_without_sending(server): + """Test that a single item that's too large raises an error.""" + # Create a single item with a very large payload + start = generate_start() + start.attributes = { + "large_data": "x" * server.remote_request_bytes_limit, + } + batch = [StartBatchItem(req=tsi.CallStartReq(start=start))] + + # Verify the single item is actually large enough to trigger the error + data = Batch(batch=batch).model_dump_json() + encoded_data = data.encode("utf-8") + assert len(encoded_data) > server.remote_request_bytes_limit + + # Process the batch and expect an error + with pytest.raises(ValueError) as excinfo: + server._flush_calls(batch) + + # Verify the error message + assert "Single call size" in str(excinfo.value) + assert "is too large to send" in str(excinfo.value) + + # Verify _send_batch_to_server was not called + assert server._send_batch_to_server.call_count == 0 + + +@pytest.mark.parametrize("server", ["small_limit"], indirect=True) +def test_multi_level_recursive_splitting(server): + """Test that a very large batch is recursively split multiple times.""" + # Create a very large batch with many items to force multiple levels of splitting. + # Some items are larger than others to test non-uniform sizes. + batch = [] + for i in range(50): + start = generate_start() + end = generate_end() + if i % 5 == 0: + start.attributes = {"data": "x" * 500} + batch.append(StartBatchItem(req=tsi.CallStartReq(start=start))) + batch.append(EndBatchItem(req=tsi.CallEndReq(end=end))) + + # Process the batch + server._flush_calls(batch) + + # Verify _send_batch_to_server was called multiple times + # The exact number depends on the batch sizes, but it should be more than just 1 split + assert server._send_batch_to_server.call_count > 2 + + # Verify all items were sent + total_items_sent = 0 + for call in server._send_batch_to_server.call_args_list: + called_data = call[0][0] + decoded_batch = json.loads(called_data.decode("utf-8")) + total_items_sent += len(decoded_batch["batch"]) + + assert total_items_sent == len(batch) + + +@pytest.mark.parametrize("server", ["normal"], indirect=True) +def test_dynamic_batch_size_adjustment(server): + """Test that max_batch_size is dynamically adjusted based on item sizes.""" + # Create a batch with consistent item sizes + batch = [] + for _ in range(10): + start, end = generate_call_start_end_pair() + batch.append(StartBatchItem(req=start)) + + # Initial max_batch_size should be the default + original_max_batch_size = server.call_processor.max_batch_size + + # Process the batch + server._flush_calls(batch) + + # Verify max_batch_size was updated + new_max_batch_size = server.call_processor.max_batch_size + assert new_max_batch_size != original_max_batch_size + + # The new max_batch_size should be based on the average item size + data = Batch(batch=batch).model_dump_json() + encoded_bytes = len(data.encode("utf-8")) + estimated_bytes_per_item = encoded_bytes / len(batch) + expected_max_batch_size = max( + 1, int(server.remote_request_bytes_limit // estimated_bytes_per_item) + ) + + assert new_max_batch_size == expected_max_batch_size + + +@pytest.mark.parametrize("server", ["small_limit"], indirect=True) +def test_non_uniform_batch_items(server): + """Test batch with extremely non-uniform item sizes.""" + # Create a batch with vastly different sized items + batch = [] + + # Add several small items + for _ in range(5): + start, _ = generate_call_start_end_pair() + batch.append(StartBatchItem(req=start)) + + # Add one medium item + start = generate_start() + start.attributes = {"medium_data": "y" * 300} + batch.append(StartBatchItem(req=tsi.CallStartReq(start=start))) + + # Add one large item (but still under the limit) + start = generate_start() + start.attributes = { + "large_data": "z" * (server.remote_request_bytes_limit // 2), + } + batch.append(StartBatchItem(req=tsi.CallStartReq(start=start))) + + # Process the batch + server._flush_calls(batch) + + # The batch should have been split to accommodate the different sized items + assert server._send_batch_to_server.call_count >= 2 + + # Verify all items were sent + total_items_sent = 0 + for call in server._send_batch_to_server.call_args_list: + called_data = call[0][0] + decoded_batch = json.loads(called_data.decode("utf-8")) + total_items_sent += len(decoded_batch["batch"]) + + assert total_items_sent == len(batch) + + +@patch("weave.trace_server.requests.post") +def test_timeout_retry_mechanism(mock_post, success_response): + """Test that timeouts trigger the retry mechanism.""" + server = RemoteHTTPTraceServer("http://example.com", should_batch=True) + + # Mock server to raise errors twice, then succeed + mock_post.side_effect = [ + requests.exceptions.Timeout("Connection timed out"), + requests.exceptions.HTTPError("500 Server Error"), + success_response, + ] + + # Trying to send a batch should fail 2 times, then succeed + server.call_start(tsi.CallStartReq(start=generate_start())) + server.call_processor.stop_accepting_new_work_and_flush_queue() + + # Verify that requests.post was called 3 times + assert mock_post.call_count == 3 + + +@pytest.mark.disable_logging_error_check +@pytest.mark.parametrize("server", ["fast_retrying"], indirect=True) +@patch("weave.trace_server.requests.post") +def test_post_timeout(mock_post, success_response, server, log_collector): + """Test that we can still send new batches even if one batch times out. + + This test modifies the retry mechanism to use a short wait time and limited retries + to verify behavior when retries are exhausted. + """ + # Configure mock to timeout twice to exhaust retries + mock_post.side_effect = [ + # First batch times out twice + requests.exceptions.Timeout("Connection timed out"), + requests.exceptions.Timeout("Connection timed out"), + # Second batch times out once, but then succeeds + requests.exceptions.Timeout("Connection timed out"), + success_response, + ] + + # Phase 1: Try but fail to process the first batch + server.call_start(tsi.CallStartReq(start=generate_start())) + server.call_processor.stop_accepting_new_work_and_flush_queue() + logs = log_collector.get_error_logs() + assert len(logs) == 1 + assert logs[0].msg == "Error processing batch: Connection timed out" + + server.call_processor.accept_new_work() + + # Phase 2: Try and succeed with the second batch + server.call_start(tsi.CallStartReq(start=generate_start())) + server.call_processor.stop_accepting_new_work_and_flush_queue() + assert len(logs) == 1 # No new errors + + +@pytest.mark.disable_logging_error_check +@pytest.mark.parametrize("server", ["normal"], indirect=True) +@pytest.mark.parametrize("log_collector", ["warning"], indirect=True) +def test_drop_data_when_queue_is_full(server, log_collector): + """Test that items are dropped when the queue is full.""" + # Replace the real queue with a mock that raises Full when put_nowait is called + mock_queue = MagicMock() + mock_queue.put_nowait.side_effect = Full + server.call_processor.queue = mock_queue + + server.call_start(tsi.CallStartReq(start=generate_start())) + + # Verify that the put_nowait method was called (meaning we tried to enqueue the item) + mock_queue.put_nowait.assert_called_once() + + # We can still check logs as a secondary verification + logs = log_collector.get_warning_logs() + assert len(logs) == 1 + assert "Queue is full" in logs[0].msg + assert "Dropping item" in logs[0].msg diff --git a/tests/utils/test_iterators.py b/tests/utils/test_iterators.py new file mode 100644 index 000000000000..1bb42c6f29fb --- /dev/null +++ b/tests/utils/test_iterators.py @@ -0,0 +1,142 @@ +import threading + +import pytest + +from weave.utils.iterators import ThreadSafeLazyList + + +def test_basic_sequence_operations(): + # Test basic sequence operations + iterator = ThreadSafeLazyList(iter(range(10))) + assert len(iterator) == 10 + assert iterator[0] == 0 + assert iterator[1:3] == [1, 2] + assert list(iterator) == list(range(10)) + + +def test_empty_iterator(): + # Test behavior with empty iterator + iterator = ThreadSafeLazyList(iter([])) + assert len(iterator) == 0 + with pytest.raises(IndexError): + _ = iterator[0] + assert list(iterator) == [] + + +def test_known_length(): + # Test initialization with known length + iterator = ThreadSafeLazyList(iter(range(5)), known_length=5) + assert len(iterator) == 5 # Should not need to exhaust iterator + assert iterator[4] == 4 # Access last element + + +def test_multiple_iterations(): + # Test multiple iterations return same results + data = list(range(5)) + iterator = ThreadSafeLazyList(iter(data)) + + assert list(iterator) == data # First iteration + assert list(iterator) == data # Second iteration + assert list(iterator) == data # Third iteration + + +def test_concurrent_access(): + # Test thread-safe concurrent access + data = list(range(1000)) + iterator = ThreadSafeLazyList(iter(data)) + results = [] + + def reader_thread(): + results.append(list(iterator)) + + threads = [threading.Thread(target=reader_thread) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All threads should see the same data + assert all(r == data for r in results) + + +def test_slicing(): + # Test various slicing operations + iterator = ThreadSafeLazyList(iter(range(10))) + + assert iterator[2:5] == [2, 3, 4] + assert iterator[:3] == [0, 1, 2] + assert iterator[7:] == [7, 8, 9] + assert iterator[::2] == [0, 2, 4, 6, 8] + assert iterator[::-1] == [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] + + +def test_random_access(): + # Test random access patterns + iterator = ThreadSafeLazyList(iter(range(10))) + + assert iterator[5] == 5 # Middle access + assert iterator[1] == 1 # Earlier access + assert iterator[8] == 8 # Later access + assert iterator[0] == 0 # First element + assert iterator[9] == 9 # Last element + + +def test_concurrent_mixed_operations(): + # Test concurrent mixed operations (reads, slices, iterations) + data = list(range(100)) + iterator = ThreadSafeLazyList(iter(data)) + results = [] + + def mixed_ops_thread(): + local_results = [] + local_results.append(iterator[10]) # Single element access + local_results.append(list(iterator[20:25])) # Slice access + local_results.append(iterator[50]) # Another single element + local_results.extend(iterator[90:]) # End slice + results.append(local_results) + + threads = [threading.Thread(target=mixed_ops_thread) for _ in range(3)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Verify all threads got the same results + expected = [10, list(range(20, 25)), 50] + list(range(90, 100)) + assert all(r == expected for r in results) + + +def test_index_out_of_range(): + # Test index out of range behavior + iterator = ThreadSafeLazyList(iter(range(5))) + + with pytest.raises(IndexError): + _ = iterator[10] + + assert iterator[-1] == 4 # Negative indices are supported + + +def test_iterator_exhaustion(): + # Test behavior when iterator is exhausted + class CountingIterator: + def __init__(self): + self.count = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.count < 5: + self.count += 1 + return self.count - 1 + raise StopIteration + + iterator = ThreadSafeLazyList(CountingIterator()) + + # Access beyond iterator length should raise IndexError + assert len(iterator) == 5 + with pytest.raises(IndexError): + _ = iterator[10] + + # Verify original data is still accessible + assert list(iterator) == [0, 1, 2, 3, 4] diff --git a/weave-js/src/components/Form/TextField.tsx b/weave-js/src/components/Form/TextField.tsx index 8f5dd1171ff0..045fd9b37677 100644 --- a/weave-js/src/components/Form/TextField.tsx +++ b/weave-js/src/components/Form/TextField.tsx @@ -24,6 +24,7 @@ type TextFieldProps = { onChange?: (value: string) => void; onKeyDown?: (key: string, e: React.KeyboardEvent) => void; onBlur?: (value: string) => void; + onFocus?: () => void; autoFocus?: boolean; disabled?: boolean; icon?: IconName; @@ -48,6 +49,7 @@ export const TextField = ({ onChange, onKeyDown, onBlur, + onFocus, autoFocus, disabled, icon, @@ -133,6 +135,7 @@ export const TextField = ({ onChange={handleChange} onKeyDown={handleKeyDown} onBlur={handleBlur} + onFocus={onFocus} autoFocus={autoFocus} disabled={disabled} readOnly={!onChange} // It would be readonly regardless but this prevents a console warning diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/AddToDatasetDrawer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/AddToDatasetDrawer.tsx index e38ebc35d127..af51231a3e48 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/AddToDatasetDrawer.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/AddToDatasetDrawer.tsx @@ -1,5 +1,5 @@ import {Box, Typography} from '@mui/material'; -import React, {useCallback, useEffect, useState} from 'react'; +import React, {useCallback, useEffect} from 'react'; import {toast} from 'react-toastify'; import {maybePluralize} from '../../../../../core/util/string'; @@ -8,17 +8,17 @@ import {WaveLoader} from '../../../../Loaders/WaveLoader'; import {useWeaveflowRouteContext} from '../context'; import {ResizableDrawer} from '../pages/common/ResizableDrawer'; import {useWFHooks} from '../pages/wfReactInterface/context'; -import {ObjectVersionSchema} from '../pages/wfReactInterface/wfDataModelHooksInterface'; import { - DatasetEditProvider, - useDatasetEditContext, -} from './DatasetEditorContext'; + ACTION_TYPES, + DatasetDrawerProvider, + useDatasetDrawer, +} from './DatasetDrawerContext'; import {createNewDataset, updateExistingDataset} from './datasetOperations'; import {DatasetPublishToast} from './DatasetPublishToast'; import {EditAndConfirmStep} from './EditAndConfirmStep'; -import {FieldConfig, NewDatasetSchemaStep} from './NewDatasetSchemaStep'; +import {NewDatasetSchemaStep} from './NewDatasetSchemaStep'; import {SchemaMappingStep} from './SchemaMappingStep'; -import {CallData, FieldMapping} from './schemaUtils'; +import {CallData, extractSourceSchema} from './schemaUtils'; import {SelectDatasetStep} from './SelectDatasetStep'; interface AddToDatasetDrawerProps { @@ -33,9 +33,13 @@ const typographyStyle = {fontFamily: 'Source Sans Pro'}; export const AddToDatasetDrawer: React.FC = props => { return ( - + - + ); }; @@ -46,21 +50,33 @@ export const AddToDatasetDrawerInner: React.FC = ({ project, selectedCalls, }) => { - const [currentStep, setCurrentStep] = useState(1); - const [selectedDataset, setSelectedDataset] = - useState(null); - const [datasets, setDatasets] = useState([]); - const [fieldMappings, setFieldMappings] = useState([]); - const [datasetObject, setDatasetObject] = useState(null); - const [error, setError] = useState(null); - const [drawerWidth, setDrawerWidth] = useState(800); - const [isFullscreen, setIsFullscreen] = useState(false); - const [isCreating, setIsCreating] = useState(false); - const [newDatasetName, setNewDatasetName] = useState(null); - const [fieldConfigs, setFieldConfigs] = useState([]); - const [isNameValid, setIsNameValid] = useState(false); - const [datasetKey, setDatasetKey] = useState(''); - const [isCreatingNew, setIsCreatingNew] = useState(false); + const { + state, + dispatch, + handleNext, + handleBack, + handleDatasetSelect, + handleMappingChange, + handleDatasetObjectLoaded, + resetDrawerState, + isNextDisabled, + editorContext, + } = useDatasetDrawer(); + + const { + currentStep, + selectedDataset, + newDatasetName, + datasets, + isCreatingNew, + fieldMappings, + fieldConfigs, + datasetObject, + drawerWidth, + isFullscreen, + isCreating, + error, + } = state; const {peekingRouter} = useWeaveflowRouteContext(); const {useRootObjectVersions, useTableUpdate, useObjCreate, useTableCreate} = @@ -68,30 +84,12 @@ export const AddToDatasetDrawerInner: React.FC = ({ const tableUpdate = useTableUpdate(); const objCreate = useObjCreate(); const tableCreate = useTableCreate(); - const {getRowsNoMeta, convertEditsToTableUpdateSpec, resetEditState} = - useDatasetEditContext(); - // Update dataset key when the underlying dataset selection or mappings change - useEffect(() => { - if (currentStep === 1) { - // Only update key when on the selection/mapping step - setDatasetKey( - selectedDataset - ? `${selectedDataset.objectId}-${ - selectedDataset.versionHash - }-${JSON.stringify(fieldMappings)}` - : `new-dataset-${newDatasetName}-${JSON.stringify(fieldMappings)}` - ); - } - }, [currentStep, selectedDataset, newDatasetName, fieldMappings]); - - // Reset edit state only when the dataset key changes - useEffect(() => { - if (datasetKey) { - resetEditState(); - } - }, [datasetKey, resetEditState]); + // Access edit context methods through the drawer context's editorContext property + const {getRowsNoMeta, convertEditsToTableUpdateSpec, resetEditState} = + editorContext; + // Fetch datasets on component mount const objectVersions = useRootObjectVersions( entity, project, @@ -102,89 +100,37 @@ export const AddToDatasetDrawerInner: React.FC = ({ true ); + // Update datasets when data is loaded useEffect(() => { if (objectVersions.result) { - setDatasets(objectVersions.result); + dispatch({ + type: ACTION_TYPES.SET_DATASETS, + payload: objectVersions.result, + }); } - }, [objectVersions.result]); - - const handleNext = () => { - const isNewDataset = selectedDataset === null; - if (isNewDataset) { - if (!newDatasetName?.trim()) { - setError('Please enter a dataset name'); - return; - } - if (!fieldConfigs.some(config => config.included)) { - setError('Please select at least one field to include'); - return; - } - - // Create field mappings from field configs - const newMappings = fieldConfigs - .filter(config => config.included) - .map(config => ({ - sourceField: config.sourceField, - targetField: config.targetField, - })); - setFieldMappings(newMappings); - - // Create an empty dataset object structure - const newDatasetObject = { - rows: [], - schema: fieldConfigs - .filter(config => config.included) - .map(config => ({ - name: config.targetField, - type: 'string', // You might want to infer the type from the source data - })), - }; - setDatasetObject(newDatasetObject); - } - setCurrentStep(prev => Math.min(prev + 1, 2)); - }; + }, [objectVersions.result, dispatch]); - const handleBack = () => { - setCurrentStep(prev => Math.max(prev - 1, 1)); - }; - - const handleDatasetSelect = (dataset: ObjectVersionSchema | null) => { - if (dataset?.objectId !== selectedDataset?.objectId) { - resetEditState(); - setSelectedDataset(dataset); - setIsCreatingNew(dataset === null); - } else { - setSelectedDataset(dataset); - setIsCreatingNew(dataset === null); - } - }; - - const handleMappingChange = (newMappings: FieldMapping[]) => { - if (JSON.stringify(newMappings) !== JSON.stringify(fieldMappings)) { - resetEditState(); - setFieldMappings(newMappings); - } else { - setFieldMappings(newMappings); + // Extract source schema from selected calls + useEffect(() => { + if (selectedCalls.length > 0) { + const extractedSchema = extractSourceSchema(selectedCalls); + dispatch({ + type: ACTION_TYPES.SET_SOURCE_SCHEMA, + payload: extractedSchema, + }); } - }; + }, [selectedCalls, dispatch]); const projectId = `${entity}/${project}`; - const resetDrawerState = useCallback(() => { - setCurrentStep(1); - setSelectedDataset(null); - setFieldMappings([]); - setDatasetObject(null); - setError(null); - }, []); - - const handleCreate = async () => { + const handleCreate = useCallback(async () => { if (!datasetObject) { return; } - setError(null); - setIsCreating(true); + dispatch({type: ACTION_TYPES.SET_ERROR, payload: null}); + dispatch({type: ACTION_TYPES.SET_IS_CREATING, payload: true}); + try { let result: any; const isNewDataset = selectedDataset === null; @@ -239,22 +185,39 @@ export const AddToDatasetDrawerInner: React.FC = ({ ); resetDrawerState(); + resetEditState(); onClose(); } catch (error) { console.error('Failed to create dataset version:', error); - setError( - error instanceof Error ? error.message : 'An unexpected error occurred' - ); + dispatch({ + type: ACTION_TYPES.SET_ERROR, + payload: + error instanceof Error + ? error.message + : 'An unexpected error occurred', + }); } finally { - setIsCreating(false); + dispatch({type: ACTION_TYPES.SET_IS_CREATING, payload: false}); } - }; - - const isNextDisabled = - currentStep === 1 && - ((selectedDataset === null && (!newDatasetName?.trim() || !isNameValid)) || - (selectedDataset === null && - !fieldConfigs.some(config => config.included))); + }, [ + datasetObject, + selectedDataset, + newDatasetName, + projectId, + entity, + project, + selectedCalls.length, + getRowsNoMeta, + tableCreate, + objCreate, + peekingRouter, + convertEditsToTableUpdateSpec, + tableUpdate, + dispatch, + resetDrawerState, + resetEditState, + onClose, + ]); const renderStepContent = () => { const isNewDataset = selectedDataset === null; @@ -269,10 +232,27 @@ export const AddToDatasetDrawerInner: React.FC = ({ setSelectedDataset={handleDatasetSelect} datasets={datasets} newDatasetName={newDatasetName} - setNewDatasetName={setNewDatasetName} - onValidationChange={setIsNameValid} + setNewDatasetName={name => + dispatch({ + type: ACTION_TYPES.SET_NEW_DATASET_NAME, + payload: name, + }) + } + onValidationChange={isValid => + dispatch({ + type: ACTION_TYPES.SET_IS_NAME_VALID, + payload: isValid, + }) + } entity={entity} project={project} + isCreatingNew={isCreatingNew} + setIsCreatingNew={isCreatingValue => + dispatch({ + type: ACTION_TYPES.SET_IS_CREATING_NEW, + payload: isCreatingValue, + }) + } /> {showSchemaConfig && ( <> @@ -285,14 +265,19 @@ export const AddToDatasetDrawerInner: React.FC = ({ fieldMappings={fieldMappings} datasetObject={datasetObject} onMappingChange={handleMappingChange} - onDatasetObjectLoaded={setDatasetObject} + onDatasetObjectLoaded={handleDatasetObjectLoaded} /> )} {isNewDataset && ( + dispatch({ + type: ACTION_TYPES.SET_FIELD_CONFIGS, + payload: configs, + }) + } /> )} @@ -318,7 +303,10 @@ export const AddToDatasetDrawerInner: React.FC = ({ open={open} onClose={onClose} defaultWidth={isFullscreen ? window.innerWidth - 73 : drawerWidth} - setWidth={width => !isFullscreen && setDrawerWidth(width)}> + setWidth={width => + !isFullscreen && + dispatch({type: ACTION_TYPES.SET_DRAWER_WIDTH, payload: width}) + }> {isCreating ? ( = ({ {currentStep === 2 && ( + + + + + + + + +
+ {state.addedRowsDirty.toString()} +
+
+ {state.processedRows.size.toString()} +
+ + ); +}; + +describe('DatasetDrawerContext', () => { + test('provides initial state', () => { + render( + {}} + entity="test-entity" + project="test-project"> + + + ); + + const stateElement = screen.getByTestId('state'); + const state = JSON.parse(stateElement.textContent || '{}'); + + expect(state.currentStep).toBe(1); + expect(state.selectedDataset).toBeNull(); + expect(state.fieldMappings).toEqual([]); + expect(state.isCreatingNew).toBe(false); + }); + + test('updates state when selecting a dataset', () => { + render( + {}} + entity="test-entity" + project="test-project"> + + + ); + + fireEvent.click(screen.getByTestId('select-dataset')); + + const stateElement = screen.getByTestId('state'); + const state = JSON.parse(stateElement.textContent || '{}'); + + expect(state.selectedDataset).toMatchObject({ + objectId: 'test-dataset', + entity: 'test-entity', + project: 'test-project', + versionHash: 'abc123', + versionIndex: 1, + }); + }); + + test('updates field mappings', () => { + render( + {}} + entity="test-entity" + project="test-project"> + + + ); + + fireEvent.click(screen.getByTestId('set-mappings')); + + const stateElement = screen.getByTestId('state'); + const state = JSON.parse(stateElement.textContent || '{}'); + + expect(state.fieldMappings).toEqual([ + {sourceField: 'text', targetField: 'text'}, + {sourceField: 'inputs.prompt', targetField: 'prompt'}, + ]); + }); + + test('navigates between steps', () => { + render( + {}} + entity="test-entity" + project="test-project"> + + + ); + + fireEvent.click(screen.getByTestId('next-step')); + + const stateElement = screen.getByTestId('state'); + const state = JSON.parse(stateElement.textContent || '{}'); + + expect(state.currentStep).toBe(2); + }); + + test('resets state', () => { + render( + {}} + entity="test-entity" + project="test-project"> + + + ); + + // First select a dataset and set mappings + fireEvent.click(screen.getByTestId('select-dataset')); + fireEvent.click(screen.getByTestId('set-mappings')); + + // Then reset + fireEvent.click(screen.getByTestId('reset')); + + const stateElement = screen.getByTestId('state'); + const state = JSON.parse(stateElement.textContent || '{}'); + + expect(state.selectedDataset).toBeNull(); + expect(state.fieldMappings).toEqual([]); + expect(state.currentStep).toBe(1); + }); + + test('handles create new mode correctly', () => { + render( + {}} + entity="test-entity" + project="test-project"> + + + ); + + // First select a dataset + fireEvent.click(screen.getByTestId('select-dataset')); + + // Then switch to create new mode + fireEvent.click(screen.getByTestId('set-creating-new')); + + const stateElement = screen.getByTestId('state'); + const state = JSON.parse(stateElement.textContent || '{}'); + + expect(state.isCreatingNew).toBe(true); + expect(state.selectedDataset).toBeNull(); + }); + + test('suggests field mappings based on schema similarity', () => { + render( + {}} + entity="test-entity" + project="test-project"> + + + ); + + // Set up schemas that should produce mapping suggestions + fireEvent.click(screen.getByTestId('setup-schemas-for-mapping')); + + // Get the updated state + const stateElement = screen.getByTestId('state'); + const state = JSON.parse(stateElement.textContent || '{}'); + + // Verify the current behavior of field mapping suggestions + + // 1. Should have the correct number of mappings + expect(state.fieldMappings.length).toBe(2); + + // 2. Should handle exact matches (text to text) + expect(state.fieldMappings).toContainEqual({ + targetField: 'text', + sourceField: 'text', + }); + + // 3. Should handle substring matches (prompt in inputs.prompt) + const promptMapping = state.fieldMappings.find( + (m: any) => m.targetField === 'prompt' + ); + expect(promptMapping).toBeDefined(); + expect(promptMapping?.sourceField).toBe('inputs.prompt'); + + // 4. Should NOT currently map model_type to inputs.model + const modelTypeMapping = state.fieldMappings.find( + (m: any) => m.targetField === 'model_type' + ); + expect(modelTypeMapping).toBeUndefined(); + + // 5. Should NOT currently map result to output + const resultMapping = state.fieldMappings.find( + (m: any) => m.targetField === 'result' + ); + expect(resultMapping).toBeUndefined(); + }); + + test('does not reprocess rows when navigating between steps unless mappings are modified', () => { + // Create mock calls data + const mockCalls = [ + { + id: 'call1', + inputs: {text: 'Sample input'}, + output: 'Sample output', + }, + ]; + + render( + {}} + entity="test-entity" + project="test-project"> + + + ); + + // Set up the necessary state for the test + fireEvent.click(screen.getByTestId('select-dataset')); + fireEvent.click(screen.getByTestId('set-dataset-object')); + + // Verify initial state + expect(screen.getByTestId('added-rows-dirty').textContent).toBe('true'); + + // Mark mappings as dirty to trigger row processing + fireEvent.click(screen.getByTestId('set-added-rows-dirty')); + + // Navigate to step 2 - this should process rows since addedRowsDirty is true + fireEvent.click(screen.getByTestId('next-step')); + + // Verify that addedRowsDirty was reset after processing + expect(screen.getByTestId('added-rows-dirty').textContent).toBe('false'); + + // Navigate back to step 1 + fireEvent.click(screen.getByTestId('prev-step')); + + // Navigate to step 2 again - this should NOT process rows since addedRowsDirty is false + const initialProcessedRowsSize = screen.getByTestId( + 'processed-rows-size' + ).textContent; + fireEvent.click(screen.getByTestId('next-step')); + + // Verify that processed rows size hasn't changed + expect(screen.getByTestId('processed-rows-size').textContent).toBe( + initialProcessedRowsSize + ); + + // Now modify mappings which should set addedRowsDirty to true + fireEvent.click(screen.getByTestId('prev-step')); + fireEvent.click(screen.getByTestId('set-mappings')); + + // Verify that addedRowsDirty is now true + expect(screen.getByTestId('added-rows-dirty').textContent).toBe('true'); + + // Navigate to step 2 again - this should process rows since addedRowsDirty is true + fireEvent.click(screen.getByTestId('next-step')); + + // Verify that addedRowsDirty was reset after processing + expect(screen.getByTestId('added-rows-dirty').textContent).toBe('false'); + }); +}); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/schemaUtils.test.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/schemaUtils.test.ts index 24a74f9e185c..f19d9382b28e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/schemaUtils.test.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/schemaUtils.test.ts @@ -64,6 +64,16 @@ describe('flattenObject', () => { const obj = {created: date}; expect(flattenObject(obj)).toEqual([{name: 'created', type: 'date'}]); }); + + test('handles null values', () => { + const obj = {value: null}; + expect(flattenObject(obj)).toEqual([{name: 'value', type: 'null'}]); + }); + + test('handles undefined values', () => { + const obj = {value: undefined}; + expect(flattenObject(obj)).toEqual([{name: 'value', type: 'undefined'}]); + }); }); describe('inferSchema', () => { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/schemaUtils.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/schemaUtils.ts index 9a18e6592665..c4b9e4556f2c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/schemaUtils.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/datasets/schemaUtils.ts @@ -1,5 +1,3 @@ -import {v4 as uuidv4} from 'uuid'; - import {TraceCallSchema} from '../pages/wfReactInterface/traceServerClientTypes'; export interface SchemaField { @@ -26,13 +24,13 @@ export const inferType = (value: any): string => { export const flattenObject = (obj: any, prefix = ''): SchemaField[] => { let fields: SchemaField[] = []; + // Return empty array for null or undefined inputs + if (obj == null) { + return fields; + } + // Special handling for __ref__ and __val__ pattern - if ( - obj != null && - typeof obj === 'object' && - '__ref__' in obj && - '__val__' in obj - ) { + if (typeof obj === 'object' && '__ref__' in obj && '__val__' in obj) { return flattenObject(obj.__val__, prefix); } @@ -92,10 +90,42 @@ export interface FieldMapping { targetField: string; } +/** + * Get a nested value from an object using a path array + * @param obj The object to extract value from + * @param path Array of property names to traverse + * @returns The value at the specified path or undefined if path doesn't exist + */ +export const getNestedValue = (obj: any, path: string[]): any => { + let current = obj; + for (const part of path) { + if (current == null) { + return undefined; + } + if (typeof current === 'object' && '__val__' in current) { + current = current.__val__; + } + if (typeof current !== 'object') { + return current; + } + current = current[part]; + } + return current; +}; + export const extractSourceSchema = (calls: CallData[]): SchemaField[] => { const allFields: SchemaField[] = []; + if (!calls || !Array.isArray(calls)) { + return allFields; + } + calls.forEach(call => { + // Skip if call or call.val is undefined + if (!call || !call.val) { + return; + } + if (call.val.inputs) { allFields.push(...flattenObject(call.val.inputs, 'inputs')); } @@ -195,7 +225,7 @@ export const suggestMappings = ( * Each returned row will: * - Be formatted for use in MUI DataGrid components * - Include a ___weave namespace containing metadata used by Weave's custom hooks and callbacks: - * - id: A unique identifier prefixed with "new-" + * - id: The digest of the call * - isNew: Flag indicating this is a newly created row * - Include mapped values from the call's inputs/outputs based on fieldMappings * - Only include fields where the source value is defined @@ -248,10 +278,182 @@ export const mapCallsToDatasetRows = ( return { ___weave: { - id: `new-${uuidv4()}`, + id: call.digest, isNew: true, }, ...row, }; }); }; + +/** + * Filters row data for new datasets based on target fields. + * + * @param mappedRows - The rows mapped from calls + * @param targetFields - Set of target field names to include + * @returns Filtered rows containing only the specified target fields + */ +export function filterRowsForNewDataset( + mappedRows: Array<{ + ___weave: {id: string; isNew: boolean}; + [key: string]: any; + }>, + targetFields: Set +): Array<{___weave: {id: string; isNew: boolean}; [key: string]: any}> { + return mappedRows + .map(row => { + try { + if (!row || typeof row !== 'object' || !row.___weave) { + return undefined; + } + + const {___weave, ...rest} = row; + const filteredData = Object.fromEntries( + Object.entries(rest).filter(([key]) => targetFields.has(key)) + ); + return { + ___weave, + ...filteredData, + }; + } catch (rowError) { + console.error('Error processing row:', rowError); + return undefined; + } + }) + .filter(row => row !== undefined) as Array<{ + ___weave: {id: string; isNew: boolean}; + [key: string]: any; + }>; +} + +/** + * Creates a map of processed rows with schema-based filtering. + * + * @param mappedRows - The rows mapped from calls + * @param datasetObject - The dataset object containing schema information + * @returns A Map of row IDs to processed row data + */ +export function createProcessedRowsMap( + mappedRows: Array<{ + ___weave: {id: string; isNew: boolean}; + [key: string]: any; + }>, + datasetObject: any +): Map { + return new Map( + mappedRows + .filter(row => row && row.___weave && row.___weave.id) + .map(row => { + // If datasetObject has a schema, filter row properties to match schema fields + if (datasetObject?.schema && Array.isArray(datasetObject.schema)) { + const schemaFields = new Set( + datasetObject.schema.map((f: {name: string}) => f.name) + ); + const {___weave, ...rest} = row; + + // Only include fields that are in the schema + const filteredData = Object.fromEntries( + Object.entries(rest).filter(([key]) => schemaFields.has(key)) + ); + + return [ + row.___weave.id, + { + ...filteredData, + ___weave: {...row.___weave, serverValue: filteredData}, + }, + ]; + } + + // Default case - keep all fields + return [ + row.___weave.id, + {...row, ___weave: {...row.___weave, serverValue: row}}, + ]; + }) + ); +} + +/** + * Suggests field mappings between source and target schemas. + * + * This function attempts to match fields between schemas using various strategies: + * 1. Preserves existing mappings if the fields still exist + * 2. Matches fields with identical names + * 3. Matches fields where one name contains the other + * + * @param sourceSchema - Array of fields in the source schema + * @param targetSchema - Array of fields in the target schema + * @param existingMappings - Optional array of existing mappings to preserve + * @returns Array of suggested field mappings + */ +export const suggestFieldMappings = ( + sourceSchema: any[], + targetSchema: any[], + existingMappings: FieldMapping[] = [] +): FieldMapping[] => { + if (!sourceSchema.length || !targetSchema.length) { + return existingMappings; + } + + // Create mapping table of existing mappings for quick lookup + const existingMappingsMap = new Map(); + existingMappings.forEach(mapping => { + existingMappingsMap.set(mapping.targetField, mapping.sourceField); + }); + + // Create a new array of suggested mappings + const newMappings: FieldMapping[] = []; + + // Attempt to match fields by name + targetSchema.forEach(targetField => { + // If there's already a mapping for this target field, keep it + if (existingMappingsMap.has(targetField.name)) { + newMappings.push({ + targetField: targetField.name, + sourceField: existingMappingsMap.get(targetField.name)!, + }); + return; + } + + // Try to find a matching source field by exact name + const exactMatch = sourceSchema.find( + sourceField => sourceField.name === targetField.name + ); + if (exactMatch) { + newMappings.push({ + targetField: targetField.name, + sourceField: exactMatch.name, + }); + return; + } + + // Try to find a matching source field by name containing the target field name + const containsMatch = sourceSchema.find(sourceField => + sourceField.name.toLowerCase().includes(targetField.name.toLowerCase()) + ); + if (containsMatch) { + newMappings.push({ + targetField: targetField.name, + sourceField: containsMatch.name, + }); + return; + } + + // Try to find a matching source field where the target field name contains the source field name + const reverseContainsMatch = sourceSchema.find(sourceField => + targetField.name.toLowerCase().includes(sourceField.name.toLowerCase()) + ); + if (reverseContainsMatch) { + newMappings.push({ + targetField: targetField.name, + sourceField: reverseContainsMatch.name, + }); + return; + } + + // No matches found, leave this target field unmapped + }); + + return newMappings; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectsPage/ObjectVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectsPage/ObjectVersionPage.tsx index f57c951b17e4..0e94af8a7546 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectsPage/ObjectVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectsPage/ObjectVersionPage.tsx @@ -62,6 +62,7 @@ const OBJECT_ICONS: Record = { Model: 'model', Dataset: 'table', Evaluation: 'baseline-alt', + EvaluationResults: 'baseline-alt', Leaderboard: 'benchmark-square', Scorer: 'type-number-alt', ActionSpec: 'rocket-launch', diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectsPage/ObjectVersionsTable.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectsPage/ObjectVersionsTable.tsx index ec468459bafc..39fd9e018b78 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectsPage/ObjectVersionsTable.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectsPage/ObjectVersionsTable.tsx @@ -162,6 +162,9 @@ export const ObjectVersionsTable: React.FC<{ // with the dynamic fields added below. basicField('weave__object_version_link', props.objectTitle ?? 'Object', { hideable: false, + valueGetter: (unused: any, row: any) => { + return row.obj.objectId; + }, renderCell: cellParams => { // Icon to indicate navigation to the object version const obj: ObjectVersionSchema = cellParams.row.obj; @@ -244,7 +247,7 @@ export const ObjectVersionsTable: React.FC<{ if (!props.hideCategoryColumn) { cols.push( basicField('baseObjectClass', 'Category', { - width: 120, + width: 132, display: 'flex', valueGetter: (unused: any, row: any) => { return row.obj.baseObjectClass; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpsPage/OpVersionsPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpsPage/OpVersionsPage.tsx index fe8c44b41ce4..5ca76bcbf26c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpsPage/OpVersionsPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpsPage/OpVersionsPage.tsx @@ -109,6 +109,9 @@ export const FilterableOpVersionsTable: React.FC<{ const columns: GridColDef[] = [ basicField('op', 'Op', { hideable: false, + valueGetter: (unused: any, row: any) => { + return row.obj.opId; + }, renderCell: cellParams => { // Icon to indicate navigation to the object version const obj: OpVersionSchema = cellParams.row.obj; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatInput.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatInput.tsx index 2a9e0e6b8dd9..7e7b0b2e61af 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatInput.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatInput.tsx @@ -140,6 +140,7 @@ export const PlaygroundChatInput: React.FC = ({ variant="secondary" size="medium" startIcon="add-new" + disabled={isLoading || chatText.trim() === ''} onClick={() => handleAdd(addMessageRole, chatText)}> Add diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatCompletionFunctions.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatCompletionFunctions.tsx index 23c30c7e5584..31968610a5da 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatCompletionFunctions.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatCompletionFunctions.tsx @@ -67,17 +67,19 @@ export const useChatCompletionFunctions = ( setIsLoading(true); const newMessageContent = content || chatText; const newMessage = createMessage(role, newMessageContent, toolCallId); - const updatedStates = playgroundStates.map((state, index) => { - if (callIndex !== undefined && callIndex !== index) { - return state; - } - const updatedState = appendChoiceToMessages(state); - // If the new message is not empty, add it to the messages - if (newMessageContent && updatedState.traceCall?.inputs?.messages) { - updatedState.traceCall.inputs.messages.push(newMessage); - } - return updatedState; - }); + const updatedStates = filterNullMessagesFromStates( + playgroundStates.map((state, index) => { + if (callIndex !== undefined && callIndex !== index) { + return state; + } + const updatedState = appendChoiceToMessages(state); + // If the new message is not empty, add it to the messages + if (newMessageContent && updatedState.traceCall?.inputs?.messages) { + updatedState.traceCall.inputs.messages.push(newMessage); + } + return updatedState; + }) + ); setPlaygroundStates(updatedStates); setChatText(''); @@ -105,20 +107,25 @@ export const useChatCompletionFunctions = ( ) => { try { setIsLoading(true); - const updatedStates = playgroundStates.map((state, index) => { - if (index === callIndex) { - if (choiceIndex !== undefined) { - return appendChoiceToMessages(state, choiceIndex); - } - const updatedState = JSON.parse(JSON.stringify(state)); - if (updatedState.traceCall?.inputs?.messages) { - updatedState.traceCall.inputs.messages = - updatedState.traceCall.inputs.messages.slice(0, messageIndex + 1); + const updatedStates = filterNullMessagesFromStates( + playgroundStates.map((state, index) => { + if (index === callIndex) { + if (choiceIndex !== undefined) { + return appendChoiceToMessages(state, choiceIndex); + } + const updatedState = JSON.parse(JSON.stringify(state)); + if (updatedState.traceCall?.inputs?.messages) { + updatedState.traceCall.inputs.messages = + updatedState.traceCall.inputs.messages.slice( + 0, + messageIndex + 1 + ); + } + return updatedState; } - return updatedState; - } - return state; - }); + return state; + }) + ); const response = await makeCompletionRequest(callIndex, updatedStates); await handleErrorsAndUpdate( @@ -230,3 +237,56 @@ const appendChoiceToMessages = ( } return updatedState; }; + +/** + * Filters out null messages from a PlaygroundState + * + * @param state The PlaygroundState to filter + * @returns A new PlaygroundState with null messages filtered out + */ +export const filterNullMessages = (state: PlaygroundState): PlaygroundState => { + if ( + !state.traceCall || + !state.traceCall.inputs || + !state.traceCall.inputs.messages + ) { + return state; + } + + const messages = state.traceCall.inputs.messages as Message[]; + const filteredMessages = messages.filter( + message => + message !== null && + typeof message === 'object' && + (message.content !== null || message.tool_calls !== null) && + message.content !== '' + ); + + // Only create a new state if messages were actually filtered out + if (filteredMessages.length === messages.length) { + return state; + } + + return { + ...state, + traceCall: { + ...state.traceCall, + inputs: { + ...state.traceCall.inputs, + messages: filteredMessages, + }, + }, + }; +}; + +/** + * Filters out null messages from an array of PlaygroundStates + * + * @param states Array of PlaygroundStates to filter + * @returns A new array of PlaygroundStates with null messages filtered out + */ +export const filterNullMessagesFromStates = ( + states: PlaygroundState[] +): PlaygroundState[] => { + return states.map(filterNullMessages); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ZodSchemaForm.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ZodSchemaForm.tsx index 1a03082e9c47..b96b94fbc2d1 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ZodSchemaForm.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/ZodSchemaForm.tsx @@ -146,6 +146,49 @@ const DiscriminatedUnionField: React.FC<{ ); }; +const useNestedField = ( + config: Record, + setConfig: (config: Record) => void, + path: string[], + keyName: string +) => { + const currentPath = useMemo(() => [...path, keyName], [path, keyName]); + const [currentValue, setCurrentValue] = useState( + getNestedValue(config, currentPath) + ); + + // The text field interacts very poorly with the "auto-refresh" feature + // in the main app, causing cursor jumps and deletion. + // We handle state updates in the following way: + // 1. when the field is focused, update the local state to the current value + // from the parent config + // 2. only update the parent config when the field is blurred + // not on every keystroke. This prevents the cursor from being reset + // to the front of the field on external component updates. + const [isFocused, setIsFocused] = useState(false); + const handleBlur = useCallback(() => { + setIsFocused(false); + if (currentValue !== getNestedValue(config, currentPath)) { + updateConfig(currentPath, currentValue, config, setConfig); + } + }, [currentValue, currentPath, config, setConfig]); + const handleFocus = useCallback(() => { + setIsFocused(true); + }, []); + useEffect(() => { + if (!isFocused) { + setCurrentValue(getNestedValue(config, currentPath)); + } + }, [config, currentPath, isFocused]); + + // Update the local state when the user types + const handleChange = useCallback((value: string) => { + setCurrentValue(value); + }, []); + + return {currentPath, currentValue, handleBlur, handleFocus, handleChange}; +}; + const NestedForm: React.FC<{ keyName: string; fieldSchema: z.ZodTypeAny; @@ -163,25 +206,8 @@ const NestedForm: React.FC<{ hideLabel, autoFocus, }) => { - const currentPath = useMemo(() => [...path, keyName], [path, keyName]); - const [currentValue, setCurrentValue] = useState( - getNestedValue(config, currentPath) - ); - - // Only update parent config on blur, for string fields - const handleBlur = useCallback(() => { - if (currentValue !== getNestedValue(config, currentPath)) { - updateConfig(currentPath, currentValue, config, setConfig); - } - }, [currentValue, currentPath, config, setConfig]); - const handleChange = useCallback((value: string) => { - setCurrentValue(value); - }, []); - - // set current value for non-string fields - useEffect(() => { - setCurrentValue(getNestedValue(config, currentPath)); - }, [config, currentPath]); + const {currentPath, currentValue, handleBlur, handleFocus, handleChange} = + useNestedField(config, setConfig, path, keyName); const unwrappedSchema = unwrapSchema(fieldSchema); const isOptional = fieldSchema instanceof z.ZodOptional; @@ -316,6 +342,7 @@ const NestedForm: React.FC<{ value={currentValue ?? ''} onChange={handleChange} onBlur={handleBlur} + onFocus={handleFocus} autoFocus={autoFocus} /> ); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx index 9ba8dce1d819..02817b0c8ab0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/TypeVersionCategoryChip.tsx @@ -13,6 +13,7 @@ const colorMap: Record = { ActionSpec: 'sienna', AnnotationSpec: 'magenta', SavedView: 'magenta', + EvaluationResults: 'moon', }; export const TypeVersionCategoryChip: React.FC<{ diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts index 331e507a3d94..6ce5c7993d3a 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/constants.ts @@ -29,4 +29,5 @@ export const KNOWN_BASE_OBJECT_CLASSES = [ 'ActionSpec', 'AnnotationSpec', 'SavedView', + 'EvaluationResults', ] as const; diff --git a/weave-js/src/components/Panel2/ImageWithOverlays.tsx b/weave-js/src/components/Panel2/ImageWithOverlays.tsx index 27039a4ec7ab..48b1156a12bc 100644 --- a/weave-js/src/components/Panel2/ImageWithOverlays.tsx +++ b/weave-js/src/components/Panel2/ImageWithOverlays.tsx @@ -45,6 +45,7 @@ interface CardImageProps { path: string; width: number; height: number; + caption?: string; }; imageFileNode: Node; masks?: Array<{loadedFrom: Node; path: string}>; @@ -77,6 +78,9 @@ export const CardImage: FC = ({ position: 'absolute', height: '100%', width: '100%', + top: 0, + left: 0, + objectFit: 'contain', } as const; return ( @@ -84,57 +88,77 @@ export const CardImage: FC = ({ data-test="card-image" style={{ height: '100%', - width: '100%', - position: 'relative', + overflow: 'auto', }}> {signedUrl == null ? (
) : ( <> {!hideImage && ( - {image.path} + <> +
+ {image.path} + {masks != null && + maskControls?.map((maskControl, i) => { + const mask = masks[i]; + if (maskControl != null) { + const classSet = (classSets ?? {})[ + maskControl.classSetID + ]; + return ( + + ); + } + return undefined; + })} + + {boundingBoxes != null && + boxControls?.map((boxControl, i) => { + if (boxControl != null) { + const classSet = classSets?.[boxControl.classSetID]; + return ( + + ); + } + return undefined; + })} +
+ + {image.caption && ( +
+ {image.caption} +
+ )} + )} - {masks != null && - maskControls?.map((maskControl, i) => { - const mask = masks[i]; - if (maskControl != null) { - const classSet = (classSets ?? {})[maskControl.classSetID]; - return ( - - ); - } - return undefined; - })} - - {boundingBoxes != null && - boxControls?.map((boxControl, i) => { - if (boxControl != null) { - const classSet = classSets?.[boxControl.classSetID]; - return ( - - ); - } - return undefined; - })} )}
diff --git a/weave-js/src/components/Panel2/PanelImage.tsx b/weave-js/src/components/Panel2/PanelImage.tsx index 07148ced750c..6464d5c8711c 100644 --- a/weave-js/src/components/Panel2/PanelImage.tsx +++ b/weave-js/src/components/Panel2/PanelImage.tsx @@ -209,6 +209,7 @@ const PanelImage: FC = ({config, input}) => { loadedFrom: imageArtifact, width: image?.width, height: image?.height, + caption: image?.caption, }; if (tileLayout === 'MASKS_NEXT_TO_IMAGE') { diff --git a/weave-js/src/components/Panel2/PanelRunHistoryTablesStepper/index.tsx b/weave-js/src/components/Panel2/PanelRunHistoryTablesStepper/index.tsx index 42a915a4eab3..604655bc7d17 100644 --- a/weave-js/src/components/Panel2/PanelRunHistoryTablesStepper/index.tsx +++ b/weave-js/src/components/Panel2/PanelRunHistoryTablesStepper/index.tsx @@ -1,5 +1,6 @@ import SliderInput from '@wandb/weave/common/components/elements/SliderInput'; import {TABLE_FILE_TYPE} from '@wandb/weave/common/types/file'; +import {getTableKeysFromNodeType} from '@wandb/weave/common/util/table'; import { constFunction, constNumber, @@ -44,6 +45,7 @@ type PanelRunHistoryTablesStepperProps = Panel2.PanelProps< const getTableKeysFromRunsHistoryPropertyType = ( runsHistoryPropertyType: Type | undefined ) => { + // Case where keys across runs vary (hence the union) if ( runsHistoryPropertyType && isUnion(nullableTaggableValue(listObjectType(runsHistoryPropertyType))) @@ -62,8 +64,16 @@ const getTableKeysFromRunsHistoryPropertyType = ( }, [] ); + return [...new Set(tableKeys)].sort(); } + + // Case where keys across runs are the same + if (runsHistoryPropertyType) { + const {tableKeys} = getTableKeysFromNodeType(runsHistoryPropertyType); + return tableKeys.sort(); + } + return []; }; @@ -123,7 +133,7 @@ const PanelRunHistoryTablesStepperConfig: React.FC< const PanelRunHistoryTablesStepper: React.FC< PanelRunHistoryTablesStepperProps > = props => { - const [currentStep, setCurrentStep] = useState(0); + const [currentStep, setCurrentStep] = useState(-1); const [steps, setSteps] = useState([]); const {input} = props; @@ -166,7 +176,7 @@ const PanelRunHistoryTablesStepper: React.FC< }); const exampleRowRefined = useNodeWithServerType(exampleRow); let defaultNode: NodeOrVoidNode = voidNode(); - if (currentStep) { + if (currentStep != null && currentStep >= 0 && tableHistoryKey) { // This performs the following weave expression: // runs.history.concat.filter((row) => row._step == )[].concat defaultNode = opConcat({ diff --git a/weave-js/src/core/_external/backendProviders/serverApiTest.ts b/weave-js/src/core/_external/backendProviders/serverApiTest.ts index 57d2f5946936..55096795cce3 100644 --- a/weave-js/src/core/_external/backendProviders/serverApiTest.ts +++ b/weave-js/src/core/_external/backendProviders/serverApiTest.ts @@ -1,6 +1,7 @@ import * as _ from 'lodash'; import {hash} from '../../model/graph/editing/hash'; +import type {DirMetadata, FileMetadata} from '../../model/types'; // CG imports import * as Types from '../../model/types'; import * as ServerApi from '../../serverApi'; @@ -1072,6 +1073,11 @@ const FILE_METADATA: { }, }; +// Todo: Not used in any test yet +const MEMBERSHIP_FILE_METADATA: { + [artifactId: string]: {[path: string]: Types.MetadataNode}; +} = {}; + type ArgsType = {[key: string]: any}; const toMany = ( @@ -1433,4 +1439,33 @@ export class Client implements ServerApi.ServerAPI { }, 1); }); } + + // TODO: NOT DONE + getArtifactMembershipFileMetadata( + artifactCollectionMembershipId: string, + entityName: string, + projectName: string, + collectionName: string, + artifactVersionIndex: string, + assetPath: string + ): Promise { + return new Promise(resolve => { + // Delay for testing + // TODO: Unnecessary + setTimeout(() => { + const metadata = + MEMBERSHIP_FILE_METADATA?.[artifactCollectionMembershipId]?.[ + assetPath + ]; + if (metadata == null) { + throw new Error( + `serverApiTest missing metadata for artifact membership's artifact path ${artifactCollectionMembershipId} "${assetPath}"` + ); + } + resolve(metadata); + // const contents = FILES?.[artifactId]?.[assetPath] ?? null; + // resolve({refFileId: null, contents}); + }, 1); + }); + } } diff --git a/weave-js/src/core/_external/backendProviders/serverApiTest2.ts b/weave-js/src/core/_external/backendProviders/serverApiTest2.ts index c64eebc77f01..8ea69272ed65 100644 --- a/weave-js/src/core/_external/backendProviders/serverApiTest2.ts +++ b/weave-js/src/core/_external/backendProviders/serverApiTest2.ts @@ -1,5 +1,10 @@ import * as String from '@wandb/weave/common/util/string'; -import {MetadataNode, ServerAPI} from '@wandb/weave/core'; +import { + type DirMetadata, + type FileMetadata, + MetadataNode, + ServerAPI, +} from '@wandb/weave/core'; import * as _ from 'lodash'; import * as Vega3 from '../util/vega3'; @@ -1016,6 +1021,11 @@ const FILE_METADATA: { }, }; +// Todo: Not used in any test yet +const MEMBERSHIP_FILE_METADATA: { + [artifactId: string]: {[path: string]: MetadataNode}; +} = {}; + function resolveRootProject(field: Vega3.QueryField) { const entityNameArg = field.args?.find(a => a.name === 'entityName')?.value; const projectNameArg = field.args?.find(a => a.name === 'name')?.value; @@ -1504,4 +1514,32 @@ export class Client implements ServerAPI { }, 1); }); } + + getArtifactMembershipFileMetadata( + artifactCollectionMembershipId: string, + entityName: string, + projectName: string, + collectionName: string, + artifactVersionIndex: string, + assetPath: string + ): Promise { + return new Promise(resolve => { + // Delay for testing + // TODO: Unnecessary + setTimeout(() => { + const metadata = + MEMBERSHIP_FILE_METADATA?.[artifactCollectionMembershipId]?.[ + assetPath + ]; + if (metadata == null) { + throw new Error( + `serverApiTest missing metadata for artifact membership's artifact path ${artifactCollectionMembershipId} "${assetPath}"` + ); + } + resolve(metadata); + // const contents = FILES?.[artifactId]?.[assetPath] ?? null; + // resolve({refFileId: null, contents}); + }, 1); + }); + } } diff --git a/weave-js/src/core/executeSync.ts b/weave-js/src/core/executeSync.ts index 239f2b72c657..83911161dd2c 100644 --- a/weave-js/src/core/executeSync.ts +++ b/weave-js/src/core/executeSync.ts @@ -116,6 +116,17 @@ class ThrowingPlaceholderServer implements ServerAPI { ): Promise { throw new Error(`Cannot getArtifactFileMetadata`); } + + getArtifactMembershipFileMetadata( + artifactCollectionMembershipId: string, + entityName: string, + projectName: string, + collectionName: string, + artifactVersionIndex: string, + assetPath: string + ): Promise { + throw new Error(`Cannot getArtifactMembershipFileMetadata`); + } } const syncResolverContext: ResolverContext = { diff --git a/weave-js/src/core/model/media/mediaImage.ts b/weave-js/src/core/model/media/mediaImage.ts index b87ba2c40681..efa56fe61929 100644 --- a/weave-js/src/core/model/media/mediaImage.ts +++ b/weave-js/src/core/model/media/mediaImage.ts @@ -25,6 +25,7 @@ export interface WBImage { [maskName: string]: MaskFile; }; classes?: ClassesFile; + caption?: string; } export interface ClassSet { diff --git a/weave-js/src/core/ops/domain/artifactMembership.ts b/weave-js/src/core/ops/domain/artifactMembership.ts index 456e4b519baf..e9ae34032aa7 100644 --- a/weave-js/src/core/ops/domain/artifactMembership.ts +++ b/weave-js/src/core/ops/domain/artifactMembership.ts @@ -116,9 +116,17 @@ export const opArtifactMembershipFile = makeStandardOp({ if (artifactMembership.artifact == null) { throw new Error('opArtifactMembershipFile missing artifact'); } + const artifactCollection = artifactMembership.artifactCollection; + if (artifactCollection == null) { + throw new Error('opArtifactMembershipFile missing artifactCollection'); + } try { - const result = await context.backend.getArtifactFileMetadata( - artifactMembership.artifact.id, + const result = await context.backend.getArtifactMembershipFileMetadata( + artifactMembership.id, + artifactCollection.project.entityName, + artifactCollection.project.name, + artifactCollection.name, + `v${artifactMembership.versionIndex}`, path ); if (result == null) { diff --git a/weave-js/src/core/ops/domain/gql.ts b/weave-js/src/core/ops/domain/gql.ts index 5d3696344d42..7ca411df8d15 100644 --- a/weave-js/src/core/ops/domain/gql.ts +++ b/weave-js/src/core/ops/domain/gql.ts @@ -1442,10 +1442,31 @@ export const toGqlField = ( ]; } else if (forwardOp.op.name === 'artifactMembership-file') { return [ + {name: 'versionIndex', fields: []}, { name: 'artifact', fields: gqlBasicField('id'), }, + { + name: 'artifactCollection', + fields: gqlBasicField('id').concat([ + { + name: 'defaultArtifactType', + fields: gqlBasicField('id').concat([{name: 'name', fields: []}]), + }, + {name: 'name', fields: []}, + { + name: 'project', + fields: gqlBasicField('id').concat([ + {name: 'name', fields: []}, + { + name: 'entityName', + fields: [], + }, + ]), + }, + ]), + }, ]; } else if (forwardOp.op.name === 'artifact-memberships') { return [ diff --git a/weave-js/src/core/serverApi.ts b/weave-js/src/core/serverApi.ts index 6cd345c1c15c..6793bf1637fa 100644 --- a/weave-js/src/core/serverApi.ts +++ b/weave-js/src/core/serverApi.ts @@ -27,6 +27,15 @@ export interface ServerAPI { assetPath: string ): Promise; + getArtifactMembershipFileMetadata( + artifactCollectionMembershipId: string, + entityName: string, + projectName: string, + collectionName: string, + artifactVersionIndex: string, + assetPath: string + ): Promise; + getRunFileContents( projectName: string, runName: string, diff --git a/weave/flow/dataset.py b/weave/flow/dataset.py index 7a3e9e7d1544..93f5be1e19dd 100644 --- a/weave/flow/dataset.py +++ b/weave/flow/dataset.py @@ -1,11 +1,13 @@ from collections.abc import Iterable, Iterator -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Union from pydantic import field_validator from typing_extensions import Self import weave from weave.flow.obj import Object +from weave.flow.util import short_str +from weave.trace.isinstance import weave_isinstance from weave.trace.objectify import register_object from weave.trace.vals import WeaveObject, WeaveTable from weave.trace.weave_client import Call @@ -14,13 +16,6 @@ import pandas as pd -def short_str(obj: Any, limit: int = 25) -> str: - str_val = str(obj) - if len(str_val) > limit: - return str_val[:limit] + "..." - return str_val - - @register_object class Dataset(Object): """ @@ -47,7 +42,7 @@ class Dataset(Object): ``` """ - rows: weave.Table + rows: Union[weave.Table, WeaveTable] @classmethod def from_obj(cls, obj: WeaveObject) -> Self: @@ -77,11 +72,11 @@ def to_pandas(self) -> "pd.DataFrame": return pd.DataFrame(self.rows) @field_validator("rows", mode="before") - def convert_to_table(cls, rows: Any) -> weave.Table: + def convert_to_table(cls, rows: Any) -> Union[weave.Table, WeaveTable]: + if weave_isinstance(rows, WeaveTable): + return rows if not isinstance(rows, weave.Table): table_ref = getattr(rows, "table_ref", None) - if isinstance(rows, WeaveTable): - rows = list(rows) rows = weave.Table(rows) if table_ref: rows.table_ref = table_ref @@ -105,8 +100,7 @@ def __iter__(self) -> Iterator[dict]: return iter(self.rows) def __len__(self) -> int: - # TODO: This can be slow for large datasets... - return len(list(self.rows)) + return len(self.rows) def __getitem__(self, key: int) -> dict: if key < 0: diff --git a/weave/flow/eval.py b/weave/flow/eval.py index 4703000ce68e..c6ef49d0dbab 100644 --- a/weave/flow/eval.py +++ b/weave/flow/eval.py @@ -2,6 +2,7 @@ import logging import traceback from datetime import datetime +from itertools import chain, repeat from typing import Any, Callable, Literal, Optional, Union from pydantic import PrivateAttr @@ -28,9 +29,8 @@ ) from weave.flow.util import make_memorable_name, transpose from weave.trace.env import get_weave_parallelism -from weave.trace.errors import OpCallError from weave.trace.objectify import register_object -from weave.trace.op import CallDisplayNameFunc, Op, as_op, is_op +from weave.trace.op import CallDisplayNameFunc, Op, OpCallError, as_op, is_op from weave.trace.vals import WeaveObject from weave.trace.weave_client import Call, get_ref @@ -211,12 +211,14 @@ async def eval_example(example: dict) -> dict: n_complete = 0 dataset = self.dataset _rows = dataset.rows - trial_rows = list(_rows) * self.trials + num_rows = len(_rows) * self.trials + + trial_rows = chain.from_iterable(repeat(_rows, self.trials)) async for example, eval_row in util.async_foreach( trial_rows, eval_example, get_weave_parallelism() ): n_complete += 1 - print(f"Evaluated {n_complete} of {len(trial_rows)} examples") + print(f"Evaluated {n_complete} of {num_rows} examples") if eval_row is None: eval_row = {self._output_key: None, "scores": {}} else: diff --git a/weave/flow/model.py b/weave/flow/model.py index 0e7e9ba9e16a..06f385c08106 100644 --- a/weave/flow/model.py +++ b/weave/flow/model.py @@ -8,9 +8,8 @@ from rich import print from weave.flow.obj import Object -from weave.trace.errors import OpCallError from weave.trace.isinstance import weave_isinstance -from weave.trace.op import Op, as_op, is_op +from weave.trace.op import Op, OpCallError, as_op, is_op from weave.trace.op_caller import async_call_op from weave.trace.weave_client import Call diff --git a/weave/flow/scorer.py b/weave/flow/scorer.py index 6bcf8c04a012..75abb5673bb7 100644 --- a/weave/flow/scorer.py +++ b/weave/flow/scorer.py @@ -10,9 +10,8 @@ import weave from weave.flow.obj import Object -from weave.trace.errors import OpCallError from weave.trace.isinstance import weave_isinstance -from weave.trace.op import Op, as_op, is_op +from weave.trace.op import Op, OpCallError, as_op, is_op from weave.trace.op_caller import async_call_op from weave.trace.weave_client import Call, sanitize_object_name @@ -354,7 +353,7 @@ async def apply_scorer_async( scorer argument names: {score_arg_names} dataset keys: {example.keys()} - scorer.column_map: {getattr(scorer, 'column_map', '{}')} + scorer.column_map: {getattr(scorer, "column_map", "{}")} Options for resolving: a. if using the `Scorer` weave class, you can set the `scorer.column_map` attribute to map scorer argument names to dataset column names or diff --git a/weave/flow/util.py b/weave/flow/util.py index 3e61d7a12540..a324887c2cf9 100644 --- a/weave/flow/util.py +++ b/weave/flow/util.py @@ -242,3 +242,10 @@ def make_memorable_name() -> str: adj = random.choice(adjectives) noun = random.choice(nouns) return f"{adj}-{noun}" + + +def short_str(obj: Any, limit: int = 25) -> str: + str_val = str(obj) + if len(str_val) > limit: + return str_val[:limit] + "..." + return str_val diff --git a/weave/integrations/google_ai_studio/google_ai_studio_sdk.py b/weave/integrations/google_ai_studio/google_ai_studio_sdk.py index 704605a837f5..18f2db85a6d2 100644 --- a/weave/integrations/google_ai_studio/google_ai_studio_sdk.py +++ b/weave/integrations/google_ai_studio/google_ai_studio_sdk.py @@ -8,7 +8,7 @@ from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.serialize import dictify +from weave.trace.serialization.serialize import dictify from weave.trace.weave_client import Call if TYPE_CHECKING: diff --git a/weave/integrations/huggingface/huggingface_inference_client_sdk.py b/weave/integrations/huggingface/huggingface_inference_client_sdk.py index da825608534e..9a6b0bda731e 100644 --- a/weave/integrations/huggingface/huggingface_inference_client_sdk.py +++ b/weave/integrations/huggingface/huggingface_inference_client_sdk.py @@ -6,7 +6,7 @@ from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.serialize import dictify +from weave.trace.serialization.serialize import dictify if TYPE_CHECKING: from huggingface_hub.inference._generated.types.chat_completion import ( diff --git a/weave/integrations/vertexai/vertexai_sdk.py b/weave/integrations/vertexai/vertexai_sdk.py index 75bc5f7881e6..a2a7e3b48715 100644 --- a/weave/integrations/vertexai/vertexai_sdk.py +++ b/weave/integrations/vertexai/vertexai_sdk.py @@ -8,7 +8,7 @@ from weave.integrations.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.serialize import dictify +from weave.trace.serialization.serialize import dictify from weave.trace.weave_client import Call if TYPE_CHECKING: diff --git a/weave/trace/context/weave_client_context.py b/weave/trace/context/weave_client_context.py index 03582a50cd27..d03e8cf8ef23 100644 --- a/weave/trace/context/weave_client_context.py +++ b/weave/trace/context/weave_client_context.py @@ -3,8 +3,6 @@ import threading from typing import TYPE_CHECKING -from weave.trace.errors import WeaveInitError - if TYPE_CHECKING: from weave.trace.weave_client import WeaveClient @@ -38,6 +36,9 @@ def get_weave_client() -> WeaveClient | None: return _global_weave_client +class WeaveInitError(Exception): ... + + def require_weave_client() -> WeaveClient: if (client := get_weave_client()) is None: raise WeaveInitError("You must call `weave.init()` first") diff --git a/weave/trace/errors.py b/weave/trace/errors.py deleted file mode 100644 index 412f1e854eaf..000000000000 --- a/weave/trace/errors.py +++ /dev/null @@ -1,28 +0,0 @@ -class Error(Exception): ... - - -class InternalError(Error): ... - - -class OpCallError(Error): ... - - -class WeaveTypeError(Error): ... - - -class WeaveSerializeError(Error): ... - - -class WeaveOpSerializeError(WeaveSerializeError): ... - - -class WeaveInitError(Error): ... - - -class WeaveDefinitionError(Error): ... - - -class WeaveWandbAuthenticationException(Error): ... - - -class WeaveConfigurationError(Error): ... diff --git a/weave/trace/op.py b/weave/trace/op.py index d5231ad466b3..97c98d452ab6 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -33,7 +33,6 @@ tracing_disabled, ) from weave.trace.context.tests_context import get_raise_on_captured_errors -from weave.trace.errors import OpCallError from weave.trace.refs import ObjectRef from weave.trace.util import log_once @@ -216,6 +215,9 @@ def _is_unbound_method(func: Callable) -> bool: return bool(is_method) +class OpCallError(Exception): ... + + def _default_on_input_handler(func: Op, args: tuple, kwargs: dict) -> ProcessedInputs: try: sig = inspect.signature(func) @@ -255,7 +257,7 @@ def _create_call( parent_call = call_context.get_current_call() attributes = call_attributes.get() - from weave.trace.serialize import dictify + from weave.trace.serialization.serialize import dictify attributes = dictify(attributes) diff --git a/weave/trace/serialization/README.md b/weave/trace/serialization/README.md new file mode 100644 index 000000000000..3fd3831ef558 --- /dev/null +++ b/weave/trace/serialization/README.md @@ -0,0 +1,176 @@ +# Serialization Patterns Audit in Weave (for Weave devs) + +## Overview + +This doc covers the current serialization patterns in Weave. For the purposes of this doc, contents are grouped by the files they are implemented in. Some concepts are currently split across multiple files, which is a limitation that should be addressed in the future. + +## Core Serialization Components + +### 1. Serialization Entrypoint (`serialize.py`) + +The entrypoint for serialization is `weave/trace/serialization/serialize.py`. This file principally contains: + +1. `to_json()`, which converts Python objects to JSON-serializable formats +2. `from_json()`, which converts JSON data back to Python objects + +#### Serialization Flow Diagram + +```mermaid +flowchart LR + UserCode1["User Code"] -->|Python Objects| WeaveClient + + subgraph WeaveClient["Weave Client"] + ClientSave["client.save()"] --> ToJson["to_json()"] + ToJson -->|JSON Data| ServerStorage["Server Storage"] + ServerStorage -->|JSON Data| FromJson["from_json()"] + FromJson --> ClientGet["client.get()"] + end + + WeaveClient -->|Python Objects| UserCode2["User Code"] +``` + +Today, serialization handles a variety of types, including: +| Type | Reversibility | +|------|---------------| +| Python primitives (int, float, str, bool, None) | Reversible | +| Collections (list, tuple, dict) | Reversible | +| Registered custom objects (via the `custom_objs` module) | Reversible | +| Arbitrary python objects | Not reversible | +| "Dictifiable" objects | Partially reversible (back into a dict) | + +NOTE: Not all serialization is reversible, which can be surprising and frustrating. We also handle weave-specific concepts like `ObjectRecord` and refs (`ObjectRef`, `TableRef`, etc.) + +### 2. Custom Object Serialization (`custom_objs.py` and `serializer.py`) + +NOTE: Custom object serialization is WIP and subject to change! + +The entrypoint for custom object serialization is `weave/trace/serialization/custom_objs.py`. This file primarily contains: + +1. `encode_custom_obj()`, which encodes custom objects using registered serializers +2. `decode_custom_obj()`, which decodes custom objects using the appropriate serializer + +#### Custom Object Serialization Flow + +```mermaid +flowchart TD + CustomObj["Custom Object (PIL.Image.Image, wave.Wave_read, etc.)"] --> ToJson["to_json()"] + ToJson --> EncodeCustomObj["encode_custom_obj()"] + + subgraph CustomSerialization["Custom Object Serialization Process"] + EncodeCustomObj --> Serializer["Registered Serializer (save method)"] + Serializer -->|Write to files| MemArtifact["MemTraceFilesArtifact"] + + MemArtifact -->|Read from files| Deserializer["Registered Serializer (load method)"] + Deserializer --> DecodeCustomObj["decode_custom_obj()"] + end + + DecodeCustomObj --> FromJson["from_json()"] + FromJson --> ReconstructedObj["Reconstructed Custom Object"] + + RegisterSerializer["register_serializer()"] -.->|Registers| Serializer + RegisterSerializer -.->|Registers| Deserializer +``` + +Custom serializers can be registered with `register_serializer`. This allows users to specify custom `save` and `load` methods for custom types. For portability, weave also packages the `load` function as an op in the saved object so it can try to be loaded even if the serializer is not registered in the target runtime (this is done on a best-effort basis). See more in the `Adding Custom Types` section. + +Weave also ships with a set of first-class serializers for common types, defined in `weave/type_handlers/`, including: + +- Images (`PIL.Image.Image`) +- Audio (`wave.Wave_read`) +- Op (`weave.Op`) + +These `KNOWN_TYPES` are implemented using `register_serializer`. Unlike other types, these will always be loaded using the SDK's current `load` function (instead of the one packaged with the object). + +#### File-based serialization + +Today, all custom object serialization is file-based via `MemTraceFilesArtifact`. Technicaly this supports many files per artifact, but in practice today we use just a single file (obj.py) for each artifact. Elements of this are hardcoded in both the SDK and App layers. + +#### Inline serialization + +Notably, this system is missing the ability to register inline serializers for custom objects. For example, the user may have a simple type like a custom datetime. This is a known limitation and will be addressed in the future. + +#### Other limitations + +1. File I/O can block the main thread for particularly large files. +2. There's not an easy way to know what types are registered, or their relative order. This may become more relevant as users begin to bring their own types, and especially if there is a conflicting type hierarchy (e.g. a registered `torch.Tensor` serializer + a subclass of `torch.Tensor` serializer.) +3. The current file-based approach also requires a network request for each file. This can be a problematic, especially in the dataset case, where a row may contain many custom objects that each have to reach out to the server to load. +4. There is no file-deduplication for custom objects. If a custom object is saved multiple times, it will use multiple files. This might be surprising, e.g. in the dataset case where a single image may be present in multiple rows. + +### Fallback serialization + +For objects that are not explicitly registered, there are a few (lossy) fallback mechanisms, including: + +1. `Dictifiable` -- objects that can implement `to_dict() -> dict[str, Any]` will attempt to be serialized as a dict. The actual mechanics are up to the object's implementation, but usually this involves dumping the object's public attributes to a dict. +2. `stringify` -- in the worst case, python objects will be serialized as the object's repr string. + +## Adding Custom Types + +Users who want to add their own custom types to Weave can do so by registering custom serializers. This section provides a guide for users who are not Weave developers but want to ensure their custom types can be properly serialized and deserialized. + +### 1. Creating Serialization Functions + +To add support for a custom type, you need to define two functions: + +1. **Save Function**: Responsible for serializing the object to files +2. **Load Function**: Responsible for deserializing the object from files + +Example: + +```python +from weave.trace.serialization import serializer +from weave.trace.serialization.custom_objs import MemTraceFilesArtifact + +class MyCustomType: + def __init__(self, value): + self.value = value + +def save_my_type(obj: MyCustomType, artifact: MemTraceFilesArtifact, name: str) -> None: + # Save the object's data to a file in the artifact + with artifact.new_file(f"{name}.txt") as f: + f.write(str(obj.value)) + +def load_my_type(artifact: MemTraceFilesArtifact, name: str) -> MyCustomType: + # Load the object's data from the file in the artifact + with artifact.open(f"{name}.txt") as f: + value = f.read() + return MyCustomType(value) +``` + +### 2. Registering the Serializer + +Once you have defined the save and load functions, you need to register them with Weave: + +```python +from weave.trace.serialization import serializer + +# Register the serializer for your custom type +serializer.register_serializer(MyCustomType, save_my_type, load_my_type) +``` + +### 3. Using Custom Types with Weave + +After registering the serializer, you can use your custom type with Weave operations: + +```python +import weave + +# Create an instance of your custom type +my_obj = MyCustomType("Hello, Weave!") + +# Save the object to Weave +client = weave.init() +ref = client.save(my_obj, "my-custom-object") + +# Retrieve the object from Weave +retrieved_obj = client.get(ref) +``` + +### 4. Cross-Runtime Compatibility + +When a custom object is saved, Weave also saves the load function as an op. This allows the object to be deserialized in a Python runtime that does not have the serializer registered, as long as the necessary dependencies are available. This is done on a best-effort basis, and the object will not be loadable if any of the dependencies are not available. + +We don't currently save a lock file of the dependencies, but we might want to if we want to provide better portability in future. + +## Server-Side Object Registration (WIP) + +Weave is also in the process of adding support for server-side object registration. Currently, we have limited support for server-side object registration through the `BUILTIN_OBJECT_REGISTRY` diff --git a/weave/trace/custom_objs.py b/weave/trace/serialization/custom_objs.py similarity index 93% rename from weave/trace/custom_objs.py rename to weave/trace/serialization/custom_objs.py index b04c9ba50a0d..9cc5e4ef823f 100644 --- a/weave/trace/custom_objs.py +++ b/weave/trace/serialization/custom_objs.py @@ -3,12 +3,17 @@ from collections.abc import Mapping from typing import Any, Callable -from weave.trace import op_type # noqa: F401, Must import this to register op save/load from weave.trace.context.weave_client_context import require_weave_client -from weave.trace.mem_artifact import MemTraceFilesArtifact from weave.trace.op import Op, op from weave.trace.refs import ObjectRef, parse_uri -from weave.trace.serializer import get_serializer_by_id, get_serializer_for_obj +from weave.trace.serialization import ( + op_type, # noqa: F401, Must import this to register op save/load +) +from weave.trace.serialization.mem_artifact import MemTraceFilesArtifact +from weave.trace.serialization.serializer import ( + get_serializer_by_id, + get_serializer_for_obj, +) class DecodeCustomObjectError(Exception): diff --git a/weave/trace/mem_artifact.py b/weave/trace/serialization/mem_artifact.py similarity index 96% rename from weave/trace/mem_artifact.py rename to weave/trace/serialization/mem_artifact.py index 456ffeb04c00..5e90593224ec 100644 --- a/weave/trace/mem_artifact.py +++ b/weave/trace/serialization/mem_artifact.py @@ -6,7 +6,9 @@ from collections.abc import Generator, Iterator, Mapping from io import BytesIO, StringIO -from weave.trace import op_type # noqa: F401, Must import this to register op save/load +from weave.trace.serialization import ( + op_type, # noqa: F401, Must import this to register op save/load +) # This uses the older weave query_service's Artifact interface. We could # probably simplify a lot at this point by removing the internal requirement diff --git a/weave/trace/op_type.py b/weave/trace/serialization/op_type.py similarity index 98% rename from weave/trace/op_type.py rename to weave/trace/serialization/op_type.py index 089d849224c1..8a90d88ca888 100644 --- a/weave/trace/op_type.py +++ b/weave/trace/serialization/op_type.py @@ -13,17 +13,18 @@ from _ast import AsyncFunctionDef, ExceptHandler from typing import Any, Callable, TypedDict, get_args, get_origin -from weave.trace import serializer, settings +from weave.trace import settings from weave.trace.context.weave_client_context import get_weave_client from weave.trace.ipython import ( ClassNotFoundError, get_class_source, is_running_interactively, ) -from weave.trace.mem_artifact import MemTraceFilesArtifact from weave.trace.op import Op, as_op, is_op from weave.trace.refs import ObjectRef from weave.trace.sanitize import REDACTED_VALUE, should_redact +from weave.trace.serialization import serializer +from weave.trace.serialization.mem_artifact import MemTraceFilesArtifact from weave.trace_server.trace_server_interface_util import str_digest WEAVE_OP_PATTERN = re.compile(r"@weave\.op(\(\))?") @@ -425,7 +426,7 @@ def _get_code_deps( if (client := get_weave_client()) is None: raise ValueError("Weave client not found") - from weave.trace.serialize import to_json + from weave.trace.serialization.serialize import to_json # Redact sensitive values if should_redact(var_name): diff --git a/weave/trace/serialize.py b/weave/trace/serialization/serialize.py similarity index 99% rename from weave/trace/serialize.py rename to weave/trace/serialization/serialize.py index 521f1d8a8205..85c60632d32e 100644 --- a/weave/trace/serialize.py +++ b/weave/trace/serialization/serialize.py @@ -7,10 +7,10 @@ from pydantic import BaseModel -from weave.trace import custom_objs from weave.trace.object_record import ObjectRecord from weave.trace.refs import ObjectRef, TableRef, parse_uri from weave.trace.sanitize import REDACTED_VALUE, should_redact +from weave.trace.serialization import custom_objs from weave.trace.serialization.dictifiable import try_to_dict from weave.trace_server.interface.builtin_object_classes.builtin_object_registry import ( BUILTIN_OBJECT_REGISTRY, diff --git a/weave/trace/serializer.py b/weave/trace/serialization/serializer.py similarity index 100% rename from weave/trace/serializer.py rename to weave/trace/serialization/serializer.py diff --git a/weave/trace/vals.py b/weave/trace/vals.py index f46d5badb05b..fdeb9bc63e16 100644 --- a/weave/trace/vals.py +++ b/weave/trace/vals.py @@ -3,7 +3,7 @@ import logging import operator import typing -from collections.abc import Generator, Iterator +from collections.abc import Generator, Iterator, Sequence from copy import deepcopy from typing import Any, Literal, Optional, SupportsIndex, Union @@ -13,7 +13,6 @@ from weave.trace import box from weave.trace.context.tests_context import get_raise_on_captured_errors from weave.trace.context.weave_client_context import get_weave_client -from weave.trace.errors import InternalError from weave.trace.object_record import ObjectRecord from weave.trace.op import is_op, maybe_bind_method from weave.trace.refs import ( @@ -26,15 +25,17 @@ RefWithExtra, TableRef, ) -from weave.trace.serialize import from_json +from weave.trace.serialization.serialize import from_json from weave.trace.table import Table from weave.trace_server.errors import ObjectDeletedError from weave.trace_server.trace_server_interface import ( ObjReadReq, TableQueryReq, + TableQueryStatsReq, TableRowFilter, TraceServerInterface, ) +from weave.utils.iterators import ThreadSafeLazyList logger = logging.getLogger(__name__) @@ -263,6 +264,11 @@ def __eq__(self, other: Any) -> bool: class WeaveTable(Traceable): filter: TableRowFilter + _known_length: Optional[int] = None + _rows: Optional[Sequence[dict]] = None + # _prefetched_rows is a local cache of rows that can be used to + # avoid a remote call. Should only be used by internal code. + _prefetched_rows: Optional[list[dict]] = None def __init__( self, @@ -279,14 +285,9 @@ def __init__( self.server = server self.root = root or self self.parent = parent - self._rows: Optional[list[dict]] = None - - # _prefetched_rows is a local cache of rows that can be used to - # avoid a remote call. Should only be used by internal code. - self._prefetched_rows: Optional[list[dict]] = None @property - def rows(self) -> list[dict]: + def rows(self) -> Sequence[dict]: if self._rows is None: should_local_iter = ( self.ref is not None @@ -295,19 +296,32 @@ def rows(self) -> list[dict]: and self._prefetched_rows is not None ) if should_local_iter: - self._rows = list(self._local_iter_with_remote_fallback()) + self._rows = ThreadSafeLazyList(self._local_iter_with_remote_fallback()) else: - self._rows = list(self._remote_iter()) + self._rows = ThreadSafeLazyList(self._remote_iter()) return self._rows @rows.setter - def rows(self, value: list[dict]) -> None: + def rows(self, value: Sequence[dict]) -> None: if not all(isinstance(row, dict) for row in value): raise ValueError("All table rows must be dicts") self._rows = value self._mark_dirty() + def _inefficiently_materialize_rows_as_list(self) -> list[dict]: + # This method is named `inefficiently` to warn callers that + # it should be avoided. We have this nasty paradigm where sometimes + # a WeaveTable needs to act like a list, but it is actually a remote + # table. This method will force iteration through the remote data + # and materialize it into a list. Any uses of this are signs of a design + # problem arising from a remote table clashing with the need to feel like + # a local list. + if not isinstance(self.rows, list): + self._rows = list(iter(self.rows)) + self._known_length = len(self._rows) + return typing.cast(list[dict], self.rows) + def set_prefetched_rows(self, prefetched_rows: list[dict]) -> None: """Sets the rows to a local cache of rows that can be used to avoid a remote call. Should only be used by internal code. @@ -323,14 +337,54 @@ def set_prefetched_rows(self, prefetched_rows: list[dict]) -> None: self._prefetched_rows = prefetched_rows def __len__(self) -> int: - return len(self.rows) + # This should be a single query + if self._known_length is not None: + return self._known_length + + # Condition 1: we already have all the rows in memory + if self._prefetched_rows is not None: + self._known_length = len(self._prefetched_rows) + return self._known_length + + # Condition 2: we have the row digests and they are a list + if ( + self.table_ref is not None + and self.table_ref._row_digests is not None + and isinstance(self.table_ref._row_digests, list) + ): + self._known_length = len(self.table_ref._row_digests) + return self._known_length + + # Condition 3: We don't know the length, in which case we can get it from the server + if self.table_ref is not None: + self._known_length = self._fetch_remote_length() + return self._known_length + + # Finally, if we have no table ref, we can still get the length + # by materializing the rows as a list. I actually think this + # can never happen, but it is here for completeness. + rows_as_list = self._inefficiently_materialize_rows_as_list() + return len(rows_as_list) + + def _fetch_remote_length(self) -> int: + if self.table_ref is None: + raise ValueError("Cannot fetch remote length of table without table ref") + + response = self.server.table_query_stats( + TableQueryStatsReq( + project_id=self.table_ref.project_id, digest=self.table_ref.digest + ) + ) + return response.count def __eq__(self, other: Any) -> bool: - return self.rows == other + rows = self._inefficiently_materialize_rows_as_list() + return rows == other def _mark_dirty(self) -> None: self.table_ref = None self._prefetched_rows = None + self._known_length = None super()._mark_dirty() def _local_iter_with_remote_fallback(self) -> Generator[dict, None, None]: @@ -398,15 +452,27 @@ def _remote_iter(self) -> Generator[dict, None, None]: ) ) - if self._prefetched_rows is not None and len(response.rows) != len( - self._prefetched_rows - ): - if get_raise_on_captured_errors(): - raise - logger.error( - f"Expected length of response rows ({len(response.rows)}) to match prefetched rows ({len(self._prefetched_rows)}). Ignoring prefetched rows." - ) - self._prefetched_rows = None + # When paginating through large datasets, we need special handling for prefetched rows + # on the first page. This is because prefetched_rows contains ALL rows, while each + # response page contains at most page_size rows. + if page_index == 0 and self._prefetched_rows is not None: + response_rows_len = len(response.rows) + prefetched_rows_len = len(self._prefetched_rows) + + # There are two valid scenarios: + # 1. The response rows exactly match prefetched rows (small dataset, no pagination needed) + # 2. We're paginating a large dataset (response has page_size rows, prefetched has more) + # + # Any other mismatch indicates an inconsistency that should be handled by + # discarding the prefetched rows and relying solely on server responses. + if response_rows_len != prefetched_rows_len and not ( + response_rows_len == page_size and prefetched_rows_len > page_size + ): + msg = f"Expected length of response rows ({response_rows_len}) to match prefetched rows ({prefetched_rows_len}). Ignoring prefetched rows." + if get_raise_on_captured_errors(): + raise ValueError(msg) + logger.debug(msg) + self._prefetched_rows = None for i, item in enumerate(response.rows): new_ref = self.ref.with_item(item.digest) if self.ref else None @@ -419,7 +485,7 @@ def _remote_iter(self) -> Generator[dict, None, None]: val = ( item.val if self._prefetched_rows is None - else self._prefetched_rows[i] + else self._prefetched_rows[page_index * page_size + i] ) res = from_json(val, self.table_ref.project_id, self.server) res = make_trace_obj(res, new_ref, self.server, self.root) @@ -431,7 +497,11 @@ def _remote_iter(self) -> Generator[dict, None, None]: page_index += 1 def __getitem__(self, key: Union[int, slice, str]) -> Any: - rows = self.rows + # TODO: ideally we would have some sort of intelligent + # LRU style caching that allows us to minimize materialization + # of the rows as a list. + rows = self._inefficiently_materialize_rows_as_list() + if isinstance(key, (int, slice)): return rows[key] @@ -445,14 +515,16 @@ def __iter__(self) -> Iterator[dict]: return iter(self.rows) def append(self, val: dict) -> None: + rows = self._inefficiently_materialize_rows_as_list() if not isinstance(val, dict): raise TypeError("Can only append dicts to tables") self._mark_dirty() - self.rows.append(val) + rows.append(val) def pop(self, index: int) -> None: + rows = self._inefficiently_materialize_rows_as_list() self._mark_dirty() - self.rows.pop(index) + rows.pop(index) class WeaveList(Traceable, list): @@ -613,6 +685,9 @@ def __eq__(self, other: Any) -> bool: return True +class InternalError(Exception): ... + + def make_trace_obj( val: Any, new_ref: Optional[RefWithExtra], # Can this actually be None? diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index bc5db6bc6806..6500a141151f 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -58,8 +58,12 @@ parse_uri, ) from weave.trace.sanitize import REDACTED_VALUE, should_redact -from weave.trace.serialize import from_json, isinstance_namedtuple, to_json -from weave.trace.serializer import get_serializer_for_obj +from weave.trace.serialization.serialize import ( + from_json, + isinstance_namedtuple, + to_json, +) +from weave.trace.serialization.serializer import get_serializer_for_obj from weave.trace.settings import ( client_parallelism, should_capture_client_info, @@ -1987,7 +1991,7 @@ def _flush(self) -> None: # _server_is_flushable and only call this if we know the server is # flushable. The # type: ignore is safe because we check the type # first. - self.server.call_processor.wait_until_all_processed() # type: ignore + self.server.call_processor.stop_accepting_new_work_and_flush_queue() # type: ignore def _send_file_create(self, req: FileCreateReq) -> Future[FileCreateRes]: if self.future_executor_fastlane: diff --git a/weave/trace/weave_init.py b/weave/trace/weave_init.py index f09837d2b3f8..573cf0e0a6b4 100644 --- a/weave/trace/weave_init.py +++ b/weave/trace/weave_init.py @@ -1,6 +1,6 @@ from __future__ import annotations -from weave.trace import autopatch, errors, init_message, trace_sentry, weave_client +from weave.trace import autopatch, init_message, trace_sentry, weave_client from weave.trace.context import weave_client_context as weave_client_context from weave.trace.settings import should_redact_pii, use_server_cache from weave.trace_server import sqlite_trace_server @@ -33,6 +33,9 @@ def get_username() -> str | None: return None +class WeaveWandbAuthenticationException(Exception): ... + + def get_entity_project_from_project_name(project_name: str) -> tuple[str, str]: from weave.wandb_interface import wandb_api @@ -41,7 +44,7 @@ def get_entity_project_from_project_name(project_name: str) -> tuple[str, str]: api = wandb_api.get_wandb_api_sync() entity_name = api.default_entity_name() if entity_name is None: - raise errors.WeaveWandbAuthenticationException( + raise WeaveWandbAuthenticationException( 'weave init requires wandb. Run "wandb login"' ) project_name = fields[0] diff --git a/weave/trace_server/async_batch_processor.py b/weave/trace_server/async_batch_processor.py deleted file mode 100644 index a8a183d94bfe..000000000000 --- a/weave/trace_server/async_batch_processor.py +++ /dev/null @@ -1,78 +0,0 @@ -import atexit -import logging -import time -from queue import Queue -from threading import Event, Lock, Thread -from typing import Callable, Generic, TypeVar - -from weave.trace.context.tests_context import get_raise_on_captured_errors - -T = TypeVar("T") -logger = logging.getLogger(__name__) - - -class AsyncBatchProcessor(Generic[T]): - """A class that asynchronously processes batches of items using a provided processor function.""" - - def __init__( - self, - processor_fn: Callable[[list[T]], None], - max_batch_size: int = 100, - min_batch_interval: float = 1.0, - ) -> None: - """ - Initializes an instance of AsyncBatchProcessor. - - Args: - processor_fn (Callable[[list[T]], None]): The function to process the batches of items. - max_batch_size (int, optional): The maximum size of each batch. Defaults to 100. - min_batch_interval (float, optional): The minimum interval between processing batches. Defaults to 1.0. - """ - self.processor_fn = processor_fn - self.max_batch_size = max_batch_size - self.min_batch_interval = min_batch_interval - self.queue: Queue[T] = Queue() - self.lock = Lock() - self.stop_event = Event() # Use an event to signal stopping - self.processing_thread = Thread(target=self._process_batches) - self.processing_thread.daemon = True - self.processing_thread.start() - atexit.register(self.wait_until_all_processed) # Register cleanup function - - def enqueue(self, items: list[T]) -> None: - """ - Enqueues a list of items to be processed. - - Args: - items (list[T]): The items to be processed. - """ - with self.lock: - for item in items: - self.queue.put(item) - - def _process_batches(self) -> None: - """Internal method that continuously processes batches of items from the queue.""" - while True: - current_batch: list[T] = [] - while not self.queue.empty() and len(current_batch) < self.max_batch_size: - current_batch.append(self.queue.get()) - - if current_batch: - try: - self.processor_fn(current_batch) - except Exception as e: - if get_raise_on_captured_errors(): - raise - logger.exception(f"Error processing batch: {e}") - - if self.stop_event.is_set() and self.queue.empty(): - break - - # Unless we are stopping, sleep for a the min_batch_interval - if not self.stop_event.is_set(): - time.sleep(self.min_batch_interval) - - def wait_until_all_processed(self) -> None: - """Waits until all enqueued items have been processed.""" - self.stop_event.set() - self.processing_thread.join() diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 6fad23637d30..e8683d175a65 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -35,6 +35,7 @@ from zoneinfo import ZoneInfo import clickhouse_connect +import ddtrace import emoji from clickhouse_connect.driver.client import Client as CHClient from clickhouse_connect.driver.query import QueryResult @@ -388,6 +389,7 @@ def row_to_call_schema_dict(row: tuple[Any, ...]) -> dict[str, Any]: for call in call_dicts: yield tsi.CallSchema.model_validate(call) + @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._add_feedback_to_calls") def _add_feedback_to_calls( self, project_id: str, calls: list[dict[str, Any]] ) -> None: @@ -422,6 +424,7 @@ def _get_refs_to_resolve( refs_to_resolve[(i, col)] = ref return refs_to_resolve + @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._expand_call_refs") def _expand_call_refs( self, project_id: str, @@ -879,6 +882,7 @@ def table_query_stream( ) yield from rows + @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._table_query_stream") def _table_query_stream( self, project_id: str, @@ -972,6 +976,7 @@ def refs_read_batch(self, req: tsi.RefsReadBatchReq) -> tsi.RefsReadBatchRes: return tsi.RefsReadBatchRes(vals=vals) + @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._parsed_refs_read_batch") def _parsed_refs_read_batch( self, parsed_refs: ObjRefListType, @@ -1001,6 +1006,9 @@ def make_ref_cache_key(ref: ri.InternalObjectRef) -> str: # Return the final data payload return [final_result_cache[make_ref_cache_key(ref)] for ref in parsed_refs] + @ddtrace.tracer.wrap( + name="clickhouse_trace_server_batched._refs_read_batch_within_project" + ) def _refs_read_batch_within_project( self, project_id_scope: str, @@ -1250,6 +1258,7 @@ def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: self._file_create_clickhouse(req, digest) return tsi.FileCreateRes(digest=digest) + @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._file_create_clickhouse") def _file_create_clickhouse(self, req: tsi.FileCreateReq, digest: str) -> None: chunks = [ req.content[i : i + FILE_CHUNK_SIZE] @@ -1282,6 +1291,7 @@ def _file_create_clickhouse(self, req: tsi.FileCreateReq, digest: str) -> None: ], ) + @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._file_create_bucket") def _file_create_bucket( self, req: tsi.FileCreateReq, digest: str, base_file_storage_uri: FileStorageURI ) -> None: @@ -1713,6 +1723,7 @@ def with_new_client(self) -> Iterator[None]: # def __del__(self) -> None: # self.ch_client.close() + @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._insert_call_batch") def _insert_call_batch(self, batch: list) -> None: if batch: settings = {} @@ -1778,6 +1789,7 @@ def _run_migrations(self) -> None: migrator = wf_migrator.ClickHouseTraceServerMigrator(self._mint_client()) migrator.apply_migrations(self._database) + @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._query_stream") def _query_stream( self, query: str, @@ -1811,6 +1823,7 @@ def _query_stream( ) yield from stream + @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._query") def _query( self, query: str, @@ -1832,6 +1845,7 @@ def _query( ) return res + @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._insert") def _insert( self, table: str, @@ -1858,6 +1872,7 @@ def _insert( ) raise + @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._insert_call") def _insert_call(self, ch_call: CallCHInsertable) -> None: parameters = ch_call.model_dump() row = [] @@ -1867,6 +1882,7 @@ def _insert_call(self, ch_call: CallCHInsertable) -> None: if self._flush_immediately: self._flush_calls() + @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._flush_calls") def _flush_calls(self) -> None: try: self._insert_call_batch(self._call_batch) @@ -1877,6 +1893,7 @@ def _flush_calls(self) -> None: self._call_batch = [] + @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._strip_large_values") def _strip_large_values(self, batch: list[list[Any]]) -> list[list[Any]]: """ Iterate through the batch and replace large values with placeholders. diff --git a/weave/trace_server_bindings/async_batch_processor.py b/weave/trace_server_bindings/async_batch_processor.py new file mode 100644 index 000000000000..010e627a4bc9 --- /dev/null +++ b/weave/trace_server_bindings/async_batch_processor.py @@ -0,0 +1,111 @@ +import atexit +import logging +import time +from queue import Empty, Full, Queue +from threading import Event, Lock, Thread +from typing import Callable, Generic, TypeVar + +from weave.trace.context.tests_context import get_raise_on_captured_errors + +T = TypeVar("T") +logger = logging.getLogger(__name__) + + +class AsyncBatchProcessor(Generic[T]): + """A class that asynchronously processes batches of items using a provided processor function.""" + + def __init__( + self, + processor_fn: Callable[[list[T]], None], + max_batch_size: int = 100, + min_batch_interval: float = 1.0, + max_queue_size: int = 10_000, + ) -> None: + """ + Initializes an instance of AsyncBatchProcessor. + + Args: + processor_fn (Callable[[list[T]], None]): The function to process the batches of items. + max_batch_size (int, optional): The maximum size of each batch. Defaults to 100. + min_batch_interval (float, optional): The minimum interval between processing batches. Defaults to 1.0. + max_queue_size (int, optional): The maximum number of items to hold in the queue. Defaults to 10_000. 0 means no limit. + """ + self.processor_fn = processor_fn + self.max_batch_size = max_batch_size + self.min_batch_interval = min_batch_interval + self.queue: Queue[T] = Queue(maxsize=max_queue_size) + self.lock = Lock() + self.stop_accepting_work_event = Event() + self.processing_thread = Thread(target=self._process_batches) + self.processing_thread.daemon = True + self.processing_thread.start() + + # TODO: Probably should include a health check thread here. It will revive the + # processing thread if that thread dies. + + # TODO: Probably should include some sort of local write buffer. It might not need + # to be here, but it should exist. That handles 2 cases: + # 1. The queue is full, so users can sync up data later. + # 2. The system crashes for some reason, so users can resume from the local buffer. + + atexit.register(self.stop_accepting_new_work_and_flush_queue) + + def enqueue(self, items: list[T]) -> None: + """ + Enqueues a list of items to be processed. + + Args: + items (list[T]): The items to be processed. + """ + with self.lock: + for item in items: + try: + self.queue.put_nowait(item) + except Full: + # TODO: This is probably not what you want, but it will prevent OOM for now. + logger.warning( + f"Queue is full. Dropping item. Max queue size: {self.queue.maxsize}" + ) + + def _get_next_batch(self) -> list[T]: + batch: list[T] = [] + while len(batch) < self.max_batch_size: + try: + item = self.queue.get_nowait() + except Empty: + break + else: + batch.append(item) + return batch + + def _process_batches(self) -> None: + """Internal method that continuously processes batches of items from the queue.""" + while True: + if current_batch := self._get_next_batch(): + try: + self.processor_fn(current_batch) + except Exception as e: + if get_raise_on_captured_errors(): + raise + logger.exception(f"Error processing batch: {e}") + else: + for _ in current_batch: + self.queue.task_done() + + if self.stop_accepting_work_event.is_set() and self.queue.empty(): + break + + # Unless we are stopping, sleep for a the min_batch_interval + if not self.stop_accepting_work_event.is_set(): + time.sleep(self.min_batch_interval) + + def stop_accepting_new_work_and_flush_queue(self) -> None: + """Stops accepting new work and begins gracefully shutting down. + + Any new items enqueued after this call will not be processed!""" + self.stop_accepting_work_event.set() + self.processing_thread.join() + + def accept_new_work(self) -> None: + """Resumes accepting new work.""" + self.stop_accepting_work_event.clear() diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index ee2a147f5b6d..2eb1ce4dbefc 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -10,11 +10,16 @@ from weave.trace.env import weave_trace_server_url from weave.trace_server import requests from weave.trace_server import trace_server_interface as tsi -from weave.trace_server.async_batch_processor import AsyncBatchProcessor +from weave.trace_server_bindings.async_batch_processor import AsyncBatchProcessor from weave.wandb_interface import project_creator logger = logging.getLogger(__name__) +# Default timeout values (in seconds) +# DEFAULT_CONNECT_TIMEOUT = 10 +# DEFAULT_READ_TIMEOUT = 30 +# DEFAULT_TIMEOUT = (DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT) + class StartBatchItem(BaseModel): mode: str = "start" @@ -129,12 +134,34 @@ def set_auth(self, auth: tuple[str, str]) -> None: retry_error_callback=_log_failure, reraise=True, ) + def _send_batch_to_server(self, encoded_data: bytes) -> None: + """Send a batch of data to the server with retry logic. + + This method is separated from _flush_calls to avoid recursive retries. + """ + r = requests.post( + self.trace_server_url + "/call/upsert_batch", + data=encoded_data, # type: ignore + auth=self._auth, + # timeout=DEFAULT_TIMEOUT, + ) + if r.status_code == 413: + # handle 413 explicitly to provide actionable error message + reason = json.loads(r.text)["reason"] + raise requests.HTTPError(f"413 Client Error: {reason}", response=r) + r.raise_for_status() + def _flush_calls( self, batch: list, *, _should_update_batch_size: bool = True, ) -> None: + """Process a batch of calls, splitting if necessary and sending to the server. + + This method handles the logic of splitting batches that are too large, + but delegates the actual server communication (with retries) to _send_batch_to_server. + """ if len(batch) == 0: return @@ -150,23 +177,15 @@ def _flush_calls( ) self.call_processor.max_batch_size = max(1, target_batch_size) - # If the batch is too big, recursively split it in half + # If the batch is too big, split it in half and process each half if encoded_bytes > self.remote_request_bytes_limit and len(batch) > 1: split_idx = int(len(batch) // 2) self._flush_calls(batch[:split_idx], _should_update_batch_size=False) self._flush_calls(batch[split_idx:], _should_update_batch_size=False) return - r = requests.post( - self.trace_server_url + "/call/upsert_batch", - data=encoded_data, - auth=self._auth, - ) - if r.status_code == 413: - # handle 413 explicitly to provide actionable error message - reason = json.loads(r.text)["reason"] - raise requests.HTTPError(f"413 Client Error: {reason}", response=r) - r.raise_for_status() + # Send the batch to the server with retries + self._send_batch_to_server(encoded_data) @tenacity.retry( stop=tenacity.stop_after_delay(REMOTE_REQUEST_RETRY_DURATION), @@ -193,6 +212,7 @@ def _generic_request_executor( data=req.model_dump_json(by_alias=True).encode("utf-8"), auth=self._auth, stream=stream, + # timeout=DEFAULT_TIMEOUT, ) if r.status_code == 500: reason_val = r.text @@ -245,7 +265,10 @@ def _generic_stream_request( reraise=True, ) def server_info(self) -> ServerInfoRes: - r = requests.get(self.trace_server_url + "/server_info") + r = requests.get( + self.trace_server_url + "/server_info", + # timeout=DEFAULT_TIMEOUT, + ) r.raise_for_status() return ServerInfoRes.model_validate(r.json()) @@ -478,6 +501,7 @@ def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: auth=self._auth, data={"project_id": req.project_id}, files={"file": (req.name, req.content)}, + # timeout=DEFAULT_TIMEOUT, ) r.raise_for_status() return tsi.FileCreateRes.model_validate(r.json()) @@ -497,6 +521,7 @@ def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadR self.trace_server_url + "/files/content", json={"project_id": req.project_id, "digest": req.digest}, auth=self._auth, + # timeout=DEFAULT_TIMEOUT, ) r.raise_for_status() # TODO: Should stream to disk rather than to memory diff --git a/weave/type_handlers/Audio/audio.py b/weave/type_handlers/Audio/audio.py index 090f7d11edd0..ecb5f172b142 100644 --- a/weave/type_handlers/Audio/audio.py +++ b/weave/type_handlers/Audio/audio.py @@ -1,7 +1,7 @@ import wave -from weave.trace import serializer -from weave.trace.custom_objs import MemTraceFilesArtifact +from weave.trace.serialization import serializer +from weave.trace.serialization.custom_objs import MemTraceFilesArtifact AUDIO_FILE_NAME = "audio.wav" diff --git a/weave/type_handlers/Image/image.py b/weave/type_handlers/Image/image.py index 6e414d79044b..4a2517dfb836 100644 --- a/weave/type_handlers/Image/image.py +++ b/weave/type_handlers/Image/image.py @@ -4,8 +4,8 @@ import logging -from weave.trace import serializer -from weave.trace.custom_objs import MemTraceFilesArtifact +from weave.trace.serialization import serializer +from weave.trace.serialization.custom_objs import MemTraceFilesArtifact from weave.utils.invertable_dict import InvertableDict try: diff --git a/weave/utils/iterators.py b/weave/utils/iterators.py new file mode 100644 index 000000000000..21dddecaa31b --- /dev/null +++ b/weave/utils/iterators.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from collections.abc import Generator, Iterator, Sequence +from threading import Lock +from typing import Any, TypeVar, overload + +T = TypeVar("T") + + +class ThreadSafeLazyList(Sequence[T]): + """ + Provides a thread-safe, iterable sequence by caching results in memory. + + This class is thread-safe and supports multiple iterations over the same data and + concurrent access. + + Args: + single_use_iterator: The source iterator whose values will be cached. (must terminate!) + known_length: Optional pre-known length of the iterator. If provided, can improve + performance by avoiding the need to exhaust the iterator to determine length. + + Thread Safety: + All operations are thread-safe through the use of internal locking. + """ + + _single_use_iterator: Iterator[T] + + def __init__( + self, single_use_iterator: Iterator[T], known_length: int | None = None + ) -> None: + self._lock = Lock() + self._single_use_iterator = single_use_iterator + self._list: list[T] = [] + self._stop_reached = False + self._known_length = known_length + + def _seek_to_index(self, index: int) -> None: + """ + Advances the iterator until the specified index is reached or iterator is exhausted. + Thread-safe operation. + """ + with self._lock: + while index >= len(self._list): + try: + self._list.append(next(self._single_use_iterator)) + except StopIteration: + self._stop_reached = True + return + + def _seek_to_end(self) -> None: + """ + Exhausts the iterator, caching all remaining values. + Thread-safe operation. + """ + with self._lock: + while not self._stop_reached: + try: + self._list.append(next(self._single_use_iterator)) + except StopIteration: + self._stop_reached = True + return + + @overload + def __getitem__(self, index: int) -> T: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[T]: ... + + def __getitem__(self, index: int | slice) -> T | Sequence[T]: + """ + Returns the item at the specified index. + + Args: + index: The index of the desired item. + + Returns: + The item at the specified index. + + Raises: + IndexError: If the index is out of range. + """ + if isinstance(index, slice): + if index.stop is None: + self._seek_to_end() + else: + self._seek_to_index(index.stop) + return self._list[index] + else: + self._seek_to_index(index) + return self._list[index] + + def __len__(self) -> int: + """ + Returns the total length of the sequence. + + If known_length was provided at initialization, returns that value. + Otherwise, exhausts the iterator to determine the length. + + Returns: + The total number of items in the sequence. + """ + if self._known_length is not None: + return self._known_length + + self._seek_to_end() + return len(self._list) + + def __iter__(self) -> Iterator[T]: + """ + Returns an iterator over the sequence. + + The returned iterator is safe to use concurrently with other operations + on this sequence. + + Returns: + An iterator yielding all items in the sequence. + """ + + def _iter() -> Generator[T, None, None]: + i = 0 + while True: + try: + val = self[i] + except IndexError: + return + try: + yield val + finally: + i += 1 + + return _iter() + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Sequence): + return False + if len(self) != len(other): + return False + self._seek_to_end() + for a, b in zip(self._list, other): + if a != b: + return False + return True diff --git a/weave_query/tests/test_arrow_concat.py b/weave_query/tests/test_arrow_concat.py index 578138a93abc..20ff36eafaf7 100644 --- a/weave_query/tests/test_arrow_concat.py +++ b/weave_query/tests/test_arrow_concat.py @@ -38,28 +38,30 @@ [1, 2, 3], [ wbmedia.ImageArtifactFileRef( - UNSAVED_TEST_ARTIFACT, - "path", - "format", - 25, - 35, - "sha256", - None, - None, - None, + artifact=UNSAVED_TEST_ARTIFACT, + path="path", + format="format", + height=25, + width=35, + sha256="sha256", + caption=None, + boxes=None, + masks=None, + classes=None, ) ], ), ( [ wbmedia.ImageArtifactFileRef( - UNSAVED_TEST_ARTIFACT, - "path1", - "format1", - 25, - 35, - "sha256-1", - { + artifact=UNSAVED_TEST_ARTIFACT, + path="path1", + format="format1", + height=25, + width=35, + sha256="sha256-1", + caption=None, + boxes={ "box": [ { "position": { @@ -73,21 +75,22 @@ } ] }, - None, - None, + masks=None, + classes=None, ) ], [ wbmedia.ImageArtifactFileRef( - UNSAVED_TEST_ARTIFACT, - "path2", - "format2", - 50, - 70, - "sha256-2", - None, - None, - None, + artifact=UNSAVED_TEST_ARTIFACT, + path="path2", + format="format2", + height=50, + width=70, + sha256="sha256-2", + caption=None, + boxes=None, + masks=None, + classes=None, ) ], ), @@ -97,15 +100,16 @@ { "a": 7, "b": wbmedia.ImageArtifactFileRef( - UNSAVED_TEST_ARTIFACT, - "path2", - "format2", - 50, - 70, - "sha256-2", - None, - None, - None, + artifact=UNSAVED_TEST_ARTIFACT, + path="path2", + format="format2", + height=50, + width=70, + sha256="sha256-2", + caption=None, + boxes=None, + masks=None, + classes=None, ), } ], @@ -116,13 +120,14 @@ { "a": 7, "b": wbmedia.ImageArtifactFileRef( - UNSAVED_TEST_ARTIFACT, - "path2", - "format2", - 50, - 70, - "sha256-2", - { + artifact=UNSAVED_TEST_ARTIFACT, + path="path2", + format="format2", + height=50, + width=70, + sha256="sha256-2", + caption=None, + boxes={ "box": [ { "position": { @@ -136,8 +141,8 @@ } ] }, - None, - None, + masks=None, + classes=None, ), } ], @@ -152,13 +157,14 @@ { "a": 7, "b": wbmedia.ImageArtifactFileRef( - UNSAVED_TEST_ARTIFACT, - "path2", - "format2", - 50, - 70, - "sha256-2", - { + artifact=UNSAVED_TEST_ARTIFACT, + path="path2", + format="format2", + height=50, + width=70, + sha256="sha256-2", + caption=None, + boxes={ "box": [ { "position": { @@ -172,8 +178,8 @@ } ] }, - None, - None, + masks=None, + classes=None, ), }, ], @@ -183,13 +189,14 @@ { "a": 7, "b": wbmedia.ImageArtifactFileRef( - UNSAVED_TEST_ARTIFACT, - "path2", - "format2", - 50, - 70, - "sha256-2", - { + artifact=UNSAVED_TEST_ARTIFACT, + path="path2", + format="format2", + height=50, + width=70, + sha256="sha256-2", + caption=None, + boxes={ "box": [ { "position": { @@ -203,22 +210,23 @@ } ] }, - None, - None, + masks=None, + classes=None, ), }, { "a": 45, "b": wbmedia.ImageArtifactFileRef( - UNSAVED_TEST_ARTIFACT, - "path2", - "format2", - 50, - 70, - "sha256-25", - {"box": []}, - None, - None, + artifact=UNSAVED_TEST_ARTIFACT, + path="path2", + format="format2", + height=50, + width=70, + sha256="sha256-25", + caption=None, + boxes={"box": []}, + masks=None, + classes=None, ), }, ], @@ -227,15 +235,16 @@ { "a": 11, "b": wbmedia.ImageArtifactFileRef( - UNSAVED_TEST_ARTIFACT, - "path2", - "format2", - 67, - 14, - "sha256-3", - {"box": []}, - None, - None, + artifact=UNSAVED_TEST_ARTIFACT, + path="path2", + format="format2", + height=67, + width=14, + sha256="sha256-3", + caption=None, + boxes={"box": []}, + masks=None, + classes=None, ), }, ], @@ -358,15 +367,16 @@ def test_concat_wbtyped(l1, l2, l1_wb_type, l2_wb_type): def test_image_ref(): py = wbmedia.ImageArtifactFileRef( - "artifact1", - "path1", - "format1", - 25, - 35, - "sha256-1", - None, - None, - None, + artifact="artifact1", + path="path1", + format="format1", + height=25, + width=35, + sha256="sha256-1", + caption=None, + boxes=None, + masks=None, + classes=None, ) a = to_arrow([py]) py2 = a.to_pylist_tagged()[0] diff --git a/weave_query/tests/test_wb_data_types.py b/weave_query/tests/test_wb_data_types.py index 1e2b934399b2..d1516006eb8e 100644 --- a/weave_query/tests/test_wb_data_types.py +++ b/weave_query/tests/test_wb_data_types.py @@ -323,6 +323,7 @@ def exp_raw_data(commit_hash: str): "sha256": "00c619b2faa45fdc9ce6de014e7aef7839c9de725bf78b528ef47d279039aacf", } }, + "caption": "", }, }, { @@ -377,6 +378,7 @@ def exp_raw_data(commit_hash: str): "sha256": "00c619b2faa45fdc9ce6de014e7aef7839c9de725bf78b528ef47d279039aacf", } }, + "caption": "", }, }, { @@ -431,6 +433,7 @@ def exp_raw_data(commit_hash: str): "sha256": "00c619b2faa45fdc9ce6de014e7aef7839c9de725bf78b528ef47d279039aacf", } }, + "caption": "", }, }, { @@ -485,6 +488,7 @@ def exp_raw_data(commit_hash: str): "sha256": "00c619b2faa45fdc9ce6de014e7aef7839c9de725bf78b528ef47d279039aacf", } }, + "caption": "", }, }, ] diff --git a/weave_query/tests/test_wb_tables.py b/weave_query/tests/test_wb_tables.py index 4652855b92a2..c634240d8819 100644 --- a/weave_query/tests/test_wb_tables.py +++ b/weave_query/tests/test_wb_tables.py @@ -177,6 +177,7 @@ def test_join_table_with_images(fake_wandb): "sha256": "e7bdc527afd649f51950b4524b0c15aecaf7f484448a6cdfcdc2ecd9bba0f5a7", "boxes": {}, "masks": {}, + "caption": "", }, }, "1": {"name": "a", "score": 1.0}, @@ -196,6 +197,7 @@ def test_join_table_with_images(fake_wandb): "sha256": "61cd2467cff9f0666c730c57d065cfe834765ba26514b46f91735c750676876a", "boxes": {}, "masks": {}, + "caption": "", }, }, "1": {"name": "b", "score": 2.0}, diff --git a/weave_query/tests/test_weave_types.py b/weave_query/tests/test_weave_types.py index 917c000ad3c2..a0ac73b08bca 100644 --- a/weave_query/tests/test_weave_types.py +++ b/weave_query/tests/test_weave_types.py @@ -699,6 +699,7 @@ def test_init_image(): "_is_object": True, "boxLayers": {"a": [1, 2, 3]}, "boxScoreKeys": [], + "caption": "", "classMap": {}, "maskLayers": {}, "type": "image-file", diff --git a/weave_query/weave_query/ops_domain/table.py b/weave_query/weave_query/ops_domain/table.py index 6708d62edee9..3e26ed44ba88 100644 --- a/weave_query/weave_query/ops_domain/table.py +++ b/weave_query/weave_query/ops_domain/table.py @@ -363,6 +363,7 @@ def _create_media_type_for_cell(cell: dict) -> typing.Any: boxes=cell.get("boxes", {}), # type: ignore masks=cell.get("masks", {}), # type: ignore classes=cell.get("classes"), # type: ignore + caption=cell.get("caption", ""), ) elif file_type in [ "audio-file", diff --git a/weave_query/weave_query/ops_domain/wb_util.py b/weave_query/weave_query/ops_domain/wb_util.py index aebaa445952f..da81998e9722 100644 --- a/weave_query/weave_query/ops_domain/wb_util.py +++ b/weave_query/weave_query/ops_domain/wb_util.py @@ -110,6 +110,7 @@ def _process_run_dict_item(val, run_path: typing.Optional[RunPath] = None): width=val["width"], height=val["height"], sha256=val["sha256"], + caption=val.get("caption", ""), # boxes=val.get("boxes", {}), # masks=val.get("masks", {}), ) diff --git a/weave_query/weave_query/ops_domain/wbmedia.py b/weave_query/weave_query/ops_domain/wbmedia.py index 8f566b7c8b36..58f8a0f866c1 100644 --- a/weave_query/weave_query/ops_domain/wbmedia.py +++ b/weave_query/weave_query/ops_domain/wbmedia.py @@ -53,6 +53,7 @@ class ImageArtifactFileRefType(types.ObjectType): boxScoreKeys: typing.Union[types.Type, list] = types.List(types.UnknownType()) maskLayers: typing.Union[types.Type, dict] = types.TypedDict({}) classMap: typing.Union[types.Type, dict] = types.TypedDict({}) + caption: typing.Union[types.Type, str] = types.Const(types.String(), "") # TODO: This should probably be standard for Type! def __post_init__(self): @@ -69,6 +70,7 @@ def _to_dict(self) -> dict: d["boxScoreKeys"] = types.constliteral_type_to_json(self.boxScoreKeys) # type: ignore d["maskLayers"] = types.constliteral_type_to_json(self.maskLayers) # type: ignore d["classMap"] = types.constliteral_type_to_json(self.classMap) # type: ignore + d["caption"] = types.constliteral_type_to_json(self.caption) # type: ignore return d # TODO: This should probably be standard for Type! @@ -79,6 +81,7 @@ def from_dict(cls, d): d.get("boxScoreKeys", []), d.get("maskLayers", {}), d.get("classMap", {}), + d.get("caption", ""), ) def property_types(self) -> dict[str, types.Type]: @@ -155,6 +158,7 @@ def property_types(self) -> dict[str, types.Type]: } ) ), + "caption": types.optional(types.String()), } return res @@ -194,6 +198,7 @@ def type_of_instance(cls, obj): boxScoreKeys=boxScoreKeys, maskLayers=maskLayers, classMap={}, + caption=obj.caption, ) @@ -206,6 +211,7 @@ class ImageArtifactFileRef: height: int width: int sha256: str + caption: typing.Optional[str] = "" boxes: typing.Optional[dict[str, list[dict]]] = dataclasses.field( default_factory=dict ) # type: ignore diff --git a/weave_query/weave_query/wandb_file_manager.py b/weave_query/weave_query/wandb_file_manager.py index 6c5949de4c37..a69b088d45a2 100644 --- a/weave_query/weave_query/wandb_file_manager.py +++ b/weave_query/weave_query/wandb_file_manager.py @@ -66,6 +66,17 @@ def _local_path_and_download_url( else: # TODO: storage_region storage_region = "default" + if isinstance(art_uri, artifact_wandb.WeaveWBArtifactURI): + return file_path, "{}/artifactsV2/{}/{}/{}/{}/{}/{}/{}".format( + base_url, + storage_region, + art_uri.entity_name, + art_uri.project_name, + art_uri.name, + urllib.parse.quote(manifest_entry.get("birthArtifactID", "")), # type: ignore + md5_hex, + urllib.parse.quote(file_name), + ) # For artifactsV2 (which is all artifacts now), the file download handler ignores the entity # parameter while parsing the url, and fetches the files directly via the artifact id # Refer to: https://github.com/wandb/core/blob/7cfee1cd07ddc49fe7ba70ce3d213d2a11bd4456/services/gorilla/api/handler/artifacts.go#L179