Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sketches Schema UI view #1015

Merged
merged 6 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 5 additions & 29 deletions ui/sdk/src/hamilton_sdk/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def post_node_execute(
if success:
task_run.status = Status.SUCCESS
task_run.result_type = type(result)
result_summary = runs.process_result(result, node_)
result_summary, schema, additional_attributes = runs.process_result(result, node_)
if result_summary is None:
result_summary = {
"observability_type": "observability_failure",
Expand All @@ -270,18 +270,7 @@ def post_node_execute(
"value": "Failed to process result.",
},
}
# NOTE This is a temporary hack to make process_result() able to return
# more than one object that will be used as UI "task attributes".
# There's a conflict between `TaskRun.result_summary` that expect a single
# dict from process_result() and the `HamiltonTracker.post_node_execute()`
# that can more freely handle "stats" to create multiple "task attributes"
elif isinstance(result_summary, dict):
result_summary = result_summary
elif isinstance(result_summary, list):
other_results = [obj for obj in result_summary[1:]]
result_summary = result_summary[0]
else:
raise TypeError("`process_result()` needs to return a dict or sequence of dict")
other_results = ([schema] if schema is not None else []) + additional_attributes

task_run.result_summary = result_summary
task_attr = dict(
Expand Down Expand Up @@ -546,12 +535,13 @@ async def post_node_execute(
task_run = self.task_runs[run_id][node_.name]
tracking_state = self.tracking_states[run_id]
task_run.end_time = datetime.datetime.now(timezone.utc)

other_results = []

if success:
task_run.status = Status.SUCCESS
task_run.result_type = type(result)
result_summary = runs.process_result(result, node_) # add node
result_summary, schema, additional = runs.process_result(result, node_) # add node
other_results = ([schema] if schema is not None else []) + additional
if result_summary is None:
result_summary = {
"observability_type": "observability_failure",
Expand All @@ -561,19 +551,6 @@ async def post_node_execute(
"value": "Failed to process result.",
},
}
# NOTE This is a temporary hack to make process_result() able to return
# more than one object that will be used as UI "task attributes".
# There's a conflict between `TaskRun.result_summary` that expect a single
# dict from process_result() and the `HamiltonTracker.post_node_execute()`
# that can more freely handle "stats" to create multiple "task attributes"
elif isinstance(result_summary, dict):
result_summary = result_summary
elif isinstance(result_summary, list):
other_results = [obj for obj in result_summary[1:]]
result_summary = result_summary[0]
else:
raise TypeError("`process_result()` needs to return a dict or sequence of dict")

task_run.result_summary = result_summary
task_attr = dict(
node_name=get_node_name(node_, task_id),
Expand Down Expand Up @@ -603,7 +580,6 @@ async def post_node_execute(
attribute_role="error",
)

# `result_summary` or "error" is first because the order influences UI display order
attributes = [task_attr]
for i, other_result in enumerate(other_results):
other_attr = dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
import json
from functools import singledispatch
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional

import pandas as pd
from hamilton_sdk.tracking import sql_utils

StatsType = Dict[str, Any]
# Multiple observations per are allowed
ObservationType = Dict[str, Any]


@singledispatch
def compute_stats(result, node_name: str, node_tags: dict) -> Union[StatsType, List[StatsType]]:
def compute_schema(result, node_name: str, node_tags: dict) -> Optional[ObservationType]:
"""The default schema will be None, and filtered out.
We can polymoorphically implement this for different types of results.

:param result:
:param node_name:
:param node_tags:
:return:
"""
return None


@singledispatch
def compute_stats(result, node_name: str, node_tags: dict) -> Optional[ObservationType]:
"""This is the default implementation for computing stats on a result.

All other implementations should be registered with the `@compute_stats.register` decorator.
Expand All @@ -29,11 +43,24 @@ def compute_stats(result, node_name: str, node_tags: dict) -> Union[StatsType, L
}


@singledispatch
def compute_additional_results(result, node_name: str, node_tags: dict) -> List[ObservationType]:
"""The default schema will be None, and filtered out.
We can polymoorphically implement this for different types of results.

:param result:
:param node_name:
:param node_tags:
:return:
"""
return []


@compute_stats.register(str)
@compute_stats.register(int)
@compute_stats.register(float)
@compute_stats.register(bool)
def compute_stats_primitives(result, node_name: str, node_tags: dict) -> StatsType:
def compute_stats_primitives(result, node_name: str, node_tags: dict) -> ObservationType:
return {
"observability_type": "primitive",
"observability_value": {
Expand All @@ -45,7 +72,7 @@ def compute_stats_primitives(result, node_name: str, node_tags: dict) -> StatsTy


@compute_stats.register(dict)
def compute_stats_dict(result: dict, node_name: str, node_tags: dict) -> StatsType:
def compute_stats_dict(result: dict, node_name: str, node_tags: dict) -> ObservationType:
"""call summary stats on the values in the dict"""
try:
# if it's JSON serializable, take it.
Expand Down Expand Up @@ -94,7 +121,7 @@ def compute_stats_dict(result: dict, node_name: str, node_tags: dict) -> StatsTy


@compute_stats.register(tuple)
def compute_stats_tuple(result: tuple, node_name: str, node_tags: dict) -> StatsType:
def compute_stats_tuple(result: tuple, node_name: str, node_tags: dict) -> ObservationType:
if "hamilton.data_loader" in node_tags and node_tags["hamilton.data_loader"] is True:
# assumption it's a tuple
if isinstance(result[1], dict):
Expand Down Expand Up @@ -141,7 +168,7 @@ def compute_stats_tuple(result: tuple, node_name: str, node_tags: dict) -> Stats


@compute_stats.register(list)
def compute_stats_list(result: list, node_name: str, node_tags: dict) -> StatsType:
def compute_stats_list(result: list, node_name: str, node_tags: dict) -> ObservationType:
"""call summary stats on the values in the list"""
try:
# if it's JSON serializable, take it.
Expand Down
6 changes: 3 additions & 3 deletions ui/sdk/src/hamilton_sdk/tracking/ibis_stats.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict

from hamilton_sdk.tracking import stats
from hamilton_sdk.tracking import data_observation
from ibis.expr.datatypes import core

# import ibis.expr.types as ir
Expand Down Expand Up @@ -73,11 +73,11 @@ def _introspect(table: relations.Table) -> Dict[str, Any]:
}


@stats.compute_stats.register
@data_observation.compute_schema.register
def compute_stats_ibis_table(
result: relations.Table, node_name: str, node_tags: dict
) -> Dict[str, Any]:
# TODO: create custom type instead of dict for UI
# TODO: use the schema type
o_value = _introspect(result)
return {
"observability_type": "dict",
Expand Down
12 changes: 6 additions & 6 deletions ui/sdk/src/hamilton_sdk/tracking/langchain_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

from typing import Any, Dict

from hamilton_sdk.tracking import stats
from hamilton_sdk.tracking import data_observation
from langchain_core import documents as lc_documents
from langchain_core import messages as lc_messages


@stats.compute_stats.register(lc_messages.BaseMessage)
@data_observation.compute_stats.register(lc_messages.BaseMessage)
def compute_stats_lc_messages(
result: lc_messages.BaseMessage, node_name: str, node_tags: dict
) -> Dict[str, Any]:
Expand All @@ -22,12 +22,12 @@ def compute_stats_lc_messages(
}


@stats.compute_stats.register(lc_documents.Document)
@data_observation.compute_stats.register(lc_documents.Document)
def compute_stats_lc_docs(
result: lc_documents.Document, node_name: str, node_tags: dict
) -> Dict[str, Any]:
if hasattr(result, "to_document"):
return stats.compute_stats(result.to_document(), node_name, node_tags)
return data_observation.compute_stats(result.to_document(), node_name, node_tags)
else:
# d.page_content # hack because not all documents are serializable
result = {"content": result.page_content, "metadata": result.metadata}
Expand All @@ -43,7 +43,7 @@ def compute_stats_lc_docs(
from langchain_core import messages

msg = messages.BaseMessage(content="Hello, World!", type="greeting")
print(stats.compute_stats(msg, "greeting", {}))
print(data_observation.compute_stats(msg, "greeting", {}))

doc = lc_documents.Document(page_content="Hello, World!", metadata={"source": "local_dir"})
print(stats.compute_stats(doc, "document", {}))
print(data_observation.compute_stats(doc, "document", {}))
4 changes: 2 additions & 2 deletions ui/sdk/src/hamilton_sdk/tracking/numpy_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import numpy as np
import pandas as pd
from hamilton_sdk.tracking import pandas_stats, stats
from hamilton_sdk.tracking import data_observation, pandas_stats

"""Module that houses functions to compute statistics on numpy objects
Notes:
- we should assume numpy v1.0+ so that we have a string type
"""


@stats.compute_stats.register
@data_observation.compute_stats.register
def compute_stats_numpy(result: np.ndarray, node_name: str, node_tags: dict) -> Dict[str, Any]:
try:
df = pd.DataFrame(result) # hack - reuse pandas stuff
Expand Down
44 changes: 23 additions & 21 deletions ui/sdk/src/hamilton_sdk/tracking/pandas_stats.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Any, Dict, List, Union
from typing import Any, Dict, Optional, Union

import pandas as pd
from hamilton_sdk.tracking import data_observation
from hamilton_sdk.tracking import pandas_col_stats as pcs
from hamilton_sdk.tracking import stats

from hamilton import driver

Expand Down Expand Up @@ -84,34 +84,36 @@ def execute_col(
return stats


@stats.compute_stats.register
@data_observation.compute_stats.register
def compute_stats_df(
result: pd.DataFrame, node_name: str, node_tags: dict
) -> List[stats.StatsType]:
) -> data_observation.ObservationType:
summary_stats = _compute_stats(result)

results = [
{
"observability_type": "dagworks_describe",
"observability_value": summary_stats,
"observability_schema_version": "0.0.3",
},
]
return {
"observability_type": "dagworks_describe",
"observability_value": summary_stats,
"observability_schema_version": "0.0.3",
}


@data_observation.compute_schema.register
def compute_schema(
result: pd.DataFrame, node_name: str, node_tags: dict
) -> Optional[data_observation.ObservationType]:
if h_schema is not None:
schema = h_schema._get_arrow_schema(result)
schema.with_metadata(dict(name=node_name))
results.append(
{
"observability_type": "schema",
"observability_value": h_schema.pyarrow_schema_to_json(schema),
"observability_schema_version": "0.0.1",
"name": "Schema",
}
)
return results
return {
"observability_type": "schema",
"observability_value": h_schema.pyarrow_schema_to_json(schema),
"observability_schema_version": "0.0.1",
"name": "Schema",
}
return None


@stats.compute_stats.register
@data_observation.compute_stats.register
def compute_stats_series(result: pd.Series, node_name: str, node_tags: dict) -> Dict[str, Any]:
col_name = result.name if result.name else node_name
return {
Expand Down
6 changes: 3 additions & 3 deletions ui/sdk/src/hamilton_sdk/tracking/polars_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

if not hasattr(pl, "Series"):
raise ImportError("Polars is not installed")
from hamilton_sdk.tracking import data_observation
from hamilton_sdk.tracking import polars_col_stats as pls
from hamilton_sdk.tracking import stats

from hamilton import driver

Expand Down Expand Up @@ -83,7 +83,7 @@ def execute_col(target_output: str, col: pl.Series, name: str, position: int) ->
return stats


@stats.compute_stats.register
@data_observation.compute_stats.register
def compute_stats_df(result: pl.DataFrame, node_name: str, node_tags: dict) -> Dict[str, Any]:
return {
"observability_type": "dagworks_describe",
Expand All @@ -92,7 +92,7 @@ def compute_stats_df(result: pl.DataFrame, node_name: str, node_tags: dict) -> D
}


@stats.compute_stats.register
@data_observation.compute_stats.register
def compute_stats_series(result: pl.Series, node_name: str, node_tags: dict) -> Dict[str, Any]:
return {
"observability_type": "dagworks_describe",
Expand Down
4 changes: 2 additions & 2 deletions ui/sdk/src/hamilton_sdk/tracking/pydantic_stats.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Dict

import pydantic
from hamilton_sdk.tracking import stats
from hamilton_sdk.tracking import data_observation


@stats.compute_stats.register
@data_observation.compute_stats.register
def compute_stats_pydantic(
result: pydantic.BaseModel, node_name: str, node_tags: dict
) -> Dict[str, Any]:
Expand Down
Loading
Loading