Skip to content

Commit

Permalink
Add BaseChunkEvent
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo committed Oct 17, 2024
1 parent db492c9 commit 310e8e0
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 23 deletions.
6 changes: 3 additions & 3 deletions docs/griptape-framework/misc/src/events_3.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from typing import cast

from griptape.drivers import OpenAiChatPromptDriver
from griptape.events import CompletionChunkEvent, EventBus, EventListener
from griptape.events import BaseChunkEvent, EventBus, EventListener
from griptape.structures import Pipeline
from griptape.tasks import ToolkitTask
from griptape.tools import PromptSummaryTool, WebScraperTool

EventBus.add_event_listeners(
[
EventListener(
lambda e: print(cast(CompletionChunkEvent, e).token, end="", flush=True),
event_types=[CompletionChunkEvent],
lambda e: print(cast(BaseChunkEvent, e).token, end="", flush=True),
event_types=[BaseChunkEvent],
)
]
)
Expand Down
17 changes: 11 additions & 6 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
TextMessageContent,
observable,
)
from griptape.events import CompletionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent
from griptape.events import (
ActionChunkEvent,
EventBus,
FinishPromptEvent,
StartPromptEvent,
TextChunkEvent,
)
from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin
from griptape.mixins.serializable_mixin import SerializableMixin

Expand Down Expand Up @@ -127,12 +133,11 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message:
else:
delta_contents[content.index] = [content]
if isinstance(content, TextDeltaMessageContent):
EventBus.publish_event(CompletionChunkEvent(token=content.text))
EventBus.publish_event(TextChunkEvent.from_delta_message_content(content))
elif isinstance(content, ActionCallDeltaMessageContent):
if content.tag is not None and content.name is not None and content.path is not None:
EventBus.publish_event(CompletionChunkEvent(token=str(content)))
elif content.partial_input is not None:
EventBus.publish_event(CompletionChunkEvent(token=content.partial_input))
EventBus.publish_event(
ActionChunkEvent.from_delta_message_content(content),
)

# Build a complete content from the content deltas
return self.__build_message(list(delta_contents.values()), usage)
Expand Down
8 changes: 6 additions & 2 deletions griptape/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from .finish_prompt_event import FinishPromptEvent
from .start_structure_run_event import StartStructureRunEvent
from .finish_structure_run_event import FinishStructureRunEvent
from .completion_chunk_event import CompletionChunkEvent
from .base_chunk_event import BaseChunkEvent
from .text_chunk_event import TextChunkEvent
from .action_chunk_event import ActionChunkEvent
from .event_listener import EventListener
from .start_image_generation_event import StartImageGenerationEvent
from .finish_image_generation_event import FinishImageGenerationEvent
Expand All @@ -37,7 +39,9 @@
"FinishPromptEvent",
"StartStructureRunEvent",
"FinishStructureRunEvent",
"CompletionChunkEvent",
"BaseChunkEvent",
"TextChunkEvent",
"ActionChunkEvent",
"EventListener",
"StartImageGenerationEvent",
"FinishImageGenerationEvent",
Expand Down
32 changes: 32 additions & 0 deletions griptape/events/action_chunk_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional

from attrs import define, field

from griptape.events.base_chunk_event import BaseChunkEvent

if TYPE_CHECKING:
from griptape.common import BaseDeltaMessageContent


@define
class ActionChunkEvent(BaseChunkEvent):
tag: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
path: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})

@classmethod
def from_delta_message_content(cls, content: BaseDeltaMessageContent) -> ActionChunkEvent:
from griptape.common import ActionCallDeltaMessageContent

if isinstance(content, ActionCallDeltaMessageContent):
return cls(
token=content.partial_input if content.partial_input is not None else "",
index=content.index,
tag=content.tag,
name=content.name,
path=content.path,
)

raise ValueError(f"Content is not an instance of ActionCallDeltaMessageContent: {content.__class__.__name__}")
19 changes: 19 additions & 0 deletions griptape/events/base_chunk_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from attrs import define, field

from griptape.events.base_event import BaseEvent

if TYPE_CHECKING:
from griptape.common import BaseDeltaMessageContent


@define
class BaseChunkEvent(BaseEvent):
token: str = field(metadata={"serializable": True})
index: int = field(default=0, metadata={"serializable": True})

@classmethod
def from_delta_message_content(cls, content: BaseDeltaMessageContent) -> BaseChunkEvent: ...
8 changes: 0 additions & 8 deletions griptape/events/completion_chunk_event.py

This file was deleted.

2 changes: 1 addition & 1 deletion griptape/events/event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __exit__(self, type, value, traceback) -> None: # noqa: ANN001, A002
def publish_event(self, event: T, *, flush: bool = False) -> None:
event_types = self.event_types

if event_types is None or type(event) in event_types:
if event_types is None or any(isinstance(event, event_type) for event_type in event_types):
handled_event = event
if self.handler is not None:
handled_event = self.handler(event)
Expand Down
22 changes: 22 additions & 0 deletions griptape/events/text_chunk_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from attrs import define

from griptape.events.base_chunk_event import BaseChunkEvent

if TYPE_CHECKING:
from griptape.common import BaseDeltaMessageContent


@define
class TextChunkEvent(BaseChunkEvent):
@classmethod
def from_delta_message_content(cls, content: BaseDeltaMessageContent) -> TextChunkEvent:
from griptape.common import TextDeltaMessageContent

if isinstance(content, TextDeltaMessageContent):
return cls(token=content.text, index=content.index)

raise ValueError(f"Content is not an instance of TextDeltaMessageContent: {content.__class__.__name__}")
6 changes: 3 additions & 3 deletions griptape/utils/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from attrs import Attribute, Factory, define, field

from griptape.artifacts.text_artifact import TextArtifact
from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent
from griptape.events import BaseChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent

if TYPE_CHECKING:
from collections.abc import Iterator
Expand Down Expand Up @@ -54,7 +54,7 @@ def run(self, *args) -> Iterator[TextArtifact]:
break
elif isinstance(event, FinishPromptEvent):
yield TextArtifact(value="\n")
elif isinstance(event, CompletionChunkEvent):
elif isinstance(event, BaseChunkEvent):
yield TextArtifact(value=event.token)
t.join()

Expand All @@ -64,7 +64,7 @@ def event_handler(event: BaseEvent) -> None:

stream_event_listener = EventListener(
handler=event_handler,
event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent],
event_types=[BaseChunkEvent, FinishPromptEvent, FinishStructureRunEvent],
)
EventBus.add_event_listener(stream_event_listener)

Expand Down

0 comments on commit 310e8e0

Please sign in to comment.