Skip to content

Commit

Permalink
Merge pull request #412 from DAGWorks-Inc/fix/haystack-action
Browse files Browse the repository at this point in the history
fix: support `.warm_up()` in `HaystackAction`
  • Loading branch information
zilto authored Oct 30, 2024
2 parents c1e86f5 + 25338a4 commit 6d3d7e9
Show file tree
Hide file tree
Showing 5 changed files with 482 additions and 262 deletions.
2 changes: 1 addition & 1 deletion burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]:

@property
def optional_and_required_inputs(self) -> tuple[set[str], set[str]]:
"""Returns a tuple of two lists of strings -- the first list is the required keys, the second is the optional keys.
"""Returns a tuple of two sets of strings -- the first set is the required keys, the second is the optional keys.
This is internal and not meant to override.
:return: Tuple of required keys and optional keys
Expand Down
27 changes: 19 additions & 8 deletions burr/integrations/haystack.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
writes: Union[list[str], dict[str, str]],
name: Optional[str] = None,
bound_params: Optional[dict] = None,
do_warm_up: bool = True,
):
"""Create a Burr ``Action`` from a Haystack ``Component``.
Expand All @@ -39,8 +40,10 @@ def __init__(
Use a mapping {state_field: socket} to rename Haystack output sockets (see example).
:param name: Name of the action. Can be set later via ``.with_name()``
or in ``ApplicationBuilder.with_actions()``.
:param bound_params: Parameters to bind to the `Component.run()` method.
:param bound_params: Parameters to bind to the ``Component.run()`` method.
:param do_warm_up: If True, try to call ``Component.warm_up()`` if it exists.
If False, we assume ``.warm_up()`` was called before creating the ``HaystackAction``.
Read more about ``.warm_up()`` in the Haystack documentation: https://docs.haystack.deepset.ai/reference/pipeline-api#pipelinewarm_up
Pass the mapping ``{"foo": "state_field"}`` to read the value of ``state_field`` on the Burr state
and pass it to ``Component.run()`` as ``foo``.
Expand Down Expand Up @@ -103,6 +106,10 @@ def run(self) -> dict:
self._component = component
self._name = name
self._bound_params = bound_params if bound_params is not None else {}
self._do_warm_up = do_warm_up

if self._do_warm_up is True:
self._try_warm_up()

# NOTE input and output socket mappings are kept separately to avoid naming conflicts.
if isinstance(reads, Mapping):
Expand All @@ -129,6 +136,10 @@ def run(self) -> dict:

self._required_inputs, self._optional_inputs = self._get_required_and_optional_inputs()

def _try_warm_up(self) -> None:
if hasattr(self._component, "warm_up") is True:
self._component.warm_up()

def _validate_input_sockets(self) -> None:
"""Check that input socket names passed by the user match the Component's input sockets"""
# NOTE those are internal attributes, but we expect them be stable.
Expand Down Expand Up @@ -166,12 +177,12 @@ def writes(self) -> list[str]:
"""State fields where results of `Component.run()` are written."""
return self._writes

def _get_required_and_optional_inputs(self) -> tuple[list[str], list[str]]:
def _get_required_and_optional_inputs(self) -> tuple[set[str], set[str]]:
"""Iterate over Haystack Component input sockets and inspect default values.
If we expect the value to come from state or it's a bound parameter, skip this socket.
Otherwise, if it has a default value, it's optional.
"""
required_inputs, optional_inputs = [], []
required_inputs, optional_inputs = set(), set()
# NOTE those are internal attributes, but we expect them be stable.
# reference: https://github.com/deepset-ai/haystack/blob/906177329bcc54f6946af361fcd3d0e334e6ce5f/haystack/core/component/component.py#L371
for socket_name, input_socket in self._component.__haystack_input__._sockets_dict.items():
Expand All @@ -180,9 +191,9 @@ def _get_required_and_optional_inputs(self) -> tuple[list[str], list[str]]:
continue

