Skip to content

Commit

Permalink
Add logs to all prompt drivers
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Oct 9, 2024
1 parent 9be0a64 commit f0292bc
Show file tree
Hide file tree
Showing 11 changed files with 113 additions and 37 deletions.
2 changes: 1 addition & 1 deletion griptape/configs/defaults_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def logging_config(self) -> LoggingConfig:

@lazy_property()
def drivers_config(self) -> BaseDriversConfig:
from .drivers.openai_drivers_config import OpenAiDriversConfig
from griptape.configs.drivers.openai_drivers_config import OpenAiDriversConfig

return OpenAiDriversConfig()

Expand Down
14 changes: 8 additions & 6 deletions griptape/configs/logging/logging_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@
class LoggingConfig:
logger_name: str = field(default="griptape", kw_only=True)
level: int = field(default=logging.INFO, kw_only=True)
handler: logging.Handler = field(
default=Factory(lambda: RichHandler(show_time=True, show_path=False)), kw_only=True
handlers: list[logging.Handler] = field(
default=Factory(lambda: [RichHandler(show_time=True, show_path=False)]), kw_only=True
)
propagate: bool = field(default=False, kw_only=True)
handler_formatter: Optional[logging.Formatter] = field(default=None, kw_only=True)
handlers_formatter: Optional[logging.Formatter] = field(default=None, kw_only=True)

def __attrs_post_init__(self) -> None:
logger = logging.getLogger(self.logger_name)
logger.setLevel(self.level)
logger.propagate = self.propagate
if self.handler_formatter is not None:
self.handler.setFormatter(self.handler_formatter)
logger.addHandler(self.handler)
if self.handlers_formatter:
for handler in self.handlers:
handler.setFormatter(self.handlers_formatter)

logger.handlers = self.handlers
14 changes: 12 additions & 2 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any

from attrs import Factory, define, field
Expand Down Expand Up @@ -28,6 +29,7 @@
ToolAction,
observable,
)
from griptape.configs import Defaults
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import AmazonBedrockTokenizer, BaseTokenizer
from griptape.utils import import_optional_dependency
Expand All @@ -41,6 +43,8 @@
from griptape.common import PromptStack
from griptape.tools import BaseTool

logger = logging.getLogger(Defaults.logging_config.logger_name)


@define
class AmazonBedrockPromptDriver(BasePromptDriver):
Expand All @@ -60,7 +64,10 @@ def client(self) -> Any:

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
response = self.client.converse(**self._base_params(prompt_stack))
params = self._base_params(prompt_stack)
logger.debug(params)
response = self.client.converse(**params)
logger.debug(response)

usage = response["usage"]
output_message = response["output"]["message"]
Expand All @@ -73,11 +80,14 @@ def try_run(self, prompt_stack: PromptStack) -> Message:

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
response = self.client.converse_stream(**self._base_params(prompt_stack))
params = self._base_params(prompt_stack)
logger.debug(params)
response = self.client.converse_stream(**params)

stream = response.get("stream")
if stream is not None:
for event in stream:
logger.debug(event)
if "contentBlockDelta" in event or "contentBlockStart" in event:
yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event))
elif "metadata" in event:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import json
import logging
from typing import TYPE_CHECKING, Any, Optional

from attrs import Attribute, Factory, define, field

from griptape.artifacts import TextArtifact
from griptape.common import DeltaMessage, Message, PromptStack, TextMessageContent, observable
from griptape.configs import Defaults
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import HuggingFaceTokenizer
from griptape.utils import import_optional_dependency
Expand All @@ -19,6 +21,8 @@

from griptape.common import PromptStack

logger = logging.getLogger(Defaults.logging_config.logger_name)


@define
class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver):
Expand Down Expand Up @@ -52,6 +56,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
"inputs": self.prompt_stack_to_string(prompt_stack),
"parameters": {**self._base_params(prompt_stack)},
}
logger.debug(payload)

response = self.client.invoke_endpoint(
EndpointName=self.endpoint,
Expand All @@ -66,6 +71,7 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
)

decoded_body = json.loads(response["Body"].read().decode("utf8"))
logger.debug(decoded_body)

if isinstance(decoded_body, list):
if decoded_body:
Expand Down
21 changes: 18 additions & 3 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Optional

from attrs import Factory, define, field
Expand Down Expand Up @@ -29,6 +30,7 @@
ToolAction,
observable,
)
from griptape.configs import Defaults
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import AnthropicTokenizer, BaseTokenizer
from griptape.utils import import_optional_dependency
Expand All @@ -43,6 +45,9 @@
from griptape.tools.base_tool import BaseTool


logger = logging.getLogger(Defaults.logging_config.logger_name)


