diff --git a/CHANGELOG.md b/CHANGELOG.md index d053ac1fb..cf0d8ed36 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `AzureOpenAiStructureConfig` for providing Structures with all Azure OpenAI Driver configuration. - `AzureOpenAiVisionImageQueryDriver` to support queries on images using Azure's OpenAI Vision models. +- `AudioLoader` for loading audio content into an `AudioArtifact`. +- `AudioTranscriptionTask` and `AudioTranscriptionClient` for transcribing audio content in Structures. +- `OpenAiAudioTranscriptionDriver` for integration with OpenAI's speech-to-text models, including Whisper. - Parameter `env` to `BaseStructureRunDriver` to set environment variables for a Structure Run. ### Changed diff --git a/docs/griptape-framework/data/loaders.md b/docs/griptape-framework/data/loaders.md index d767d0d8f..69ffeda06 100644 --- a/docs/griptape-framework/data/loaders.md +++ b/docs/griptape-framework/data/loaders.md @@ -199,3 +199,24 @@ loader.load(EmailLoader.EmailQuery(label="INBOX")) loader.load_collection([EmailLoader.EmailQuery(label="INBOX"), EmailLoader.EmailQuery(label="SENT")]) ``` + +## Audio Loader + +!!! info + This driver requires the `loaders-audio` [extra](../index.md#extras). + +The [Audio Loader](../../reference/griptape/loaders/audio_loader.md) is used to load audio content as an [AudioArtifact](./artifacts.md#audioartifact). The Loader operates on audio bytes that can be sourced from files on disk, downloaded audio, or audio in memory. + +The Loader will load audio in its native format and populates the resulting Artifact's `format` field by making a best-effort guess of the underlying audio format using the `filetype` package. + +```python +from griptape.loaders import AudioLoader +from griptape.utils import load_file + +# Load an image from disk +with open("tests/resources/sentences.wav", "rb") as f: + audio_artifact = AudioLoader().load(f.read()) + +# You can also use the load_file utility function +AudioLoader().load(load_file("tests/resources/sentences.wav")) +``` diff --git a/docs/griptape-framework/drivers/audio-transcription-drivers.md b/docs/griptape-framework/drivers/audio-transcription-drivers.md new file mode 100644 index 000000000..ddadcb89a --- /dev/null +++ b/docs/griptape-framework/drivers/audio-transcription-drivers.md @@ -0,0 +1,32 @@ +## Overview + +[Audio Transcription Drivers](../../reference/griptape/drivers/audio_transcription/index.md) extract text from spoken audio. + +This driver acts as a critical bridge between audio transcription Engines and the underlying models, facilitating the construction and execution of API calls that transform speech into editable and searchable text. Utilized predominantly in applications that support the input of verbal communications, the Audio Transcription Driver effectively extracts and interprets speech, rendering it into a textual format that can be easily integrated into data systems and Workflows. + +This capability is essential for enhancing accessibility, improving content discoverability, and automating tasks that traditionally relied on manual transcription, thereby streamlining operations and enhancing efficiency across various industries. + +### OpenAI + +The [OpenAI Audio Transcription Driver](../../reference/griptape/drivers/audio_transcription/openai_audio_transcription_driver.md) utilizes OpenAI's sophisticated `whisper` model to accurately transcribe spoken audio into text. This model supports multiple languages, ensuring precise transcription across a wide range of dialects. + +```python +from griptape.drivers import OpenAiAudioTranscriptionDriver +from griptape.engines import AudioTranscriptionEngine +from griptape.tools.audio_transcription_client.tool import AudioTranscriptionClient +from griptape.structures import Agent + + +driver = OpenAiAudioTranscriptionDriver( + model="whisper-1" +) + +tool = AudioTranscriptionClient( + off_prompt=False, + engine=AudioTranscriptionEngine( + audio_transcription_driver=driver, + ), +) + +Agent(tools=[tool]).run("Transcribe the following audio file: tests/resources/sentences.wav") +``` diff --git a/docs/griptape-framework/engines/audio-engines.md b/docs/griptape-framework/engines/audio-engines.md index 514b8dd8f..cbef1ef23 100644 --- a/docs/griptape-framework/engines/audio-engines.md +++ b/docs/griptape-framework/engines/audio-engines.md @@ -27,3 +27,26 @@ engine.run( prompts=["Hello, world!"], ) ``` + +### Audio Transcription Engine + +The [Audio Transcription Engine](../../reference/griptape/engines/audio/audio_transcription_engine.md) facilitates transcribing speech from audio inputs. + +```python +from griptape.drivers import OpenAiAudioTranscriptionDriver +from griptape.engines import AudioTranscriptionEngine +from griptape.loaders import AudioLoader +from griptape.utils import load_file + + +driver = OpenAiAudioTranscriptionDriver( + model="whisper-1" +) + +engine = AudioTranscriptionEngine( + audio_transcription_driver=driver, +) + +audio_artifact = AudioLoader().load(load_file("tests/resources/sentences.wav")) +engine.run(audio_artifact) +``` diff --git a/docs/griptape-framework/structures/tasks.md b/docs/griptape-framework/structures/tasks.md index bff78a96c..0f05ea3e2 100644 --- a/docs/griptape-framework/structures/tasks.md +++ b/docs/griptape-framework/structures/tasks.md @@ -805,3 +805,57 @@ team = Pipeline( team.run() ``` + +## Text to Speech Task + +This Task enables Structures to synthesize speech from text using [Text to Speech Engines](../../reference/griptape/engines/audio/text_to_speech_engine.md) and [Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md). + +```python +import os + +from griptape.drivers import ElevenLabsTextToSpeechDriver +from griptape.engines import TextToSpeechEngine +from griptape.tasks import TextToSpeechTask +from griptape.structures import Pipeline + + +driver = ElevenLabsTextToSpeechDriver( + api_key=os.getenv("ELEVEN_LABS_API_KEY"), + model="eleven_multilingual_v2", + voice="Matilda", +) + +task = TextToSpeechTask( + text_to_speech_engine=TextToSpeechEngine( + text_to_speech_driver=driver, + ), +) + +Pipeline(tasks=[task]).run("Generate audio from this text: 'Hello, world!'") +``` + +## Audio Transcription Task + +This Task enables Structures to transcribe speech from text using [Audio Transcription Engines](../../reference/griptape/engines/audio/audio_transcription_engine.md) and [Audio Transcription Drivers](../../reference/griptape/drivers/audio_transcription/index.md). + +```python +from griptape.drivers import OpenAiAudioTranscriptionDriver +from griptape.engines import AudioTranscriptionEngine +from griptape.loaders import AudioLoader +from griptape.tasks import AudioTranscriptionTask +from griptape.structures import Pipeline +from griptape.utils import load_file + +driver = OpenAiAudioTranscriptionDriver( + model="whisper-1" +) + +task = AudioTranscriptionTask( + input=lambda _: AudioLoader().load(load_file("tests/resources/sentences2.wav")), + audio_transcription_engine=AudioTranscriptionEngine( + audio_transcription_driver=driver, + ), +) + +Pipeline(tasks=[task]).run() +``` diff --git a/docs/griptape-tools/official-tools/audio-transcription-client.md b/docs/griptape-tools/official-tools/audio-transcription-client.md new file mode 100644 index 000000000..5cb458d76 --- /dev/null +++ b/docs/griptape-tools/official-tools/audio-transcription-client.md @@ -0,0 +1,24 @@ +# AudioTranscriptionClient + +This Tool enables [Agents](../../griptape-framework/structures/agents.md) to transcribe speech from text using [Audio Transcription Engines](../../reference/griptape/engines/audio/audio_transcription_engine.md) and [Audio Transcription Drivers](../../reference/griptape/drivers/audio_transcription/index.md). + +```python +from griptape.drivers import OpenAiAudioTranscriptionDriver +from griptape.engines import AudioTranscriptionEngine +from griptape.tools.audio_transcription_client.tool import AudioTranscriptionClient +from griptape.structures import Agent + + +driver = OpenAiAudioTranscriptionDriver( + model="whisper-1" +) + +tool = AudioTranscriptionClient( + off_prompt=False, + engine=AudioTranscriptionEngine( + audio_transcription_driver=driver, + ), +) + +Agent(tools=[tool]).run("Transcribe the following audio file: /Users/andrew/code/griptape/tests/resources/sentences2.wav") +``` \ No newline at end of file diff --git a/docs/griptape-tools/official-tools/text-to-speech-client.md b/docs/griptape-tools/official-tools/text-to-speech-client.md index f016db551..622b5bf3a 100644 --- a/docs/griptape-tools/official-tools/text-to-speech-client.md +++ b/docs/griptape-tools/official-tools/text-to-speech-client.md @@ -1,6 +1,6 @@ # TextToSpeechClient -This tool enables LLMs to synthesize speech from text using [Text to Speech Engines](../../reference/griptape/engines/audio/text_to_speech_engine.md) and [Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md). +This Tool enables LLMs to synthesize speech from text using [Text to Speech Engines](../../reference/griptape/engines/audio/text_to_speech_engine.md) and [Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md). ```python import os diff --git a/griptape/config/base_structure_config.py b/griptape/config/base_structure_config.py index a94cf75d9..0e07d29f5 100644 --- a/griptape/config/base_structure_config.py +++ b/griptape/config/base_structure_config.py @@ -14,6 +14,7 @@ BasePromptDriver, BaseVectorStoreDriver, BaseTextToSpeechDriver, + BaseAudioTranscriptionDriver, ) from griptape.utils import dict_merge @@ -29,6 +30,7 @@ class BaseStructureConfig(BaseConfig, ABC): default=None, kw_only=True, metadata={"serializable": True} ) text_to_speech_driver: BaseTextToSpeechDriver = field(kw_only=True, metadata={"serializable": True}) + audio_transcription_driver: BaseAudioTranscriptionDriver = field(kw_only=True, metadata={"serializable": True}) def merge_config(self, config: dict) -> BaseStructureConfig: base_config = self.to_dict() diff --git a/griptape/config/openai_structure_config.py b/griptape/config/openai_structure_config.py index 459160b11..416306a98 100644 --- a/griptape/config/openai_structure_config.py +++ b/griptape/config/openai_structure_config.py @@ -11,6 +11,10 @@ OpenAiChatPromptDriver, OpenAiEmbeddingDriver, OpenAiImageGenerationDriver, + BaseTextToSpeechDriver, + OpenAiTextToSpeechDriver, + BaseAudioTranscriptionDriver, + OpenAiAudioTranscriptionDriver, OpenAiImageQueryDriver, ) @@ -40,3 +44,11 @@ class OpenAiStructureConfig(StructureConfig): kw_only=True, metadata={"serializable": True}, ) + text_to_speech_driver: BaseTextToSpeechDriver = field( + default=Factory(lambda: OpenAiTextToSpeechDriver(model="tts")), kw_only=True, metadata={"serializable": True} + ) + audio_transcription_driver: BaseAudioTranscriptionDriver = field( + default=Factory(lambda: OpenAiAudioTranscriptionDriver(model="whisper-1")), + kw_only=True, + metadata={"serializable": True}, + ) diff --git a/griptape/config/structure_config.py b/griptape/config/structure_config.py index 63f1ea9f3..ae3ad1e99 100644 --- a/griptape/config/structure_config.py +++ b/griptape/config/structure_config.py @@ -17,6 +17,8 @@ BaseImageQueryDriver, BaseTextToSpeechDriver, DummyTextToSpeechDriver, + BaseAudioTranscriptionDriver, + DummyAudioTranscriptionDriver, ) @@ -43,3 +45,6 @@ class StructureConfig(BaseStructureConfig): text_to_speech_driver: BaseTextToSpeechDriver = field( default=Factory(lambda: DummyTextToSpeechDriver()), kw_only=True, metadata={"serializable": True} ) + audio_transcription_driver: BaseAudioTranscriptionDriver = field( + default=Factory(lambda: DummyAudioTranscriptionDriver()), kw_only=True, metadata={"serializable": True} + ) diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index 6043b0b43..6c64756e1 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -107,6 +107,10 @@ from .structure_run.griptape_cloud_structure_run_driver import GriptapeCloudStructureRunDriver from .structure_run.local_structure_run_driver import LocalStructureRunDriver +from .audio_transcription.base_audio_transcription_driver import BaseAudioTranscriptionDriver +from .audio_transcription.dummy_audio_transcription_driver import DummyAudioTranscriptionDriver +from .audio_transcription.openai_audio_transcription_driver import OpenAiAudioTranscriptionDriver + __all__ = [ "BasePromptDriver", "OpenAiChatPromptDriver", @@ -199,4 +203,7 @@ "BaseStructureRunDriver", "GriptapeCloudStructureRunDriver", "LocalStructureRunDriver", + "BaseAudioTranscriptionDriver", + "DummyAudioTranscriptionDriver", + "OpenAiAudioTranscriptionDriver", ] diff --git a/griptape/drivers/audio_transcription/__init__.py b/griptape/drivers/audio_transcription/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/drivers/audio_transcription/base_audio_transcription_driver.py b/griptape/drivers/audio_transcription/base_audio_transcription_driver.py new file mode 100644 index 000000000..3cc368c94 --- /dev/null +++ b/griptape/drivers/audio_transcription/base_audio_transcription_driver.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Optional + +from attrs import define, field + +from griptape.artifacts import TextArtifact, AudioArtifact +from griptape.events import StartAudioTranscriptionEvent, FinishAudioTranscriptionEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin + +if TYPE_CHECKING: + from griptape.structures import Structure + + +@define +class BaseAudioTranscriptionDriver(SerializableMixin, ExponentialBackoffMixin, ABC): + model: str = field(kw_only=True, metadata={"serializable": True}) + structure: Optional[Structure] = field(default=None, kw_only=True) + + def before_run(self) -> None: + if self.structure: + self.structure.publish_event(StartAudioTranscriptionEvent()) + + def after_run(self) -> None: + if self.structure: + self.structure.publish_event(FinishAudioTranscriptionEvent()) + + def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: + for attempt in self.retrying(): + with attempt: + self.before_run() + result = self.try_run(audio, prompts) + self.after_run() + + return result + + else: + raise Exception("Failed to run audio transcription") + + @abstractmethod + def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: ... diff --git a/griptape/drivers/audio_transcription/dummy_audio_transcription_driver.py b/griptape/drivers/audio_transcription/dummy_audio_transcription_driver.py new file mode 100644 index 000000000..1602604e4 --- /dev/null +++ b/griptape/drivers/audio_transcription/dummy_audio_transcription_driver.py @@ -0,0 +1,14 @@ +from typing import Optional + +from attrs import define, field +from griptape.artifacts import AudioArtifact, TextArtifact +from griptape.drivers import BaseAudioTranscriptionDriver +from griptape.exceptions import DummyException + + +@define +class DummyAudioTranscriptionDriver(BaseAudioTranscriptionDriver): + model: str = field(init=False) + + def try_run(self, audio: AudioArtifact, prompts: Optional[list] = None) -> TextArtifact: + raise DummyException(__class__.__name__, "try_transcription") diff --git a/griptape/drivers/audio_transcription/openai_audio_transcription_driver.py b/griptape/drivers/audio_transcription/openai_audio_transcription_driver.py new file mode 100644 index 000000000..14b367521 --- /dev/null +++ b/griptape/drivers/audio_transcription/openai_audio_transcription_driver.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import io +from typing import Optional + +import openai +from attrs import field, Factory, define + +from griptape.artifacts import AudioArtifact, TextArtifact +from griptape.drivers import BaseAudioTranscriptionDriver + + +@define +class OpenAiAudioTranscriptionDriver(BaseAudioTranscriptionDriver): + api_type: str = field(default=openai.api_type, kw_only=True) + api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True}) + base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) + organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True}) + client: openai.OpenAI = field( + default=Factory( + lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), + takes_self=True, + ) + ) + + def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: + additional_params = {} + + if prompts is not None: + additional_params["prompt"] = ", ".join(prompts) + + transcription = self.client.audio.transcriptions.create( + # Even though we're not actually providing a file to the client, the API still requires that we send a file + # name. We set the file name to use the same format as the audio file so that the API can reject + # it if the format is unsupported. + model=self.model, + file=(f"a.{audio.format}", io.BytesIO(audio.value)), + response_format="json", + **additional_params, + ) + + return TextArtifact(value=transcription.text) diff --git a/griptape/engines/__init__.py b/griptape/engines/__init__.py index f76e373ef..17adaa53d 100644 --- a/griptape/engines/__init__.py +++ b/griptape/engines/__init__.py @@ -12,6 +12,7 @@ from .image.outpainting_image_generation_engine import OutpaintingImageGenerationEngine from .image_query.image_query_engine import ImageQueryEngine from .audio.text_to_speech_engine import TextToSpeechEngine +from .audio.audio_transcription_engine import AudioTranscriptionEngine __all__ = [ "BaseQueryEngine", @@ -28,4 +29,5 @@ "OutpaintingImageGenerationEngine", "ImageQueryEngine", "TextToSpeechEngine", + "AudioTranscriptionEngine", ] diff --git a/griptape/engines/audio/audio_transcription_engine.py b/griptape/engines/audio/audio_transcription_engine.py new file mode 100644 index 000000000..3631b2d17 --- /dev/null +++ b/griptape/engines/audio/audio_transcription_engine.py @@ -0,0 +1,12 @@ +from attrs import define, field + +from griptape.artifacts import AudioArtifact, TextArtifact +from griptape.drivers import BaseAudioTranscriptionDriver + + +@define +class AudioTranscriptionEngine: + audio_transcription_driver: BaseAudioTranscriptionDriver = field(kw_only=True) + + def run(self, audio: AudioArtifact, *args, **kwargs) -> TextArtifact: + return self.audio_transcription_driver.try_run(audio) diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index 0aa4552dd..944a309eb 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -19,6 +19,9 @@ from .base_text_to_speech_event import BaseTextToSpeechEvent from .start_text_to_speech_event import StartTextToSpeechEvent from .finish_text_to_speech_event import FinishTextToSpeechEvent +from .base_audio_transcription_event import BaseAudioTranscriptionEvent +from .start_audio_transcription_event import StartAudioTranscriptionEvent +from .finish_audio_transcription_event import FinishAudioTranscriptionEvent __all__ = [ "BaseEvent", @@ -42,4 +45,7 @@ "BaseTextToSpeechEvent", "StartTextToSpeechEvent", "FinishTextToSpeechEvent", + "BaseAudioTranscriptionEvent", + "StartAudioTranscriptionEvent", + "FinishAudioTranscriptionEvent", ] diff --git a/griptape/events/base_audio_transcription_event.py b/griptape/events/base_audio_transcription_event.py new file mode 100644 index 000000000..f634adfce --- /dev/null +++ b/griptape/events/base_audio_transcription_event.py @@ -0,0 +1,4 @@ +from griptape.events import BaseEvent + + +class BaseAudioTranscriptionEvent(BaseEvent): ... diff --git a/griptape/events/finish_audio_transcription_event.py b/griptape/events/finish_audio_transcription_event.py new file mode 100644 index 000000000..321de1577 --- /dev/null +++ b/griptape/events/finish_audio_transcription_event.py @@ -0,0 +1,4 @@ +from griptape.events.base_audio_transcription_event import BaseAudioTranscriptionEvent + + +class FinishAudioTranscriptionEvent(BaseAudioTranscriptionEvent): ... diff --git a/griptape/events/start_audio_transcription_event.py b/griptape/events/start_audio_transcription_event.py new file mode 100644 index 000000000..25316ac8a --- /dev/null +++ b/griptape/events/start_audio_transcription_event.py @@ -0,0 +1,4 @@ +from griptape.events.base_audio_transcription_event import BaseAudioTranscriptionEvent + + +class StartAudioTranscriptionEvent(BaseAudioTranscriptionEvent): ... diff --git a/griptape/loaders/__init__.py b/griptape/loaders/__init__.py index 38beba01a..b79b0ff44 100644 --- a/griptape/loaders/__init__.py +++ b/griptape/loaders/__init__.py @@ -8,6 +8,7 @@ from .dataframe_loader import DataFrameLoader from .email_loader import EmailLoader from .image_loader import ImageLoader +from .audio_loader import AudioLoader from .blob_loader import BlobLoader @@ -22,5 +23,6 @@ "DataFrameLoader", "EmailLoader", "ImageLoader", + "AudioLoader", "BlobLoader", ] diff --git a/griptape/loaders/audio_loader.py b/griptape/loaders/audio_loader.py new file mode 100644 index 000000000..532662e79 --- /dev/null +++ b/griptape/loaders/audio_loader.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from typing import cast + +from attrs import define + +from griptape.artifacts import AudioArtifact +from griptape.loaders import BaseLoader +from griptape.utils import import_optional_dependency + + +@define +class AudioLoader(BaseLoader): + """Loads audio content into audio artifacts.""" + + def load(self, source: bytes, *args, **kwargs) -> AudioArtifact: + audio_artifact = AudioArtifact(source, format=import_optional_dependency("filetype").guess(source).extension) + + return audio_artifact + + def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, AudioArtifact]: + return cast(dict[str, AudioArtifact], super().load_collection(sources, *args, **kwargs)) diff --git a/griptape/tasks/__init__.py b/griptape/tasks/__init__.py index e51335241..2c282adff 100644 --- a/griptape/tasks/__init__.py +++ b/griptape/tasks/__init__.py @@ -20,6 +20,7 @@ from .base_audio_generation_task import BaseAudioGenerationTask from .text_to_speech_task import TextToSpeechTask from .structure_run_task import StructureRunTask +from .audio_transcription_task import AudioTranscriptionTask __all__ = [ "BaseTask", @@ -44,4 +45,5 @@ "BaseAudioGenerationTask", "TextToSpeechTask", "StructureRunTask", + "AudioTranscriptionTask", ] diff --git a/griptape/tasks/audio_transcription_task.py b/griptape/tasks/audio_transcription_task.py new file mode 100644 index 000000000..c75faa0d4 --- /dev/null +++ b/griptape/tasks/audio_transcription_task.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from abc import ABC +from typing import Callable + +from attrs import define, field + +from griptape.artifacts.audio_artifact import AudioArtifact +from griptape.engines import AudioTranscriptionEngine +from griptape.artifacts import TextArtifact +from griptape.mixins import RuleMixin +from griptape.tasks import BaseTask + + +@define +class AudioTranscriptionTask(RuleMixin, BaseTask, ABC): + _input: AudioArtifact | Callable[[BaseTask], AudioArtifact] = field() + _audio_transcription_engine: AudioTranscriptionEngine = field( + default=None, kw_only=True, alias="audio_transcription_engine" + ) + + @property + def input(self) -> AudioArtifact: + if isinstance(self._input, AudioArtifact): + return self._input + elif isinstance(self._input, Callable): + return self._input(self) + else: + raise ValueError("Input must be an AudioArtifact.") + + @input.setter + def input(self, value: AudioArtifact | Callable[[BaseTask], AudioArtifact]) -> None: + self._input = value + + @property + def audio_transcription_engine(self) -> AudioTranscriptionEngine: + if self._audio_transcription_engine is None: + if self.structure is not None: + self._audio_transcription_engine = AudioTranscriptionEngine( + audio_transcription_driver=self.structure.config.audio_transcription_driver + ) + else: + raise ValueError("Audio Generation Engine is not set.") + return self._audio_transcription_engine + + @audio_transcription_engine.setter + def audio_transcription_engine(self, value: AudioTranscriptionEngine) -> None: + self._audio_transcription_engine = value + + def run(self) -> TextArtifact: + return self.audio_transcription_engine.run(self.input) diff --git a/griptape/tools/audio_transcription_client/__init__.py b/griptape/tools/audio_transcription_client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/tools/audio_transcription_client/manifest.yml b/griptape/tools/audio_transcription_client/manifest.yml new file mode 100644 index 000000000..6bbe4a21a --- /dev/null +++ b/griptape/tools/audio_transcription_client/manifest.yml @@ -0,0 +1,5 @@ +version: "v1" +name: Transcription Client +description: A tool for generating transcription of audio. +contact_email: hello@griptape.ai +legal_info_url: https://www.griptape.ai/legal diff --git a/griptape/tools/audio_transcription_client/tool.py b/griptape/tools/audio_transcription_client/tool.py new file mode 100644 index 000000000..ad0f0626e --- /dev/null +++ b/griptape/tools/audio_transcription_client/tool.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import Any, cast + +from attrs import define, field, Factory +from schema import Schema, Literal + +from griptape.artifacts import ErrorArtifact, AudioArtifact, TextArtifact +from griptape.engines import AudioTranscriptionEngine +from griptape.loaders.audio_loader import AudioLoader +from griptape.tools import BaseTool +from griptape.utils import load_artifact_from_memory +from griptape.utils.decorators import activity + + +@define +class AudioTranscriptionClient(BaseTool): + """A tool that can be used to generate transcriptions from input audio.""" + + engine: AudioTranscriptionEngine = field(kw_only=True) + audio_loader: AudioLoader = field(default=Factory(lambda: AudioLoader()), kw_only=True) + + @activity( + config={ + "description": "This tool can be used to generate transcriptions of audio files on disk.", + "schema": Schema({Literal("path", description="The paths to an audio file on disk."): str}), + } + ) + def transcribe_audio_from_disk(self, params: dict) -> TextArtifact | ErrorArtifact: + audio_path = params["values"]["path"] + + with open(audio_path, "rb") as f: + audio_artifact = self.audio_loader.load(f.read()) + + return self.engine.run(audio_artifact) + + @activity( + config={ + "description": "This tool can be used to generate the transcription of an audio artifact in memory.", + "schema": Schema({"schema": Schema({"memory_name": str, "artifact_namespace": str, "artifact_name": str})}), + } + ) + def transcribe_audio_from_memory(self, params: dict[str, Any]) -> TextArtifact | ErrorArtifact: + memory = self.find_input_memory(params["values"]["memory_name"]) + artifact_namespace = params["values"]["artifact_namespace"] + artifact_name = params["values"]["artifact_name"] + + if memory is None: + return ErrorArtifact("memory not found") + + audio_artifact = cast( + AudioArtifact, load_artifact_from_memory(memory, artifact_namespace, artifact_name, AudioArtifact) + ) + + return self.engine.run(audio_artifact) diff --git a/mkdocs.yml b/mkdocs.yml index b63f36fe5..99231cce7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -110,6 +110,7 @@ nav: - Event Listener Drivers: "griptape-framework/drivers/event-listener-drivers.md" - Structure Run Drivers: "griptape-framework/drivers/structure-run-drivers.md" - Text to Speech Drivers: "griptape-framework/drivers/text-to-speech-drivers.md" + - Audio Transcription Drivers: "griptape-framework/drivers/audio-transcription-drivers.md" - Data: - Overview: "griptape-framework/data/index.md" - Artifacts: "griptape-framework/data/artifacts.md" @@ -146,6 +147,7 @@ nav: - OutpaintingImageGenerationClient: "griptape-tools/official-tools/outpainting-image-generation-client.md" - ImageQueryClient: "griptape-tools/official-tools/image-query-client.md" - TextToSpeechClient: "griptape-tools/official-tools/text-to-speech-client.md" + - AudioTranscriptionClient: "griptape-tools/official-tools/audio-transcription-client.md" - Custom Tools: - Building Custom Tools: "griptape-tools/custom-tools/index.md" - Recipes: diff --git a/poetry.lock b/poetry.lock index c78b72391..7e6804fcc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1353,6 +1353,17 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1 testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] typing = ["typing-extensions (>=4.8)"] +[[package]] +name = "filetype" +version = "1.2.0" +description = "Infer file type and MIME type of any file/buffer. No external dependencies." +optional = true +python-versions = "*" +files = [ + {file = "filetype-1.2.0-py2.py3-none-any.whl", hash = "sha256:7ce71b6880181241cf7ac8697a2f1eb6a8bd9b429f7ad6d27b8db9ba5f1c2d25"}, + {file = "filetype-1.2.0.tar.gz", hash = "sha256:66b56cd6474bf41d8c54660347d37afcc3f7d1970648de365c102ef77548aadb"}, +] + [[package]] name = "frozenlist" version = "1.4.1" @@ -5984,7 +5995,7 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -all = ["anthropic", "beautifulsoup4", "boto3", "cohere", "elevenlabs", "google-generativeai", "mail-parser", "markdownify", "marqo", "opensearch-py", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pymongo", "pypdf", "redis", "snowflake-sqlalchemy", "sqlalchemy-redshift", "torch", "trafilatura", "transformers", "voyageai"] +all = ["anthropic", "beautifulsoup4", "boto3", "cohere", "elevenlabs", "filetype", "google-generativeai", "mail-parser", "markdownify", "marqo", "opensearch-py", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pymongo", "pypdf", "redis", "snowflake-sqlalchemy", "sqlalchemy-redshift", "torch", "trafilatura", "transformers", "voyageai"] drivers-embedding-amazon-bedrock = ["boto3"] drivers-embedding-amazon-sagemaker = ["boto3"] drivers-embedding-google = ["google-generativeai"] @@ -6013,6 +6024,7 @@ drivers-vector-postgresql = ["pgvector", "psycopg2-binary"] drivers-vector-redis = ["redis"] drivers-web-scraper-markdownify = ["beautifulsoup4", "markdownify", "playwright"] drivers-web-scraper-trafilatura = ["trafilatura"] +loaders-audio = ["filetype"] loaders-dataframe = ["pandas"] loaders-email = ["mail-parser"] loaders-image = ["pillow"] @@ -6021,4 +6033,4 @@ loaders-pdf = ["pypdf"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "ce7c88b4d4ea368bd7a6c08c8e0f4310b2c10f1237d77d80f50bda4b35612481" +content-hash = "da89e6bb7aa395fd5badc416ec859e574416ea1fe07ee6f93035fe6a7e314dcd" diff --git a/pyproject.toml b/pyproject.toml index af264a5dd..55ae82ce2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ pandas = {version = "^1.3", optional = true} pypdf = {version = "^3.9", optional = true} pillow = {version = "^10.2.0", optional = true} mail-parser = {version = "^3.15.0", optional = true} +filetype = {version = "^1.2", optional = true} [tool.poetry.extras] drivers-prompt-cohere = ["cohere"] @@ -100,6 +101,7 @@ loaders-dataframe = ["pandas"] loaders-pdf = ["pypdf"] loaders-image = ["pillow"] loaders-email = ["mail-parser"] +loaders-audio = ["filetype"] all = [ # drivers @@ -132,6 +134,7 @@ all = [ "pypdf", "pillow", "mail-parser", + "filetype", ] [tool.poetry.group.test] diff --git a/tests/resources/sentences.wav b/tests/resources/sentences.wav new file mode 100644 index 000000000..796200849 Binary files /dev/null and b/tests/resources/sentences.wav differ diff --git a/tests/resources/sentences2.wav b/tests/resources/sentences2.wav new file mode 100644 index 000000000..185427285 Binary files /dev/null and b/tests/resources/sentences2.wav differ diff --git a/tests/unit/config/test_amazon_bedrock_structure_config.py b/tests/unit/config/test_amazon_bedrock_structure_config.py index 66ca44bb5..5b8c63a98 100644 --- a/tests/unit/config/test_amazon_bedrock_structure_config.py +++ b/tests/unit/config/test_amazon_bedrock_structure_config.py @@ -52,6 +52,7 @@ def test_to_dict(self, config): }, "type": "AmazonBedrockStructureConfig", "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config): diff --git a/tests/unit/config/test_anthropic_structure_config.py b/tests/unit/config/test_anthropic_structure_config.py index 9f014092a..8279fb091 100644 --- a/tests/unit/config/test_anthropic_structure_config.py +++ b/tests/unit/config/test_anthropic_structure_config.py @@ -45,6 +45,7 @@ def test_to_dict(self, config): }, "conversation_memory_driver": None, "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config): diff --git a/tests/unit/config/test_azure_openai_structure_config.py b/tests/unit/config/test_azure_openai_structure_config.py index dd5cd56a7..7e06dc0f5 100644 --- a/tests/unit/config/test_azure_openai_structure_config.py +++ b/tests/unit/config/test_azure_openai_structure_config.py @@ -57,7 +57,6 @@ def test_to_dict(self, config): "style": None, "type": "AzureOpenAiImageGenerationDriver", }, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, "image_query_driver": { "base_url": None, "image_quality": "auto", @@ -81,6 +80,8 @@ def test_to_dict(self, config): }, "type": "LocalVectorStoreDriver", }, + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config: AzureOpenAiStructureConfig): diff --git a/tests/unit/config/test_google_structure_config.py b/tests/unit/config/test_google_structure_config.py index f089b611b..72e623ff0 100644 --- a/tests/unit/config/test_google_structure_config.py +++ b/tests/unit/config/test_google_structure_config.py @@ -42,6 +42,7 @@ def test_to_dict(self, config): }, "conversation_memory_driver": None, "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config): diff --git a/tests/unit/config/test_openai_structure_config.py b/tests/unit/config/test_openai_structure_config.py index efcac008e..3dd1ac85f 100644 --- a/tests/unit/config/test_openai_structure_config.py +++ b/tests/unit/config/test_openai_structure_config.py @@ -27,7 +27,6 @@ def test_to_dict(self, config): "user": "", }, "conversation_memory_driver": None, - "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, "embedding_driver": { "base_url": None, "model": "text-embedding-3-small", @@ -63,6 +62,22 @@ def test_to_dict(self, config): }, "type": "LocalVectorStoreDriver", }, + "text_to_speech_driver": { + "type": "OpenAiTextToSpeechDriver", + "api_version": None, + "base_url": None, + "format": "mp3", + "model": "tts", + "organization": None, + "voice": "alloy", + }, + "audio_transcription_driver": { + "type": "OpenAiAudioTranscriptionDriver", + "api_version": None, + "base_url": None, + "model": "whisper-1", + "organization": None, + }, } def test_from_dict(self, config): diff --git a/tests/unit/config/test_structure_config.py b/tests/unit/config/test_structure_config.py index 9e1b00038..a420205f2 100644 --- a/tests/unit/config/test_structure_config.py +++ b/tests/unit/config/test_structure_config.py @@ -20,6 +20,7 @@ def test_to_dict(self, config): "type": "DummyVectorStoreDriver", }, "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, } def test_from_dict(self, config): diff --git a/tests/unit/drivers/transcription/__init__.py b/tests/unit/drivers/transcription/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/drivers/transcription/test_openai_audio_transcription_driver.py b/tests/unit/drivers/transcription/test_openai_audio_transcription_driver.py new file mode 100644 index 000000000..57c5a5e2e --- /dev/null +++ b/tests/unit/drivers/transcription/test_openai_audio_transcription_driver.py @@ -0,0 +1,25 @@ +import pytest +from unittest.mock import Mock + +from griptape.artifacts import AudioArtifact +from griptape.drivers import OpenAiAudioTranscriptionDriver + + +class TestOpenAiAudioTranscriptionDriver: + @pytest.fixture + def audio_artifact(self): + return AudioArtifact(value=b"audio data", format="mp3") + + @pytest.fixture + def driver(self): + return OpenAiAudioTranscriptionDriver(model="model", client=Mock(), api_key="key") + + def test_init(self, driver): + assert driver + + def test_try_text_to_audio(self, driver, audio_artifact): + driver.client.audio.transcriptions.create.return_value = Mock(text="text data") + + text_artifact = driver.try_run(audio_artifact) + + assert text_artifact.value == "text data" diff --git a/tests/unit/loaders/test_audio_loader.py b/tests/unit/loaders/test_audio_loader.py new file mode 100644 index 000000000..b7946da03 --- /dev/null +++ b/tests/unit/loaders/test_audio_loader.py @@ -0,0 +1,41 @@ +import pytest + +from griptape.artifacts import AudioArtifact +from griptape.loaders import AudioLoader + + +class TestAudioLoader: + @pytest.fixture + def loader(self): + return AudioLoader() + + @pytest.fixture + def create_source(self, bytes_from_resource_path): + return bytes_from_resource_path + + @pytest.mark.parametrize("resource_path,suffix,mime_type", [("sentences.wav", ".wav", "audio/wav")]) + def test_load(self, resource_path, suffix, mime_type, loader, create_source): + source = create_source(resource_path) + + artifact = loader.load(source) + + assert isinstance(artifact, AudioArtifact) + assert artifact.name.endswith(suffix) + assert artifact.mime_type == mime_type + assert len(artifact.value) > 0 + + def test_load_collection(self, create_source, loader): + resource_paths = ["sentences.wav", "sentences2.wav"] + sources = [create_source(resource_path) for resource_path in resource_paths] + + collection = loader.load_collection(sources) + + assert len(collection) == len(resource_paths) + + keys = {loader.to_key(source) for source in sources} + for key in collection.keys(): + artifact = collection[key] + assert isinstance(artifact, AudioArtifact) + assert artifact.name.endswith(".wav") + assert artifact.mime_type == "audio/wav" + assert len(artifact.value) > 0 diff --git a/tests/unit/tasks/test_audio_transcription_task.py b/tests/unit/tasks/test_audio_transcription_task.py new file mode 100644 index 000000000..fdab5f730 --- /dev/null +++ b/tests/unit/tasks/test_audio_transcription_task.py @@ -0,0 +1,38 @@ +from unittest.mock import Mock + +import pytest + +from griptape.artifacts import AudioArtifact +from griptape.engines import AudioTranscriptionEngine +from griptape.structures import Agent +from griptape.tasks import BaseTask, AudioTranscriptionTask +from tests.mocks.mock_structure_config import MockStructureConfig + + +class TestAudioTranscriptionTask: + @pytest.fixture + def audio_artifact(self): + return AudioArtifact(value=b"audio data", format="mp3") + + @pytest.fixture + def audio_transcription_engine(self): + return Mock() + + def test_audio_input(self, audio_artifact, audio_transcription_engine): + task = AudioTranscriptionTask(audio_artifact, audio_transcription_engine=audio_transcription_engine) + + assert task.input.value == audio_artifact.value + + def test_callable_input(self, audio_artifact, audio_transcription_engine): + def callable(task: BaseTask) -> AudioArtifact: + return audio_artifact + + task = AudioTranscriptionTask(callable, audio_transcription_engine=audio_transcription_engine) + + assert task.input == audio_artifact + + def test_config_audio_transcription_engine(self, audio_artifact): + task = AudioTranscriptionTask(audio_artifact) + Agent(config=MockStructureConfig()).add_task(task) + + assert isinstance(task.audio_transcription_engine, AudioTranscriptionEngine) diff --git a/tests/unit/tools/test_transcription_client.py b/tests/unit/tools/test_transcription_client.py new file mode 100644 index 000000000..ea6bd3453 --- /dev/null +++ b/tests/unit/tools/test_transcription_client.py @@ -0,0 +1,47 @@ +from unittest.mock import Mock, mock_open, patch + +import pytest + +from griptape.artifacts import AudioArtifact +from griptape.tools.audio_transcription_client.tool import AudioTranscriptionClient + + +class TestTranscriptionClient: + @pytest.fixture + def transcription_engine(self) -> Mock: + return Mock() + + @pytest.fixture + def audio_loader(self) -> Mock: + loader = Mock() + loader.load.return_value = AudioArtifact(value=b"audio data", format="wav") + + return loader + + def test_init_transcription_client(self, transcription_engine, audio_loader) -> None: + assert AudioTranscriptionClient(engine=transcription_engine, audio_loader=audio_loader) + + @patch("builtins.open", mock_open(read_data=b"audio data")) + def test_transcribe_audio_from_disk(self, transcription_engine, audio_loader) -> None: + client = AudioTranscriptionClient(engine=transcription_engine, audio_loader=audio_loader) + client.engine.run.return_value = Mock(value="transcription") # pyright: ignore + + text_artifact = client.transcribe_audio_from_disk(params={"values": {"path": "audio.wav"}}) + + assert text_artifact + assert text_artifact.value == "transcription" + + def test_transcribe_audio_from_memory(self, transcription_engine, audio_loader) -> None: + client = AudioTranscriptionClient(engine=transcription_engine, audio_loader=audio_loader) + memory = Mock() + memory.load_artifacts = Mock(return_value=[AudioArtifact(value=b"audio data", format="wav", name="name")]) + client.find_input_memory = Mock(return_value=memory) + + client.engine.run.return_value = Mock(value="transcription") # pyright: ignore + + text_artifact = client.transcribe_audio_from_memory( + params={"values": {"memory_name": "memory", "artifact_namespace": "namespace", "artifact_name": "name"}} + ) + + assert text_artifact + assert text_artifact.value == "transcription"