Skip to content

Commit

Permalink
WIP: improve logging
Browse files Browse the repository at this point in the history
  • Loading branch information
stellasia committed Jan 2, 2025
1 parent 39a4b73 commit e4342e8
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 25 deletions.
5 changes: 5 additions & 0 deletions examples/build_graph/simple_kg_builder_from_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import asyncio
import logging

import neo4j
from neo4j_graphrag.embeddings import OpenAIEmbeddings
Expand All @@ -20,6 +21,10 @@
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.llm.openai_llm import OpenAILLM

logging.basicConfig()
logging.getLogger("neo4j_graphrag").setLevel(logging.DEBUG)


# Neo4j db infos
URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError
from neo4j_graphrag.generation.prompts import ERExtractionTemplate, PromptTemplate
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.utils import prettyfier

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -221,8 +222,9 @@ async def extract_for_chunk(
)
else:
logger.error(
f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk.index}"
f"LLM response is not valid JSON for chunk_index={chunk.index}"
)
logger.debug(f"Invalid JSON: {llm_result.content}")
result = {"nodes": [], "relationships": []}
try:
chunk_graph = Neo4jGraph(**result)
Expand All @@ -233,8 +235,9 @@ async def extract_for_chunk(
)
else:
logger.error(
f"LLM response has improper format {result} for chunk_index={chunk.index}"
f"LLM response has improper format for chunk_index={chunk.index}"
)
logger.debug(f"Invalid JSON format: {result}")
chunk_graph = Neo4jGraph()
return chunk_graph

Expand Down Expand Up @@ -336,5 +339,5 @@ async def run(
]
chunk_graphs: list[Neo4jGraph] = list(await asyncio.gather(*tasks))
graph = self.combine_chunk_graphs(lexical_graph, chunk_graphs)
logger.debug(f"{self.__class__.__name__}: {graph}")
logger.debug(f"Extracted graph: {prettyfier(graph)}")
return graph
25 changes: 23 additions & 2 deletions src/neo4j_graphrag/experimental/components/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
from __future__ import annotations

import uuid
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, TYPE_CHECKING

from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, field_validator, RootModel

from neo4j_graphrag.experimental.pipeline.component import DataModel


if TYPE_CHECKING:
from pydantic._internal import _repr


class DocumentInfo(DataModel):
"""A document loaded by a DataLoader.
Expand Down Expand Up @@ -75,6 +79,20 @@ class TextChunks(DataModel):
chunks: list[TextChunk]


# class Embeddings(RootModel):
# """A wrapper around list[float] to represent embeddings.
# Used to improve logging of vectors by not showing the full vector.
# """
# root: list[float]
#
# # def __rep_str__(self, sep: str = ", ") -> str:
# # return f"<Embeddings: dimension={len(self.root)}, vector[:3]={self.root[:3]}>"
#
# def __repr_args__(self) -> _repr.ReprArgs:
# yield 'dimension', len(self.root)
# yield 'vector', self.root[:3]
#