if input_socket.default_value == haystack_empty:
required_inputs.append(state_field_name)
required_inputs.add(state_field_name)
else:
optional_inputs.append(state_field_name)
optional_inputs.add(state_field_name)

return required_inputs, optional_inputs

Expand All @@ -191,10 +202,10 @@ def inputs(self) -> list[str]:
"""Return a list of required inputs for ``Component.run()``
This corresponds to the Component's required input socket names.
"""
return self._required_inputs
return list(self._required_inputs)

@property
def optional_and_required_inputs(self) -> tuple[list[str], list[str]]:
def optional_and_required_inputs(self) -> tuple[set[str], set[str]]:
"""Return a tuple of required and optional inputs for ``Component.run()``
This corresponds to the Component's required and optional input socket names.
"""
Expand Down
79 changes: 41 additions & 38 deletions examples/haystack-integration/application.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,12 @@
import os

from haystack.components.builders import PromptBuilder
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.embedders import OpenAITextEmbedder
from haystack.components.generators import OpenAIGenerator
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.document_stores.in_memory import InMemoryDocumentStore

from burr.core import ApplicationBuilder, State, action
from burr.integrations.haystack import HaystackAction

# dummy OpenAI key to avoid raising an error
os.environ["OPENAI_API_KEY"] = "sk-..."


embed_text = HaystackAction(
component=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
name="embed_text",
reads=[],
writes={"embedding": "query_embedding"},
)


retrieve_documents = HaystackAction(
component=InMemoryEmbeddingRetriever(InMemoryDocumentStore()),
name="retrieve_documents",
reads=["query_embedding"],
writes=["documents"],
)


build_prompt = HaystackAction(
component=PromptBuilder(template="Document: {{documents}} Question: {{question}}"),
name="build_prompt",
reads=["documents"],
writes={"prompt": "question_prompt"},
)


generate_answer = HaystackAction(
component=OpenAIGenerator(model="gpt-4o-mini"),
name="generate_answer",
reads={"question_prompt": "prompt"},
writes={"text": "answer"},
)


@action(reads=["answer"], writes=[])
def display_answer(state: State) -> State:
Expand All @@ -52,6 +15,34 @@ def display_answer(state: State) -> State:


def build_application():
embed_text = HaystackAction(
component=OpenAITextEmbedder(model="text-embedding-3-small"),
name="embed_text",
reads=[],
writes={"query_embedding": "embedding"},
)

retrieve_documents = HaystackAction(
component=InMemoryEmbeddingRetriever(InMemoryDocumentStore()),
name="retrieve_documents",
reads=["query_embedding"],
writes=["documents"],
)

build_prompt = HaystackAction(
component=PromptBuilder(template="Document: {{documents}} Question: {{question}}"),
name="build_prompt",
reads=["documents"],
writes={"question_prompt": "prompt"},
)

generate_answer = HaystackAction(
component=OpenAIGenerator(model="gpt-4o-mini"),
name="generate_answer",
reads={"prompt": "question_prompt"},
writes={"answer": "replies"},
)

