From 7b85d663d441c0252159d509a6024ff3d30121aa Mon Sep 17 00:00:00 2001 From: Andrew French Date: Tue, 14 May 2024 16:10:38 -0700 Subject: [PATCH] Renamings, tests, docs --- docs/griptape-framework/data/artifacts.md | 6 ++- .../drivers/text-to-speech-drivers.md | 33 +++++++++++++++ .../engines/audio-engines.md | 29 ++++++++++++++ .../official-tools/text-to-speech-client.md | 27 +++++++++++++ .../config/structure_global_drivers_config.py | 8 ++-- griptape/drivers/__init__.py | 12 +++--- .../__init__.py | 0 .../base_text_to_speech_driver.py} | 2 +- .../dummy_text_to_speech_driver.py} | 4 +- .../elevenlabs_text_to_speech_driver.py} | 14 +++++-- griptape/engines/__init__.py | 6 +-- .../audio/base_audio_generation_engine.py | 19 --------- .../audio/text_to_audio_generation_engine.py | 12 ------ .../engines/audio/text_to_speech_engine.py | 14 +++++++ griptape/mixins/__init__.py | 4 +- .../media_artifact_file_output_mixin.py | 6 +-- griptape/tasks/__init__.py | 4 +- griptape/tasks/base_audio_generation_task.py | 4 +- griptape/tasks/base_image_generation_task.py | 4 +- ...eration_task.py => text_to_speech_task.py} | 26 ++++++------ .../tool.py | 4 +- .../tool.py | 4 +- .../prompt_image_generation_client/tool.py | 4 +- griptape/tools/text_to_speech_client/tool.py | 8 ++-- .../variation_image_generation_client/tool.py | 4 +- griptape/utils/__init__.py | 1 - griptape/utils/play_audio.py | 9 ----- pyproject.toml | 3 +- tests/unit/drivers/text_to_speech/__init__.py | 0 ...test_elevenlabs_audio_generation_driver.py | 20 ++++++++++ .../test_image_artifact_file_output_mixin.py | 10 ++--- tests/unit/tasks/test_text_to_speech_task.py | 30 ++++++++++++++ .../unit/tools/test_text_to_speech_client.py | 40 +++++++++++++++++++ 33 files changed, 265 insertions(+), 106 deletions(-) create mode 100644 docs/griptape-framework/drivers/text-to-speech-drivers.md create mode 100644 docs/griptape-framework/engines/audio-engines.md create mode 100644 docs/griptape-tools/official-tools/text-to-speech-client.md rename griptape/drivers/{audio_generation => text_to_speech}/__init__.py (100%) rename griptape/drivers/{audio_generation/base_audio_generation_driver.py => text_to_speech/base_text_to_speech_driver.py} (95%) rename griptape/drivers/{audio_generation/dummy_audio_generation_driver.py => text_to_speech/dummy_text_to_speech_driver.py} (77%) rename griptape/drivers/{audio_generation/elevenlabs_audio_generation_driver.py => text_to_speech/elevenlabs_text_to_speech_driver.py} (69%) delete mode 100644 griptape/engines/audio/base_audio_generation_engine.py delete mode 100644 griptape/engines/audio/text_to_audio_generation_engine.py create mode 100644 griptape/engines/audio/text_to_speech_engine.py rename griptape/tasks/{audio_generation_task.py => text_to_speech_task.py} (58%) delete mode 100644 griptape/utils/play_audio.py create mode 100644 tests/unit/drivers/text_to_speech/__init__.py create mode 100644 tests/unit/drivers/text_to_speech/test_elevenlabs_audio_generation_driver.py create mode 100644 tests/unit/tasks/test_text_to_speech_task.py create mode 100644 tests/unit/tools/test_text_to_speech_client.py diff --git a/docs/griptape-framework/data/artifacts.md b/docs/griptape-framework/data/artifacts.md index b12dc97d3..18d5cb61f 100644 --- a/docs/griptape-framework/data/artifacts.md +++ b/docs/griptape-framework/data/artifacts.md @@ -35,4 +35,8 @@ Each blob has a [name](../../reference/griptape/artifacts/base_artifact.md#gript ## ImageArtifact -An [ImageArtifact](../../reference/griptape/artifacts/image_artifact.md) is used for passing images back to the LLM. In addition to binary image data, an ImageArtifact includes image metadata like MIME type, dimensions, and prompt and model information for images returned by [image generation Drivers](../drivers/image-generation-drivers.md). It inherits from [BlobArtifact](#blobartifact). +An [ImageArtifact](../../reference/griptape/artifacts/image_artifact.md) is used for passing images back to the LLM. In addition to binary image data, an Image Artifact includes image metadata like MIME type, dimensions, and prompt and model information for images returned by [image generation Drivers](../drivers/image-generation-drivers.md). It inherits from [BlobArtifact](#blobartifact). + +## AudioArtifact + +An [AudioArtifact](../../reference/griptape/artifacts/audio_artifact.md) allows the Framework to interact with audio content. An Audio Artifact includes binary audio content as well as metadata like format, duration, and prompt and model information for audio returned generative models. It inherits from [BlobArtifact](#blobartifact). diff --git a/docs/griptape-framework/drivers/text-to-speech-drivers.md b/docs/griptape-framework/drivers/text-to-speech-drivers.md new file mode 100644 index 000000000..bef7ab0e3 --- /dev/null +++ b/docs/griptape-framework/drivers/text-to-speech-drivers.md @@ -0,0 +1,33 @@ +## Overview + +[Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md) are used by [Text To Speech Engines](../engines/audio/text-to-speech-engine.md) to build and execute API calls to audio generation models. + +Provide a Driver when building an [Engine](../engines/audio-generation-engines.md), then pass it to a [Tool](../tools/index.md) for use by an [Agent](../structures/agents.md): + +### Eleven Labs + +The [Eleven Labs Text to Speech Driver](../../reference/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.md) provides support for text-to-speech models hosted by Eleven Labs. This Driver supports configurations specific to Eleven Labs, like voice selection and output format. + +```python +import os + +from griptape.drivers import ElevenLabsTextToSpeechDriver +from griptape.engines import TextToSpeechEngine +from griptape.tools.text_to_speech_client.tool import TextToSpeechClient +from griptape.structures import Agent + + +driver = ElevenLabsTextToSpeechDriver( + api_key=os.getenv("ELEVEN_LABS_API_KEY"), + model="eleven_multilingual_v2", + voice="Matilda", +) + +tool = TextToSpeechClient( + engine=TextToSpeechEngine( + text_to_speech_driver=driver, + ), +) + +Agent(tools=[tool]).run("Generate audio from this text: 'Hello, world!'") +``` diff --git a/docs/griptape-framework/engines/audio-engines.md b/docs/griptape-framework/engines/audio-engines.md new file mode 100644 index 000000000..50265fcbd --- /dev/null +++ b/docs/griptape-framework/engines/audio-engines.md @@ -0,0 +1,29 @@ +## Overview + +[Audio Generation Engines](../../reference/griptape/engines/audio/index.md) facilitate audio generation. Audio Generation Engines provides a `run` method that accepts the necessary inputs for its particular mode and provides the request to the configured [Driver](../drivers/text-to-speech/index.md). + +### Text to Speech Engine + +This Engine facilitates synthesizing speech from text inputs. + +```python +import os + +from griptape.drivers import ElevenLabsTextToSpeechDriver +from griptape.engines import TextToSpeechEngine + + +driver = ElevenLabsTextToSpeechDriver( + api_key=os.getenv("ELEVEN_LABS_API_KEY"), + model="eleven_multilingual_v2", + voice="Rachel", +) + +engine = TextToSpeechEngine( + text_to_speech_driver=driver, +) + +engine.run( + prompts=["Hello, world!"], +) +``` diff --git a/docs/griptape-tools/official-tools/text-to-speech-client.md b/docs/griptape-tools/official-tools/text-to-speech-client.md new file mode 100644 index 000000000..f016db551 --- /dev/null +++ b/docs/griptape-tools/official-tools/text-to-speech-client.md @@ -0,0 +1,27 @@ +# TextToSpeechClient + +This tool enables LLMs to synthesize speech from text using [Text to Speech Engines](../../reference/griptape/engines/audio/text_to_speech_engine.md) and [Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md). + +```python +import os + +from griptape.drivers import ElevenLabsTextToSpeechDriver +from griptape.engines import TextToSpeechEngine +from griptape.tools.text_to_speech_client.tool import TextToSpeechClient +from griptape.structures import Agent + + +driver = ElevenLabsTextToSpeechDriver( + api_key=os.getenv("ELEVEN_LABS_API_KEY"), + model="eleven_multilingual_v2", + voice="Matilda", +) + +tool = TextToSpeechClient( + engine=TextToSpeechEngine( + text_to_speech_driver=driver, + ), +) + +Agent(tools=[tool]).run("Generate audio from this text: 'Hello, world!'") +``` \ No newline at end of file diff --git a/griptape/config/structure_global_drivers_config.py b/griptape/config/structure_global_drivers_config.py index c78c82fe6..29985c3e5 100644 --- a/griptape/config/structure_global_drivers_config.py +++ b/griptape/config/structure_global_drivers_config.py @@ -14,9 +14,9 @@ DummyPromptDriver, DummyImageQueryDriver, BaseImageQueryDriver, - BaseAudioGenerationDriver, + BaseTextToSpeechDriver, ) -from griptape.drivers.audio_generation.dummy_audio_generation_driver import DummyAudioGenerationDriver +from griptape.drivers.text_to_speech.dummy_text_to_speech_driver import DummyTextToSpeechDriver from griptape.mixins.serializable_mixin import SerializableMixin @@ -40,6 +40,6 @@ class StructureGlobalDriversConfig(SerializableMixin): conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field( default=None, kw_only=True, metadata={"serializable": True} ) - audio_generation_driver: BaseAudioGenerationDriver = field( - default=Factory(lambda: DummyAudioGenerationDriver()), kw_only=True, metadata={"serializable": True} + audio_generation_driver: BaseTextToSpeechDriver = field( + default=Factory(lambda: DummyTextToSpeechDriver()), kw_only=True, metadata={"serializable": True} ) diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index 013cb487f..462da47c4 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -97,9 +97,9 @@ from .file_manager.local_file_manager_driver import LocalFileManagerDriver from .file_manager.amazon_s3_file_manager_driver import AmazonS3FileManagerDriver -from .audio_generation.base_audio_generation_driver import BaseAudioGenerationDriver -from .audio_generation.dummy_audio_generation_driver import DummyAudioGenerationDriver -from .audio_generation.elevenlabs_audio_generation_driver import ElevenLabsAudioGenerationDriver +from .text_to_speech.base_text_to_speech_driver import BaseTextToSpeechDriver +from .text_to_speech.dummy_text_to_speech_driver import DummyTextToSpeechDriver +from .text_to_speech.elevenlabs_text_to_speech_driver import ElevenLabsTextToSpeechDriver __all__ = [ "BasePromptDriver", @@ -185,7 +185,7 @@ "BaseFileManagerDriver", "LocalFileManagerDriver", "AmazonS3FileManagerDriver", - "BaseAudioGenerationDriver", - "DummyAudioGenerationDriver", - "ElevenLabsAudioGenerationDriver", + "BaseTextToSpeechDriver", + "DummyTextToSpeechDriver", + "ElevenLabsTextToSpeechDriver", ] diff --git a/griptape/drivers/audio_generation/__init__.py b/griptape/drivers/text_to_speech/__init__.py similarity index 100% rename from griptape/drivers/audio_generation/__init__.py rename to griptape/drivers/text_to_speech/__init__.py diff --git a/griptape/drivers/audio_generation/base_audio_generation_driver.py b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py similarity index 95% rename from griptape/drivers/audio_generation/base_audio_generation_driver.py rename to griptape/drivers/text_to_speech/base_text_to_speech_driver.py index ba7489c9e..a235e80ff 100644 --- a/griptape/drivers/audio_generation/base_audio_generation_driver.py +++ b/griptape/drivers/text_to_speech/base_text_to_speech_driver.py @@ -15,7 +15,7 @@ @define -class BaseAudioGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC): +class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) structure: Optional[Structure] = field(default=None, kw_only=True) diff --git a/griptape/drivers/audio_generation/dummy_audio_generation_driver.py b/griptape/drivers/text_to_speech/dummy_text_to_speech_driver.py similarity index 77% rename from griptape/drivers/audio_generation/dummy_audio_generation_driver.py rename to griptape/drivers/text_to_speech/dummy_text_to_speech_driver.py index 2a89220f2..701948acf 100644 --- a/griptape/drivers/audio_generation/dummy_audio_generation_driver.py +++ b/griptape/drivers/text_to_speech/dummy_text_to_speech_driver.py @@ -1,12 +1,12 @@ from typing import Optional from attrs import define, field from griptape.artifacts.audio_artifact import AudioArtifact -from griptape.drivers import BaseAudioGenerationDriver +from griptape.drivers import BaseTextToSpeechDriver from griptape.exceptions import DummyException @define -class DummyAudioGenerationDriver(BaseAudioGenerationDriver): +class DummyTextToSpeechDriver(BaseTextToSpeechDriver): model: str = field(init=False) def try_text_to_audio(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> AudioArtifact: diff --git a/griptape/drivers/audio_generation/elevenlabs_audio_generation_driver.py b/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py similarity index 69% rename from griptape/drivers/audio_generation/elevenlabs_audio_generation_driver.py rename to griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py index 2d808e96e..5a6573f8e 100644 --- a/griptape/drivers/audio_generation/elevenlabs_audio_generation_driver.py +++ b/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py @@ -5,12 +5,14 @@ from attr import define, field, Factory from griptape.artifacts.audio_artifact import AudioArtifact -from griptape.drivers.audio_generation.base_audio_generation_driver import BaseAudioGenerationDriver -from elevenlabs.client import ElevenLabs +from griptape.drivers import BaseTextToSpeechDriver + +if TYPE_CHECKING: + from elevenlabs.client import ElevenLabs @define -class ElevenLabsAudioGenerationDriver(BaseAudioGenerationDriver): +class ElevenLabsTextToSpeechDriver(BaseTextToSpeechDriver): api_key: str = field(kw_only=True, metadata={"serializable": True}) client: Any = field( default=Factory(lambda self: ElevenLabs(api_key=self.api_key), takes_self=True), @@ -29,4 +31,8 @@ def try_text_to_audio(self, prompts: list[str], negative_prompts: Optional[list[ for chunk in audio: content += chunk - return AudioArtifact(value=content, format="mpeg") + # All ElevenLabs audio format strings have the following structure: + # {format}_{sample_rate}_{bitrate} + artifact_format = self.output_format.split("_")[0] + + return AudioArtifact(value=content, format=artifact_format) diff --git a/griptape/engines/__init__.py b/griptape/engines/__init__.py index 35262110e..f76e373ef 100644 --- a/griptape/engines/__init__.py +++ b/griptape/engines/__init__.py @@ -11,8 +11,7 @@ from .image.inpainting_image_generation_engine import InpaintingImageGenerationEngine from .image.outpainting_image_generation_engine import OutpaintingImageGenerationEngine from .image_query.image_query_engine import ImageQueryEngine -from .audio.base_audio_generation_engine import BaseAudioGenerationEngine -from .audio.text_to_audio_generation_engine import TextToAudioGenerationEngine +from .audio.text_to_speech_engine import TextToSpeechEngine __all__ = [ "BaseQueryEngine", @@ -28,6 +27,5 @@ "InpaintingImageGenerationEngine", "OutpaintingImageGenerationEngine", "ImageQueryEngine", - "BaseAudioGenerationEngine", - "TextToAudioGenerationEngine", + "TextToSpeechEngine", ] diff --git a/griptape/engines/audio/base_audio_generation_engine.py b/griptape/engines/audio/base_audio_generation_engine.py deleted file mode 100644 index 82f99eef3..000000000 --- a/griptape/engines/audio/base_audio_generation_engine.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import annotations -from abc import ABC, abstractmethod - -from attr import field, define -from typing import Optional - -from griptape.artifacts import MediaArtifact -from griptape.artifacts.audio_artifact import AudioArtifact -from griptape.drivers import BaseImageGenerationDriver, BaseAudioGenerationDriver -from griptape.rules import Ruleset - - -@define -class BaseAudioGenerationEngine(ABC): - audio_generation_driver: BaseAudioGenerationDriver = field(kw_only=True) - - @abstractmethod - def run(self, prompts: list[str], *args, **kwargs) -> AudioArtifact: - ... diff --git a/griptape/engines/audio/text_to_audio_generation_engine.py b/griptape/engines/audio/text_to_audio_generation_engine.py deleted file mode 100644 index c2b238387..000000000 --- a/griptape/engines/audio/text_to_audio_generation_engine.py +++ /dev/null @@ -1,12 +0,0 @@ -from __future__ import annotations - -from attr import define - -from griptape.artifacts.audio_artifact import AudioArtifact -from griptape.engines.audio.base_audio_generation_engine import BaseAudioGenerationEngine - - -@define -class TextToAudioGenerationEngine(BaseAudioGenerationEngine): - def run(self, prompts: list[str], *args, **kwargs) -> AudioArtifact: - return self.audio_generation_driver.try_text_to_audio(prompts=prompts) diff --git a/griptape/engines/audio/text_to_speech_engine.py b/griptape/engines/audio/text_to_speech_engine.py new file mode 100644 index 000000000..29118848e --- /dev/null +++ b/griptape/engines/audio/text_to_speech_engine.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from attr import define, field + +from griptape.artifacts.audio_artifact import AudioArtifact +from griptape.drivers import BaseTextToSpeechDriver + + +@define +class TextToSpeechEngine: + text_to_speech_driver: BaseTextToSpeechDriver = field(kw_only=True) + + def run(self, prompts: list[str], *args, **kwargs) -> AudioArtifact: + return self.text_to_speech_driver.try_text_to_audio(prompts=prompts) diff --git a/griptape/mixins/__init__.py b/griptape/mixins/__init__.py index 78bee7440..d9eea53c2 100644 --- a/griptape/mixins/__init__.py +++ b/griptape/mixins/__init__.py @@ -3,13 +3,13 @@ from .actions_subtask_origin_mixin import ActionsSubtaskOriginMixin from .rule_mixin import RuleMixin from .serializable_mixin import SerializableMixin -from .media_artifact_file_output_mixin import MediaArtifactFileOutputMixin +from .media_artifact_file_output_mixin import BlobArtifactFileOutputMixin __all__ = [ "ActivityMixin", "ExponentialBackoffMixin", "ActionsSubtaskOriginMixin", "RuleMixin", - "MediaArtifactFileOutputMixin", + "BlobArtifactFileOutputMixin", "SerializableMixin", ] diff --git a/griptape/mixins/media_artifact_file_output_mixin.py b/griptape/mixins/media_artifact_file_output_mixin.py index a718a9b6d..d7d6f584c 100644 --- a/griptape/mixins/media_artifact_file_output_mixin.py +++ b/griptape/mixins/media_artifact_file_output_mixin.py @@ -7,11 +7,11 @@ from typing import Optional if TYPE_CHECKING: - from griptape.artifacts import MediaArtifact + from griptape.artifacts import BlobArtifact @define(slots=False) -class MediaArtifactFileOutputMixin: +class BlobArtifactFileOutputMixin: output_dir: Optional[str] = field(default=None, kw_only=True) output_file: Optional[str] = field(default=None, kw_only=True) @@ -31,7 +31,7 @@ def validate_output_file(self, _, output_file: str) -> None: if self.output_dir: raise ValueError("Can't have both output_dir and output_file specified.") - def _write_to_file(self, artifact: MediaArtifact) -> None: + def _write_to_file(self, artifact: BlobArtifact) -> None: if self.output_file: outfile = self.output_file elif self.output_dir: diff --git a/griptape/tasks/__init__.py b/griptape/tasks/__init__.py index ced6c74b7..110d7dbe2 100644 --- a/griptape/tasks/__init__.py +++ b/griptape/tasks/__init__.py @@ -17,7 +17,7 @@ from .variation_image_generation_task import VariationImageGenerationTask from .image_query_task import ImageQueryTask from .base_audio_generation_task import BaseAudioGenerationTask -from .audio_generation_task import AudioGenerationTask +from .text_to_speech_task import TextToSpeechTask __all__ = [ "BaseTask", @@ -39,5 +39,5 @@ "OutpaintingImageGenerationTask", "ImageQueryTask", "BaseAudioGenerationTask", - "AudioGenerationTask", + "TextToSpeechTask", ] diff --git a/griptape/tasks/base_audio_generation_task.py b/griptape/tasks/base_audio_generation_task.py index 656253ed1..f9d691768 100644 --- a/griptape/tasks/base_audio_generation_task.py +++ b/griptape/tasks/base_audio_generation_task.py @@ -7,11 +7,11 @@ from griptape.artifacts import MediaArtifact from griptape.loaders import ImageLoader -from griptape.mixins import RuleMixin, MediaArtifactFileOutputMixin +from griptape.mixins import RuleMixin, BlobArtifactFileOutputMixin from griptape.rules import Ruleset, Rule from griptape.tasks import BaseTask @define -class BaseAudioGenerationTask(MediaArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): +class BaseAudioGenerationTask(BlobArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): ... diff --git a/griptape/tasks/base_image_generation_task.py b/griptape/tasks/base_image_generation_task.py index 76882e0f7..2dbab4ce9 100644 --- a/griptape/tasks/base_image_generation_task.py +++ b/griptape/tasks/base_image_generation_task.py @@ -7,13 +7,13 @@ from griptape.artifacts import MediaArtifact from griptape.loaders import ImageLoader -from griptape.mixins import RuleMixin, MediaArtifactFileOutputMixin +from griptape.mixins import RuleMixin, BlobArtifactFileOutputMixin from griptape.rules import Ruleset, Rule from griptape.tasks import BaseTask @define -class BaseImageGenerationTask(MediaArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): +class BaseImageGenerationTask(BlobArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): """Provides a base class for image generation-related tasks. Attributes: diff --git a/griptape/tasks/audio_generation_task.py b/griptape/tasks/text_to_speech_task.py similarity index 58% rename from griptape/tasks/audio_generation_task.py rename to griptape/tasks/text_to_speech_task.py index 639b99963..777e26917 100644 --- a/griptape/tasks/audio_generation_task.py +++ b/griptape/tasks/text_to_speech_task.py @@ -5,7 +5,7 @@ from attr import define, field from griptape.artifacts.audio_artifact import AudioArtifact -from griptape.engines import TextToAudioGenerationEngine +from griptape.engines import TextToSpeechEngine from griptape.artifacts import TextArtifact from griptape.tasks import BaseTask from griptape.tasks.base_audio_generation_task import BaseAudioGenerationTask @@ -13,13 +13,11 @@ @define -class AudioGenerationTask(BaseAudioGenerationTask): +class TextToSpeechTask(BaseAudioGenerationTask): DEFAULT_INPUT_TEMPLATE = "{{ args[0] }}" _input: str | TextArtifact | Callable[[BaseTask], TextArtifact] = field(default=DEFAULT_INPUT_TEMPLATE) - _audio_generation_engine: TextToAudioGenerationEngine = field( - default=None, kw_only=True, alias="audio_generation_engine" - ) + _text_to_speech_engine: TextToSpeechEngine = field(default=None, kw_only=True, alias="text_to_speech_engine") @property def input(self) -> TextArtifact: @@ -35,22 +33,22 @@ def input(self, value: TextArtifact) -> None: self._input = value @property - def audio_generation_engine(self) -> TextToAudioGenerationEngine: - if self._audio_generation_engine is None: + def text_to_speech_engine(self) -> TextToSpeechEngine: + if self._text_to_speech_engine is None: if self.structure is not None: - self._audio_generation_engine = TextToAudioGenerationEngine( - audio_generation_driver=self.structure.config.global_drivers.audio_generation_driver + self._text_to_speech_engine = TextToSpeechEngine( + text_to_speech_driver=self.structure.config.global_drivers.audio_generation_driver ) else: raise ValueError("Audio Generation Engine is not set.") - return self._audio_generation_engine + return self._text_to_speech_engine - @audio_generation_engine.setter - def audio_generation_engine(self, value: TextToAudioGenerationEngine) -> None: - self._audio_generation_engine = value + @text_to_speech_engine.setter + def text_to_speech_engine(self, value: TextToSpeechEngine) -> None: + self._text_to_speech_engine = value def run(self) -> AudioArtifact: - audio_artifact = self.audio_generation_engine.run(prompts=[self.input.to_text()], rulesets=self.all_rulesets) + audio_artifact = self.text_to_speech_engine.run(prompts=[self.input.to_text()], rulesets=self.all_rulesets) if self.output_dir or self.output_file: self._write_to_file(audio_artifact) diff --git a/griptape/tools/inpainting_image_generation_client/tool.py b/griptape/tools/inpainting_image_generation_client/tool.py index 23a162b9a..02799ad29 100644 --- a/griptape/tools/inpainting_image_generation_client/tool.py +++ b/griptape/tools/inpainting_image_generation_client/tool.py @@ -8,14 +8,14 @@ from griptape.artifacts import ErrorArtifact, ImageArtifact from griptape.engines import InpaintingImageGenerationEngine from griptape.loaders import ImageLoader -from griptape.mixins import MediaArtifactFileOutputMixin +from griptape.mixins import BlobArtifactFileOutputMixin from griptape.tools import BaseTool from griptape.utils.decorators import activity from griptape.utils.load_artifact_from_memory import load_artifact_from_memory @define -class InpaintingImageGenerationClient(MediaArtifactFileOutputMixin, BaseTool): +class InpaintingImageGenerationClient(BlobArtifactFileOutputMixin, BaseTool): """A tool that can be used to generate prompted inpaintings of an image. Attributes: diff --git a/griptape/tools/outpainting_image_generation_client/tool.py b/griptape/tools/outpainting_image_generation_client/tool.py index a848ea6a3..bd9a2125c 100644 --- a/griptape/tools/outpainting_image_generation_client/tool.py +++ b/griptape/tools/outpainting_image_generation_client/tool.py @@ -10,12 +10,12 @@ from griptape.loaders import ImageLoader from griptape.tools import BaseTool from griptape.utils.decorators import activity -from griptape.mixins import MediaArtifactFileOutputMixin +from griptape.mixins import BlobArtifactFileOutputMixin from griptape.utils.load_artifact_from_memory import load_artifact_from_memory @define -class OutpaintingImageGenerationClient(MediaArtifactFileOutputMixin, BaseTool): +class OutpaintingImageGenerationClient(BlobArtifactFileOutputMixin, BaseTool): """A tool that can be used to generate prompted outpaintings of an image. Attributes: diff --git a/griptape/tools/prompt_image_generation_client/tool.py b/griptape/tools/prompt_image_generation_client/tool.py index 6b36794f5..50020a1ea 100644 --- a/griptape/tools/prompt_image_generation_client/tool.py +++ b/griptape/tools/prompt_image_generation_client/tool.py @@ -7,11 +7,11 @@ from griptape.engines import PromptImageGenerationEngine from griptape.tools import BaseTool from griptape.utils.decorators import activity -from griptape.mixins import MediaArtifactFileOutputMixin +from griptape.mixins import BlobArtifactFileOutputMixin @define -class PromptImageGenerationClient(MediaArtifactFileOutputMixin, BaseTool): +class PromptImageGenerationClient(BlobArtifactFileOutputMixin, BaseTool): """A tool that can be used to generate an image from a text prompt. Attributes: diff --git a/griptape/tools/text_to_speech_client/tool.py b/griptape/tools/text_to_speech_client/tool.py index 923c5a5d2..9f8c36a81 100644 --- a/griptape/tools/text_to_speech_client/tool.py +++ b/griptape/tools/text_to_speech_client/tool.py @@ -6,14 +6,14 @@ from schema import Schema, Literal from griptape.artifacts import ErrorArtifact, AudioArtifact -from griptape.engines import TextToAudioGenerationEngine +from griptape.engines import TextToSpeechEngine from griptape.tools import BaseTool from griptape.utils.decorators import activity -from griptape.mixins import MediaArtifactFileOutputMixin +from griptape.mixins import BlobArtifactFileOutputMixin @define -class TextToSpeechClient(MediaArtifactFileOutputMixin, BaseTool): +class TextToSpeechClient(BlobArtifactFileOutputMixin, BaseTool): """A tool that can be used to generate speech from input text. Attributes: @@ -22,7 +22,7 @@ class TextToSpeechClient(MediaArtifactFileOutputMixin, BaseTool): output_file: If provided, the generated audio will be written to disk as output_file. """ - engine: TextToAudioGenerationEngine = field(kw_only=True) + engine: TextToSpeechEngine = field(kw_only=True) @activity( config={ diff --git a/griptape/tools/variation_image_generation_client/tool.py b/griptape/tools/variation_image_generation_client/tool.py index 2dbcd4298..4d0d0537e 100644 --- a/griptape/tools/variation_image_generation_client/tool.py +++ b/griptape/tools/variation_image_generation_client/tool.py @@ -10,12 +10,12 @@ from griptape.loaders import ImageLoader from griptape.tools import BaseTool from griptape.utils.decorators import activity -from griptape.mixins import MediaArtifactFileOutputMixin +from griptape.mixins import BlobArtifactFileOutputMixin from griptape.utils.load_artifact_from_memory import load_artifact_from_memory @define -class VariationImageGenerationClient(MediaArtifactFileOutputMixin, BaseTool): +class VariationImageGenerationClient(BlobArtifactFileOutputMixin, BaseTool): """A tool that can be used to generate prompted variations of an image. Attributes: diff --git a/griptape/utils/__init__.py b/griptape/utils/__init__.py index 77a47b0a4..1aad72db9 100644 --- a/griptape/utils/__init__.py +++ b/griptape/utils/__init__.py @@ -17,7 +17,6 @@ from .constants import Constants as constants from .load_artifact_from_memory import load_artifact_from_memory from .deprecation import deprecation_warn -from .play_audio import play_audio def minify_json(value: str) -> str: diff --git a/griptape/utils/play_audio.py b/griptape/utils/play_audio.py deleted file mode 100644 index 73afc58de..000000000 --- a/griptape/utils/play_audio.py +++ /dev/null @@ -1,9 +0,0 @@ -import elevenlabs - -from griptape.artifacts import AudioArtifact - - -def play_audio(artifact: AudioArtifact) -> AudioArtifact: - elevenlabs.play(artifact.value) - - return artifact diff --git a/pyproject.toml b/pyproject.toml index 1c1c0bc2e..caea403f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,13 +51,13 @@ playwright = {version = "^1.42", optional = true} beautifulsoup4 = {version = "^4.12.3", optional = true} markdownify = {version = "^0.11.6", optional = true} voyageai = {version = "^0.2.1", optional = true} +elevenlabs = {version = "^1.1.2", optional = true} # loaders pandas = {version = "^1.3", optional = true} pypdf = {version = "^3.9", optional = true} pillow = {version = "^10.2.0", optional = true} mail-parser = {version = "^3.15.0", optional = true} -elevenlabs = "^1.1.2" [tool.poetry.extras] drivers-prompt-cohere = ["cohere"] @@ -121,6 +121,7 @@ all = [ "beautifulsoup4", "markdownify", "voyageai", + "elevenlabs", # loaders "pandas", diff --git a/tests/unit/drivers/text_to_speech/__init__.py b/tests/unit/drivers/text_to_speech/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/drivers/text_to_speech/test_elevenlabs_audio_generation_driver.py b/tests/unit/drivers/text_to_speech/test_elevenlabs_audio_generation_driver.py new file mode 100644 index 000000000..9ae45a06b --- /dev/null +++ b/tests/unit/drivers/text_to_speech/test_elevenlabs_audio_generation_driver.py @@ -0,0 +1,20 @@ +import pytest +from unittest.mock import Mock +from griptape.drivers import ElevenLabsTextToSpeechDriver + + +class TestElevenLabsTextToSpeechDriver: + @pytest.fixture + def driver(self): + return ElevenLabsTextToSpeechDriver(model="model", client=Mock(), voice="voice", api_key="key") + + def test_init(self, driver): + assert driver + + def test_try_text_to_audio(self, driver): + driver.client.generate.return_value = [b"audio data"] + + audio_artifact = driver.try_text_to_audio(prompts=["test prompt"]) + + assert audio_artifact.value == b"audio data" + assert audio_artifact.format == "mpeg" diff --git a/tests/unit/mixins/test_image_artifact_file_output_mixin.py b/tests/unit/mixins/test_image_artifact_file_output_mixin.py index 20e644191..69a2f1d71 100644 --- a/tests/unit/mixins/test_image_artifact_file_output_mixin.py +++ b/tests/unit/mixins/test_image_artifact_file_output_mixin.py @@ -4,12 +4,12 @@ import pytest from griptape.artifacts import ImageArtifact -from griptape.mixins import MediaArtifactFileOutputMixin +from griptape.mixins import BlobArtifactFileOutputMixin class TestMediaArtifactFileOutputMixin: def test_no_output(self): - class Test(MediaArtifactFileOutputMixin): + class Test(BlobArtifactFileOutputMixin): pass assert Test().output_file is None @@ -18,7 +18,7 @@ class Test(MediaArtifactFileOutputMixin): def test_output_file(self): artifact = ImageArtifact(name="test.png", value=b"test", height=1, width=1, format="png") - class Test(MediaArtifactFileOutputMixin): + class Test(BlobArtifactFileOutputMixin): def run(self): self._write_to_file(artifact) @@ -33,7 +33,7 @@ def run(self): def test_output_dir(self): artifact = ImageArtifact(name="test.png", value=b"test", height=1, width=1, format="png") - class Test(MediaArtifactFileOutputMixin): + class Test(BlobArtifactFileOutputMixin): def run(self): self._write_to_file(artifact) @@ -46,7 +46,7 @@ def run(self): assert os.path.exists(os.path.join(outdir, artifact.name)) def test_output_file_and_dir(self): - class Test(MediaArtifactFileOutputMixin): + class Test(BlobArtifactFileOutputMixin): pass outfile = "test.txt" diff --git a/tests/unit/tasks/test_text_to_speech_task.py b/tests/unit/tasks/test_text_to_speech_task.py new file mode 100644 index 000000000..f30893e8a --- /dev/null +++ b/tests/unit/tasks/test_text_to_speech_task.py @@ -0,0 +1,30 @@ +from unittest.mock import Mock + +from griptape.artifacts import TextArtifact +from griptape.engines import TextToSpeechEngine +from griptape.structures import Agent +from griptape.tasks import BaseTask, TextToSpeechTask +from tests.mocks.mock_structure_config import MockStructureConfig + + +class TestTextToSpeechTask: + def test_string_input(self): + task = TextToSpeechTask("string input", text_to_speech_engine=Mock()) + + assert task.input.value == "string input" + + def test_callable_input(self): + input_artifact = TextArtifact("some text input") + + def callable(task: BaseTask) -> TextArtifact: + return input_artifact + + task = TextToSpeechTask(callable, text_to_speech_engine=Mock()) + + assert task.input == input_artifact + + def test_config_text_to_speech_engine(self): + task = TextToSpeechTask("foo bar") + Agent(config=MockStructureConfig()).add_task(task) + + assert isinstance(task.text_to_speech_engine, TextToSpeechEngine) diff --git a/tests/unit/tools/test_text_to_speech_client.py b/tests/unit/tools/test_text_to_speech_client.py new file mode 100644 index 000000000..881b1234d --- /dev/null +++ b/tests/unit/tools/test_text_to_speech_client.py @@ -0,0 +1,40 @@ +import os +import tempfile +import uuid +from unittest.mock import Mock + +import pytest + +from griptape.tools.text_to_speech_client.tool import TextToSpeechClient + + +class TestTextToSpeechClient: + @pytest.fixture + def text_to_speech_engine(self) -> Mock: + return Mock() + + @pytest.fixture + def text_to_speech_client(self, text_to_speech_engine) -> TextToSpeechClient: + return TextToSpeechClient(engine=text_to_speech_engine) + + def test_validate_output_configs(self, text_to_speech_engine) -> None: + with pytest.raises(ValueError): + TextToSpeechClient(engine=text_to_speech_engine, output_dir="test", output_file="test") + + def test_text_to_speech(self, text_to_speech_client) -> None: + text_to_speech_client.engine.run.return_value = Mock(value=b"audio data", format="mp3") + + audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}}) + + assert audio_artifact + + def test_text_to_speech_with_outfile(self, text_to_speech_engine) -> None: + outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.mp3" + text_to_speech_client = TextToSpeechClient(engine=text_to_speech_engine, output_file=outfile) + + text_to_speech_client.engine.run.return_value = Mock(value=b"audio data", format="mp3") # pyright: ignore + + audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}}) + + assert audio_artifact + assert os.path.exists(outfile)