class Neo4jNode(BaseModel):
"""Represents a Neo4j node.
Expand Down Expand Up @@ -129,6 +147,9 @@ class Neo4jGraph(DataModel):
nodes: list[Neo4jNode] = []
relationships: list[Neo4jRelationship] = []

# def __str__(self) -> str:
# return f"<Neo4jGraph: {len(self.nodes)} nodes, {len(self.relationships)} relationships>"


class ResolutionStats(DataModel):
number_of_nodes_to_resolve: int
Expand Down
3 changes: 3 additions & 0 deletions src/neo4j_graphrag/experimental/pipeline/config/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class PipelineConfigWrapper(BaseModel):
] = Field(discriminator=Discriminator(_get_discriminator_value))

def parse(self, resolved_data: dict[str, Any] | None = None) -> PipelineDefinition:
logger.debug("PIPELINE_CONFIG: start parsing config...")
return self.config.parse(resolved_data)

def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -101,10 +102,12 @@ def from_config(
cls, config: AbstractPipelineConfig | dict[str, Any], do_cleaning: bool = False
) -> Self:
wrapper = PipelineConfigWrapper.model_validate({"config": config})
logger.debug(f"PIPELINE_RUNNER: instantiate Pipeline from config type: {wrapper.config.template_}")
return cls(wrapper.parse(), config=wrapper.config, do_cleaning=do_cleaning)

@classmethod
def from_config_file(cls, file_path: Union[str, Path]) -> Self:
logger.info(f"PIPELINE_RUNNER: reading config file from {file_path}")
if not isinstance(file_path, str):
file_path = str(file_path)
data = ConfigReader().read(file_path)
Expand Down
24 changes: 13 additions & 11 deletions src/neo4j_graphrag/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from timeit import default_timer
from typing import Any, AsyncGenerator, Optional

from neo4j_graphrag.utils import prettyfier

try:
import pygraphviz as pgv
except ImportError:
Expand Down Expand Up @@ -90,21 +92,19 @@ async def execute(self, **kwargs: Any) -> RunResult | None:
if the task run successfully, None if the status update
was unsuccessful.
"""
logger.debug(f"Running component {self.name} with {kwargs}")
start_time = default_timer()
component_result = await self.component.run(**kwargs)
run_result = RunResult(
result=component_result,
)
end_time = default_timer()
logger.debug(f"Component {self.name} finished in {end_time - start_time}s")
return run_result

async def run(self, inputs: dict[str, Any]) -> RunResult | None:
"""Main method to execute the task."""
logger.debug(f"TASK START {self.name=} {inputs=}")
logger.debug(f"TASK START {self.name=} input={prettyfier(inputs)}")
start_time = default_timer()
res = await self.execute(**inputs)
logger.debug(f"TASK RESULT {self.name=} {res=}")
end_time = default_timer()
logger.debug(f"TASK FINISHED {self.name} in {end_time - start_time} res={prettyfier(res)}")
return res