return (
ApplicationBuilder()
.with_actions(
Expand All @@ -73,5 +64,17 @@ def build_application():


if __name__ == "__main__":
import os

os.environ["OPENAI_API_KEY"] = "sk-..."

app = build_application()

_, _, state = app.run(
halt_after=["display_answer"],
inputs={
"text": "What is the capital of France?",
"question": "What is the capital of France?",
},
)
app.visualize(include_state=True)
512 changes: 307 additions & 205 deletions examples/haystack-integration/notebook.ipynb

Large diffs are not rendered by default.

124 changes: 114 additions & 10 deletions tests/integrations/test_burr_haystack.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import pytest
from haystack import Pipeline, component
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.embedders import OpenAITextEmbedder
from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.utils.auth import Secret

from burr.core import State, action
from burr.core.application import ApplicationBuilder
from burr.core.graph import GraphBuilder
from burr.integrations.haystack import HaystackAction, haystack_pipeline_to_burr_graph

Expand All @@ -22,6 +25,27 @@ def run(self, required_input: str, optional_input: str = "default") -> dict:
}


@component
class MockComponentWithWarmup:
def __init__(self, required_init: str, optional_init: str = "default"):
self.required_init = required_init
self.optional_init = optional_init
self.is_warm = False

def warm_up(self):
self.is_warm = True

@component.output_types(output_1=str, output_2=str)
def run(self, required_input: str, optional_input: str = "default") -> dict:
if self.is_warm is False:
raise RuntimeError("You must call ``warm_up()`` before running.")

return {
"output_1": required_input,
"output_2": optional_input,
}


@action(reads=["query_embedding"], writes=["documents"])
def retrieve_documents(state: State) -> State:
query_embedding = state["query_embedding"]
Expand All @@ -33,14 +57,6 @@ def retrieve_documents(state: State) -> State:
return state.update(documents=results["documents"])


haystack_retrieve_documents = HaystackAction(
component=InMemoryEmbeddingRetriever(InMemoryDocumentStore()),
name="retrieve_documents",
reads=["query_embedding"],
writes=["documents"],
)


def test_input_socket_mapping():
# {input_socket_name: state_field}
reads = {"required_input": "foo"}
Expand Down Expand Up @@ -207,10 +223,38 @@ def test_update_with_writes_sequence():
assert new_state["output_1"] == 1


def test_component_is_warmed_up():
state = State(initial_values={})
haction = HaystackAction(
component=MockComponentWithWarmup(required_init="init"),
name="mock",
reads=[],
writes=[],
do_warm_up=True,
)
results = haction.run(state=state, required_input="as_input")
assert results == {"output_1": "as_input", "output_2": "default"}


def test_component_is_not_warmed_up():
state = State(initial_values={})
haction = HaystackAction(
component=MockComponentWithWarmup(required_init="init"),
name="mock",
reads=[],
writes=[],
do_warm_up=False,
)
with pytest.raises(RuntimeError):
haction.run(state=state, required_input="as_input")


def test_pipeline_converter():
# create haystack Pipeline
retriever = InMemoryEmbeddingRetriever(InMemoryDocumentStore())
text_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
text_embedder = OpenAITextEmbedder(
model="text-embedding-3-small", api_key=Secret.from_token("mock-key")
)

basic_rag_pipeline = Pipeline()
basic_rag_pipeline.add_component("text_embedder", text_embedder)
Expand Down Expand Up @@ -251,3 +295,63 @@ def test_pipeline_converter():
burr_t.from_.name == haystack_t.from_.name and burr_t.to.name == haystack_t.to.name
for haystack_t in haystack_graph.transitions
)


def test_run_application():
app = (
ApplicationBuilder()
.with_actions(
HaystackAction(
component=MockComponent(required_init="init"),
name="mock",
reads=[],
writes=["output_1"],
)
)
.with_transitions()
.with_entrypoint("mock")
.build()
)

_, _, state = app.run(halt_after=["mock"], inputs={"required_input": "runtime"})
assert state["output_1"] == "runtime"


def test_run_application_is_warm_up():
app = (
ApplicationBuilder()
.with_actions(
HaystackAction(
component=MockComponentWithWarmup(required_init="init"),
name="mock",
reads=[],
writes=["output_1"],
)
)
.with_transitions()
.with_entrypoint("mock")
.build()
)

_, _, state = app.run(halt_after=["mock"], inputs={"required_input": "runtime"})
assert state["output_1"] == "runtime"


def test_run_application_is_not_warmed_up():
app = (
ApplicationBuilder()
.with_actions(
HaystackAction(
component=MockComponentWithWarmup(required_init="init"),
name="mock",
reads=[],
writes=["output_1"],
do_warm_up=False,
)
)
.with_transitions()
.with_entrypoint("mock")
.build()
)
with pytest.raises(RuntimeError):
app.run(halt_after=["mock"], inputs={"required_input": "runtime"})

0 comments on commit 6d3d7e9

Please sign in to comment.