@define
class AnthropicPromptDriver(BasePromptDriver):
"""Anthropic Prompt Driver.
Expand Down Expand Up @@ -72,7 +77,11 @@ def client(self) -> Client:

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
response = self.client.messages.create(**self._base_params(prompt_stack))
params = self._base_params(prompt_stack)
logger.debug(params)
response = self.client.messages.create(**params)

logger.debug(response.model_dump())

return Message(
content=[self.__to_prompt_stack_message_content(content) for content in response.content],
Expand All @@ -82,9 +91,12 @@ def try_run(self, prompt_stack: PromptStack) -> Message:

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
events = self.client.messages.create(**self._base_params(prompt_stack), stream=True)
params = {**self._base_params(prompt_stack), "stream": True}
logger.debug(params)
events = self.client.messages.create(**params)

for event in events:
logger.debug(event)
if event.type == "content_block_delta" or event.type == "content_block_start":
yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event))
elif event.type == "message_start":
Expand All @@ -98,7 +110,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
system_messages = prompt_stack.system_messages
system_message = system_messages[0].to_text() if system_messages else None

return {
params = {
"model": self.model,
"temperature": self.temperature,
"stop_sequences": self.tokenizer.stop_sequences,
Expand All @@ -113,6 +125,9 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
),
**({"system": system_message} if system_message else {}),
}
logger.debug(params)

return params

def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]:
return [
Expand Down
15 changes: 13 additions & 2 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any

from attrs import Factory, define, field
Expand All @@ -20,6 +21,7 @@
observable,
)
from griptape.common.prompt_stack.contents.action_call_delta_message_content import ActionCallDeltaMessageContent
from griptape.configs import Defaults
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import BaseTokenizer, CohereTokenizer
from griptape.utils import import_optional_dependency
Expand All @@ -33,6 +35,8 @@

from griptape.tools import BaseTool

logger = logging.getLogger(Defaults.logging_config.logger_name)


@define(kw_only=True)
class CoherePromptDriver(BasePromptDriver):
Expand All @@ -59,7 +63,11 @@ def client(self) -> Client:

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
result = self.client.chat(**self._base_params(prompt_stack))
params = self._base_params(prompt_stack)
logger.debug(params)

result = self.client.chat(**params)
logger.debug(result.model_dump())
usage = result.meta.tokens

return Message(
Expand All @@ -70,9 +78,12 @@ def try_run(self, prompt_stack: PromptStack) -> Message:

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
result = self.client.chat_stream(**self._base_params(prompt_stack))
params = self._base_params(prompt_stack)
logger.debug(params)
result = self.client.chat_stream(**params)

for event in result:
logger.debug(event.model_dump())
if event.event_type == "stream-end":
usage = event.response.meta.tokens

Expand Down
18 changes: 12 additions & 6 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
import logging
from typing import TYPE_CHECKING, Optional

from attrs import Factory, define, field
Expand All @@ -23,6 +24,7 @@
ToolAction,
observable,
)
from griptape.configs import Defaults
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import BaseTokenizer, GoogleTokenizer
from griptape.utils import import_optional_dependency, remove_key_in_dict_recursively
Expand All @@ -37,6 +39,8 @@

from griptape.tools import BaseTool

logger = logging.getLogger(Defaults.logging_config.logger_name)


@define
class GooglePromptDriver(BasePromptDriver):
Expand Down Expand Up @@ -72,10 +76,10 @@ def client(self) -> GenerativeModel:
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
messages = self.__to_google_messages(prompt_stack)
response: GenerateContentResponse = self.client.generate_content(
messages,
**self._base_params(prompt_stack),
)
params = self._base_params(prompt_stack)
logging.debug((messages, params))
response: GenerateContentResponse = self.client.generate_content(messages, **params)
logging.debug(response.to_dict())

usage_metadata = response.usage_metadata

Expand All @@ -91,14 +95,16 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
messages = self.__to_google_messages(prompt_stack)
params = {**self._base_params(prompt_stack), "stream": True}
logging.debug((messages, params))
response: GenerateContentResponse = self.client.generate_content(
messages,
**self._base_params(prompt_stack),
stream=True,
**params,
)

prompt_token_count = None
for chunk in response:
logger.debug(chunk.to_dict())
usage_metadata = chunk.usage_metadata

content = self.__to_prompt_stack_delta_message_content(chunk.parts[0]) if chunk.parts else None
Expand Down
22 changes: 12 additions & 10 deletions griptape/drivers/prompt/huggingface_hub_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING

from attrs import Factory, define, field

from griptape.common import DeltaMessage, Message, PromptStack, TextDeltaMessageContent, observable
from griptape.configs import Defaults
from griptape.drivers import BasePromptDriver
from griptape.tokenizers import HuggingFaceTokenizer
from griptape.utils import import_optional_dependency
Expand All @@ -15,6 +17,8 @@

from huggingface_hub import InferenceClient

logger = logging.getLogger(Defaults.logging_config.logger_name)


@define
class HuggingFaceHubPromptDriver(BasePromptDriver):
Expand Down Expand Up @@ -52,13 +56,14 @@ def client(self) -> InferenceClient:
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
prompt = self.prompt_stack_to_string(prompt_stack)
full_params = {"return_full_text": False, "max_new_tokens": self.max_tokens, **self.params}
logger.debug((prompt, full_params))

response = self.client.text_generation(
prompt,
return_full_text=False,
max_new_tokens=self.max_tokens,
**self.params,
**full_params,
)
logger.debug(response)
input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))
output_tokens = len(self.tokenizer.tokenizer.encode(response))

Expand All @@ -71,19 +76,16 @@ def try_run(self, prompt_stack: PromptStack) -> Message:
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
prompt = self.prompt_stack_to_string(prompt_stack)
full_params = {"return_full_text": False, "max_new_tokens": self.max_tokens, "stream": True, **self.params}
logger.debug((prompt, full_params))

response = self.client.text_generation(
prompt,
return_full_text=False,
max_new_tokens=self.max_tokens,
stream=True,
**self.params,
)
response = self.client.text_generation(prompt, **full_params)

input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))

full_text = ""
for token in response:
logger.debug(token)
full_text += token
yield DeltaMessage(content=TextDeltaMessageContent(token, index=0))

Expand Down
Loading

0 comments on commit f0292bc

Please sign in to comment.