Expand Down Expand Up @@ -141,7 +141,7 @@ async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
try:
await self.set_task_status(task.name, RunStatus.RUNNING)
except PipelineStatusUpdateError:
logger.info(f"Component {task.name} already running or done")
logger.debug(f"ORCHESTRATOR: TASK ABORTED: {task.name} is already running or done, aborting")
return None
res = await task.run(inputs)
await self.set_task_status(task.name, RunStatus.DONE)
Expand Down Expand Up @@ -198,7 +198,8 @@ async def check_dependencies_complete(self, task: TaskPipelineNode) -> None:
d_status = await self.get_status_for_component(d.start)
if d_status != RunStatus.DONE:
logger.debug(
f"Missing dependency {d.start} for {task.name} (status: {d_status}). "
f"ORCHESTRATOR {self.run_id}: TASK DELAYED: Missing dependency {d.start} for {task.name} "
f"(status: {d_status}). "
"Will try again when dependency is complete."
)
raise PipelineMissingDependencyError()
Expand Down Expand Up @@ -227,6 +228,7 @@ async def next(
await self.check_dependencies_complete(next_node)
except PipelineMissingDependencyError:
continue
logger.debug(f"ORCHESTRATOR {self.run_id}: enqueuing next task: {next_node.name}")
yield next_node
return

Expand Down Expand Up @@ -315,7 +317,6 @@ async def run(self, data: dict[str, Any]) -> None:
(node without any parent). Then the callback on_task_complete
will handle the task dependencies.
"""
logger.debug(f"PIPELINE START {data=}")
tasks = [self.run_task(root, data) for root in self.pipeline.roots()]
await asyncio.gather(*tasks)

Expand Down Expand Up @@ -624,15 +625,16 @@ def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool:
return True

async def run(self, data: dict[str, Any]) -> PipelineResult:
logger.debug("Starting pipeline")
logger.debug("PIPELINE START")
start_time = default_timer()
self.invalidate()
self.validate_input_data(data)
orchestrator = Orchestrator(self)
logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}")
await orchestrator.run(data)
end_time = default_timer()
logger.debug(
f"Pipeline {orchestrator.run_id} finished in {end_time - start_time}s"
f"PIPELINE FINISHED {orchestrator.run_id} in {end_time - start_time}s"
)
return PipelineResult(
run_id=orchestrator.run_id,
Expand Down
42 changes: 41 additions & 1 deletion src/neo4j_graphrag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,51 @@
# limitations under the License.
from __future__ import annotations

from typing import Optional
from typing import Optional, Any

from pydantic import BaseModel


def validate_search_query_input(
query_text: Optional[str] = None, query_vector: Optional[list[float]] = None
) -> None:
if not (bool(query_vector) ^ bool(query_text)):
raise ValueError("You must provide exactly one of query_vector or query_text.")



class Prettyfier:
"""Prettyfy object for logging.
I.e.: truncate long lists.
"""
def __init__(self, max_items_in_list: int = 5):
self.max_items_in_list = max_items_in_list

def _prettyfy_dict(self, value: dict[Any, Any]) -> dict[Any, Any]:
return {
k: self(v) # prettyfy each value
for k, v in value.items()
}

def _prettyfy_list(self, value: list[Any]) -> list[Any]:
items = [
self(v) # prettify each item
for v in value[:self.max_items_in_list]
]
remaining_items = len(value) - len(items)
if remaining_items > 0:
items.append(f"...truncated {remaining_items} items...")
return items

def __call__(self, value: Any) -> Any:
if isinstance(value, dict):
return self._prettyfy_dict(value)
if isinstance(value, BaseModel):
return self(value.model_dump())
if isinstance(value, list):
return self._prettyfy_list(value)
return value


prettyfier = Prettyfier()
4 changes: 2 additions & 2 deletions tests/e2e/test_kg_writer_component_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None:
if start_node.embedding_properties: # for mypy
for key, val in start_node.embedding_properties.items():
assert key in node_a.keys()
assert node_a.get(key) == [1.0, 2.0, 3.0]
assert val.root == node_a.get(key)

node_b = record["b"]
assert end_node.label in list(node_b.labels)
Expand All @@ -100,7 +100,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None:
if node_with_two_embeddings.embedding_properties: # for mypy
for key, val in node_with_two_embeddings.embedding_properties.items():
assert key in node_c.keys()
assert val == node_c.get(key)
assert val.root == node_c.get(key)


@pytest.mark.asyncio
Expand Down
10 changes: 6 additions & 4 deletions tests/unit/experimental/components/test_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@

import pytest
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks
from neo4j_graphrag.experimental.components.types import (
Embeddings,
TextChunk,
TextChunks,
)


@pytest.mark.asyncio
Expand All @@ -33,6 +37,4 @@ async def test_text_chunk_embedder_run(embedder: MagicMock) -> None:
assert isinstance(chunk, TextChunk)
assert chunk.metadata is not None
assert "embedding" in chunk.metadata.keys()
assert isinstance(chunk.metadata["embedding"], list)
for i in chunk.metadata["embedding"]:
assert isinstance(i, float)
assert isinstance(chunk.metadata["embedding"], Embeddings)
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
LexicalGraphConfig,
Neo4jNode,
TextChunk,
TextChunks,
TextChunks, Embeddings,
)


Expand Down Expand Up @@ -64,7 +64,7 @@ def test_lexical_graph_builder_create_chunk_node_metadata_embedding() -> None:
assert isinstance(node, Neo4jNode)
assert node.id is not None
assert node.properties == {"index": 0, "text": "text chunk", "status": "ok"}
assert node.embedding_properties == {"embedding": [1, 2, 3]}
assert node.embedding_properties == {"embedding": Embeddings([1, 2, 3])}


@pytest.mark.asyncio
Expand Down

0 comments on commit e4342e8

Please sign in to comment.