Skip to content

Commit

Permalink
Add flexibility for lexical graph config to SimpleKGPipeline (#209)
Browse files Browse the repository at this point in the history
* Add flexibility for lexical graph config to SimpleKGPipeline

* Update CHANGELOG and update E2E test

* Revert LLMEntityRelationExtractor changes and move lexical_graph_config to pipe_inputs in SimpleKGPipeline
  • Loading branch information
willtai authored Nov 4, 2024
1 parent 508323a commit 18e8e2a
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## Next

### Added
- Introduced optional lexical graph configuration for SimpleKGPipeline, enhancing flexibility in customizing node labels and relationship types in the lexical graph.

## 1.2.0

### Added
Expand Down
11 changes: 11 additions & 0 deletions src/neo4j_graphrag/experimental/pipeline/kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
FixedSizeSplitter,
)
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline, PipelineResult
from neo4j_graphrag.generation.prompts import ERExtractionTemplate
Expand All @@ -59,6 +60,7 @@ class SimpleKGPipelineConfig(BaseModel):
on_error: OnError = OnError.RAISE
prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate()
perform_entity_resolution: bool = True
lexical_graph_config: Optional[LexicalGraphConfig] = None

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand All @@ -84,6 +86,7 @@ class SimpleKGPipeline:
on_error (str): Error handling strategy for the Entity and relation extractor. Defaults to "IGNORE", where chunk will be ignored if extraction fails. Possible values: "RAISE" or "IGNORE".
perform_entity_resolution (bool): Merge entities with same label and name. Default: True
prompt_template (str): A custom prompt template to use for extraction.
lexical_graph_config (Optional[LexicalGraphConfig], optional): Lexical graph configuration to customize node labels and relationship types in the lexical graph.
"""

def __init__(
Expand All @@ -101,6 +104,7 @@ def __init__(
on_error: str = "IGNORE",
prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate(),
perform_entity_resolution: bool = True,
lexical_graph_config: Optional[LexicalGraphConfig] = None,
):
self.entities = [SchemaEntity(label=label) for label in entities or []]
self.relations = [SchemaRelation(label=label) for label in relations or []]
Expand All @@ -127,6 +131,7 @@ def __init__(
prompt_template=prompt_template,
embedder=embedder,
perform_entity_resolution=perform_entity_resolution,
lexical_graph_config=lexical_graph_config,
)

self.from_pdf = config.from_pdf
Expand All @@ -141,6 +146,7 @@ def __init__(
)
self.prompt_template = config.prompt_template
self.perform_entity_resolution = config.perform_entity_resolution
self.lexical_graph_config = config.lexical_graph_config

self.pipeline = self._build_pipeline()

Expand Down Expand Up @@ -252,4 +258,9 @@ def _prepare_inputs(
else:
pipe_inputs["splitter"] = {"text": text}

if self.lexical_graph_config:
pipe_inputs["extractor"] = {
"lexical_graph_config": self.lexical_graph_config
}

return pipe_inputs
11 changes: 9 additions & 2 deletions tests/e2e/test_simplekgpipeline_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import neo4j
import pytest
from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
from neo4j_graphrag.llm import LLMInterface, LLMResponse

Expand Down Expand Up @@ -111,6 +112,11 @@ async def test_pipeline_builder_happy_path(
("ORGANIZATION", "LED_BY", "PERSON"),
]

# Additional arguments
lexical_graph_config = LexicalGraphConfig(chunk_node_label="chunkNodeLabel")
from_pdf = False
on_error = "RAISE"

# Create an instance of the SimpleKGPipeline
kg_builder_text = SimpleKGPipeline(
llm=llm,
Expand All @@ -119,8 +125,9 @@ async def test_pipeline_builder_happy_path(
entities=entities,
relations=relations,
potential_schema=potential_schema,
from_pdf=False,
on_error="RAISE",
from_pdf=from_pdf,
on_error=on_error,
lexical_graph_config=lexical_graph_config,
)

# Run the knowledge graph building process with text input
Expand Down
63 changes: 63 additions & 0 deletions tests/unit/experimental/pipeline/test_kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from neo4j_graphrag.embeddings import Embedder
from neo4j_graphrag.experimental.components.entity_relation_extractor import OnError
from neo4j_graphrag.experimental.components.schema import SchemaEntity, SchemaRelation
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
Expand Down Expand Up @@ -316,3 +317,65 @@ def test_simple_kg_pipeline_no_entity_resolution(_: Mock) -> None:
)

assert "resolver" not in kg_builder.pipeline


@mock.patch(
"neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version",
return_value=(5, 23, 0),
)
@pytest.mark.asyncio
def test_simple_kg_pipeline_lexical_graph_config_attribute(_: Mock) -> None:
llm = MagicMock(spec=LLMInterface)
driver = MagicMock(spec=neo4j.Driver)
embedder = MagicMock(spec=Embedder)

lexical_graph_config = LexicalGraphConfig()
kg_builder = SimpleKGPipeline(
llm=llm,
driver=driver,
embedder=embedder,
on_error="IGNORE",
lexical_graph_config=lexical_graph_config,
)

assert kg_builder.lexical_graph_config == lexical_graph_config


@mock.patch(
"neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version",
return_value=(5, 23, 0),
)
@pytest.mark.asyncio
async def test_knowledge_graph_builder_with_lexical_graph_config(_: Mock) -> None:
llm = MagicMock(spec=LLMInterface)
driver = MagicMock(spec=neo4j.Driver)
embedder = MagicMock(spec=Embedder)

chunk_node_label = "TestChunk"
document_nodel_label = "TestDocument"
lexical_graph_config = LexicalGraphConfig(
chunk_node_label=chunk_node_label, document_node_label=document_nodel_label
)

kg_builder = SimpleKGPipeline(
llm=llm,
driver=driver,
embedder=embedder,
from_pdf=False,
lexical_graph_config=lexical_graph_config,
)

text_input = "May thy knife chip and shatter."

with patch.object(
kg_builder.pipeline,
"run",
return_value=PipelineResult(run_id="test_run", result=None),
) as mock_run:
await kg_builder.run_async(text=text_input)

pipe_inputs = mock_run.call_args[0][0]
assert "extractor" in pipe_inputs
assert pipe_inputs["extractor"] == {
"lexical_graph_config": lexical_graph_config
}

0 comments on commit 18e8e2a

Please sign in to comment.