Skip to content

Commit

Permalink
"extras" must be resolved (#248)
Browse files Browse the repository at this point in the history
* "extras" must be resolved

* Add a test case

* Ruff

* Move test to e2e tests

* mypy + comments

* Fix test
  • Loading branch information
stellasia authored Jan 21, 2025
1 parent 5340073 commit 839be27
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
LLMType,
Neo4jDriverType,
)
from neo4j_graphrag.experimental.pipeline.config.param_resolver import (
ParamConfig,
)
from neo4j_graphrag.experimental.pipeline.config.types import PipelineType
from neo4j_graphrag.experimental.pipeline.types import (
ComponentDefinition,
Expand All @@ -48,7 +51,7 @@ class AbstractPipelineConfig(AbstractConfig):
llm_config: dict[str, LLMType] = {}
embedder_config: dict[str, EmbedderType] = {}
# extra parameters values that can be used in different places of the config file
extras: dict[str, Any] = {}
extras: dict[str, ParamConfig] = {}

DEFAULT_NAME: ClassVar[str] = "default"
"""Name of the default item in dict
Expand Down
10 changes: 8 additions & 2 deletions tests/e2e/data/config_files/simple_kg_pipeline_config.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
{
"version_": "1",
"template_": "SimpleKGPipeline",
"extras": {
"openai_api_key": {
"resolver_": "ENV",
"var_": "MY_OPENAI_KEY"
}
},
"neo4j_config": {
"params_": {
"uri": {
Expand Down Expand Up @@ -36,8 +42,8 @@
"class_": "OpenAIEmbeddings",
"params_": {
"api_key": {
"resolver_": "ENV",
"var_": "OPENAI_API_KEY"
"resolver_": "CONFIG_KEY",
"key_": "extras.openai_api_key"
}
}
},
Expand Down
18 changes: 16 additions & 2 deletions tests/e2e/experimental/pipeline/config/test_pipeline_runner_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,33 @@ async def test_simple_kg_pipeline_from_json_config(
os.environ["NEO4J_USER"] = "neo4j"
os.environ["NEO4J_PASSWORD"] = "password"
os.environ["OPENAI_API_KEY"] = "sk-my-secret-key"
os.environ["MY_OPENAI_KEY"] = "my-openai-key"

runner = PipelineRunner.from_config_file(
"tests/e2e/data/config_files/simple_kg_pipeline_config.json"
)

# check extras and API keys are handled as expected
config = runner.config
assert config is not None
# extras must be resolved:
assert config._global_data["extras"] == {"openai_api_key": "my-openai-key"}
# API key for LLM is read from env vars (see config file)
default_llm = config._global_data["llm_config"]["default"]
assert default_llm.client.api_key == "sk-my-secret-key"
# API key for embedder is read from extras (see config file)
default_embedder = config._global_data["embedder_config"]["default"]
assert default_embedder.client.api_key == "my-openai-key"

# then run pipeline and check results
res = await runner.run({"file_path": "tests/e2e/data/documents/harry_potter.pdf"})
assert isinstance(res, PipelineResult)
# print(await runner.pipeline.store.get_result_for_component(res.run_id, "splitter"))
assert res.result["resolver"] == {
"number_of_nodes_to_resolve": 3,
"number_of_created_nodes": 3,
}
nodes = driver.execute_query("MATCH (n) RETURN n")
# 1 chunk + 1 document + 3 nodes
# 1 chunk + 1 document + 3 __Entity__ nodes
assert len(nodes.records) == 5


Expand Down

0 comments on commit 839be27

Please sign in to comment.