Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Oct 25, 2024
1 parent 085a05d commit d72af55
Show file tree
Hide file tree
Showing 15 changed files with 94 additions and 89 deletions.
2 changes: 1 addition & 1 deletion docs/griptape-framework/drivers/text-to-speech-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ search:

## Overview

[Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md) are used by [Text To Speech Engines](../engines/audio-engines.md) to build and execute API calls to audio generation models.
[Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md) are used to build and execute API calls to audio generation models.

Provide a Driver when building an [Engine](../engines/audio-engines.md), then pass it to a [Tool](../tools/index.md) for use by an [Agent](../structures/agents.md):

Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ This Task is useful for orchestrating multiple specialized Structures in a singl

## Text to Speech Task

This Task enables Structures to synthesize speech from text using [Text to Speech Engines](../../reference/griptape/engines/audio/text_to_speech_engine.md) and [Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md).
This Task enables Structures to synthesize speech from text using [Text to Speech Engines](../../reference/griptape/engines/audio/text_to_speech_driver.md) and [Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md).

```python
--8<-- "docs/griptape-framework/structures/src/tasks_17.py"
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-tools/official-tools/text-to-speech-tool.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Text To Speech Tool

This Tool enables LLMs to synthesize speech from text using [Text to Speech Engines](../../reference/griptape/engines/audio/text_to_speech_engine.md) and [Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md).
This Tool enables LLMs to synthesize speech from text using [Text to Speech Engines](../../reference/griptape/engines/audio/text_to_speech_driver.md) and [Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md).

```python
--8<-- "docs/griptape-tools/official-tools/src/text_to_speech_tool_1.py"
Expand Down
2 changes: 1 addition & 1 deletion griptape/tasks/variation_image_generation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class VariationImageGenerationTask(BaseImageGenerationTask):
"""

image_generation_driver: BaseImageGenerationDriver = field(
default=Factory(lambda: Defaults.drivers_config.base_image_generation_driver),
default=Factory(lambda: Defaults.drivers_config.image_generation_driver),
kw_only=True,
)
_input: Union[tuple[Union[str, TextArtifact], ImageArtifact], Callable[[BaseTask], ListArtifact], ListArtifact] = (
Expand Down
2 changes: 1 addition & 1 deletion griptape/tools/variation_image_generation/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def image_variation_from_memory(self, params: dict[str, dict[str, str]]) -> Imag
def _generate_variation(
self, prompt: str, negative_prompt: str, image_artifact: ImageArtifact
) -> ImageArtifact | ErrorArtifact:
output_artifact = self.image_generation_driver.try_image_variation(
output_artifact = self.image_generation_driver.run_image_variation(
prompts=[prompt], negative_prompts=[negative_prompt], image=image_artifact
)

Expand Down
22 changes: 11 additions & 11 deletions tests/unit/tasks/test_audio_transcription_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from griptape.artifacts import AudioArtifact, TextArtifact
from griptape.engines import AudioTranscriptionEngine
from griptape.drivers import BaseAudioTranscriptionDriver
from griptape.structures import Agent, Pipeline
from griptape.tasks import AudioTranscriptionTask, BaseTask

Expand All @@ -14,32 +14,32 @@ def audio_artifact(self):
return AudioArtifact(value=b"audio data", format="mp3")

@pytest.fixture()
def audio_transcription_engine(self):
def audio_transcription_driver(self):
return Mock()

def test_audio_input(self, audio_artifact, audio_transcription_engine):
task = AudioTranscriptionTask(audio_artifact, audio_transcription_engine=audio_transcription_engine)
def test_audio_input(self, audio_artifact, audio_transcription_driver):
task = AudioTranscriptionTask(audio_artifact, audio_transcription_driver=audio_transcription_driver)

assert task.input.value == audio_artifact.value

def test_callable_input(self, audio_artifact, audio_transcription_engine):
def test_callable_input(self, audio_artifact, audio_transcription_driver):
def callable_input(task: BaseTask) -> AudioArtifact:
return audio_artifact

task = AudioTranscriptionTask(callable_input, audio_transcription_engine=audio_transcription_engine)
task = AudioTranscriptionTask(callable_input, audio_transcription_driver=audio_transcription_driver)

assert task.input == audio_artifact

def test_config_audio_transcription_engine(self, audio_artifact):
def test_config_audio_transcription_driver(self, audio_artifact):
task = AudioTranscriptionTask(audio_artifact)
Agent().add_task(task)

assert isinstance(task.audio_transcription_engine, AudioTranscriptionEngine)
assert isinstance(task.audio_transcription_driver, BaseAudioTranscriptionDriver)

def test_run(self, audio_artifact, audio_transcription_engine):
audio_transcription_engine.run.return_value = TextArtifact("mock transcription")
def test_run(self, audio_artifact, audio_transcription_driver):
audio_transcription_driver.run.return_value = TextArtifact("mock transcription")

task = AudioTranscriptionTask(audio_artifact, audio_transcription_engine=audio_transcription_engine)
task = AudioTranscriptionTask(audio_artifact, audio_transcription_driver=audio_transcription_driver)
pipeline = Pipeline()
pipeline.add_task(task)

Expand Down
18 changes: 1 addition & 17 deletions tests/unit/tasks/test_base_image_generation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,13 @@

class TestBaseImageGenerationTask:
def test_validate_negative_rulesets(self) -> None:
with pytest.raises(ValueError):
MockImageGenerationTask(
TextArtifact("some input"),
negative_rulesets=[Ruleset(name="Negative Ruleset", rules=[Rule(value="Negative Rule")])],
negative_rules=[Rule(value="Negative Rule")],
output_dir="some/dir",
)

assert MockImageGenerationTask(
TextArtifact("some input"),
negative_rulesets=[Ruleset(name="Negative Ruleset", rules=[Rule(value="Negative Rule")])],
output_dir="some/dir",
)

def test_validate_negative_rules(self) -> None:
with pytest.raises(ValueError):
MockImageGenerationTask(
TextArtifact("some input"),
negative_rulesets=[Ruleset(name="Negative Ruleset", rules=[Rule(value="Negative Rule")])],
negative_rules=[Rule(value="Negative Rule")],
output_dir="some/dir",
)

assert MockImageGenerationTask(
TextArtifact("some input"), negative_rules=[Rule(value="Negative Rule")], output_dir="some/dir"
)
Expand All @@ -46,7 +30,7 @@ def test_negative_rulesets_from_rules(self) -> None:

task = MockImageGenerationTask(TextArtifact("some input"), negative_rules=[rule], output_dir="some/dir")

assert task.negative_rulesets[0].name == task.NEGATIVE_RULESET_NAME
assert task.negative_rulesets[0].name == task.DEFAULT_NEGATIVE_RULESET_NAME
assert task.negative_rulesets[0].rules[0] == rule

def test_validate_output_dir(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/tasks/test_image_query_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TestImageQueryTask:
@pytest.fixture()
def image_query_driver(self) -> Mock:
mock = Mock()
mock.run.return_value = TextArtifact("image")
mock.query.return_value = TextArtifact("image")

return mock

Expand Down
22 changes: 11 additions & 11 deletions tests/unit/tasks/test_text_to_speech_task.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from unittest.mock import Mock

from griptape.artifacts import AudioArtifact, TextArtifact
from griptape.engines import TextToSpeechEngine
from griptape.drivers.text_to_speech.base_text_to_speech_driver import BaseTextToSpeechDriver
from griptape.structures import Agent, Pipeline
from griptape.tasks import BaseTask, TextToSpeechTask


class TestTextToSpeechTask:
def test_string_input(self):
task = TextToSpeechTask("string input", text_to_speech_engine=Mock())
task = TextToSpeechTask("string input", text_to_speech_driver=Mock())

assert task.input.value == "string input"

Expand All @@ -18,27 +18,27 @@ def test_callable_input(self):
def callable_input(task: BaseTask) -> TextArtifact:
return input_artifact

task = TextToSpeechTask(callable_input, text_to_speech_engine=Mock())
task = TextToSpeechTask(callable_input, text_to_speech_driver=Mock())

assert task.input == input_artifact

def test_config_text_to_speech_engine(self):
def test_config_text_to_speech_driver(self):
task = TextToSpeechTask("foo bar")
Agent().add_task(task)

assert isinstance(task.text_to_speech_engine, TextToSpeechEngine)
assert isinstance(task.text_to_speech_driver, BaseTextToSpeechDriver)

def test_calls(self):
text_to_speech_engine = Mock()
text_to_speech_engine.run.return_value = AudioArtifact(b"audio content", format="mp3")
text_to_speech_driver = Mock()
text_to_speech_driver.run_text_to_audio.return_value = AudioArtifact(b"audio content", format="mp3")

assert TextToSpeechTask("test", text_to_speech_engine=text_to_speech_engine).run().value == b"audio content"
assert TextToSpeechTask("test", text_to_speech_driver=text_to_speech_driver).run().value == b"audio content"

def test_run(self):
text_to_speech_engine = Mock()
text_to_speech_engine.run.return_value = AudioArtifact(b"audio content", format="mp3")
text_to_speech_driver = Mock()
text_to_speech_driver.run_text_to_audio.return_value = AudioArtifact(b"audio content", format="mp3")

task = TextToSpeechTask("some text", text_to_speech_engine=text_to_speech_engine)
task = TextToSpeechTask("some text", text_to_speech_driver=text_to_speech_driver)
pipeline = Pipeline()
pipeline.add_task(task)

Expand Down
16 changes: 9 additions & 7 deletions tests/unit/tools/test_inpainting_image_generation_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ def image_loader(self) -> Mock:

@pytest.fixture()
def image_generator(self, image_generation_driver, image_loader) -> InpaintingImageGenerationTool:
return InpaintingImageGenerationTool(engine=image_generation_driver, image_loader=image_loader)
return InpaintingImageGenerationTool(image_generation_driver=image_generation_driver, image_loader=image_loader)

def test_validate_output_configs(self, image_generation_driver) -> None:
with pytest.raises(ValueError):
InpaintingImageGenerationTool(engine=image_generation_driver, output_dir="test", output_file="test")
InpaintingImageGenerationTool(
image_generation_driver=image_generation_driver, output_dir="test", output_file="test"
)

def test_image_inpainting(self, image_generator, path_from_resource_path) -> None:
image_generator.engine.run.return_value = Mock(
image_generator.image_generation_driver.run_image_inpainting.return_value = Mock(
value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt"
)

Expand All @@ -56,10 +58,10 @@ def test_image_inpainting_with_outfile(
) -> None:
outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.png"
image_generator = InpaintingImageGenerationTool(
engine=image_generation_driver, output_file=outfile, image_loader=image_loader
image_generation_driver=image_generation_driver, output_file=outfile, image_loader=image_loader
)

image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess]
image_generator.image_generation_driver.run_image_inpainting.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess]
value=b"image data", format="png", width=512, height=512
)

Expand All @@ -78,12 +80,12 @@ def test_image_inpainting_with_outfile(
assert os.path.exists(outfile)

def test_image_inpainting_from_memory(self, image_generation_driver, image_artifact):
image_generator = InpaintingImageGenerationTool(engine=image_generation_driver)
image_generator = InpaintingImageGenerationTool(image_generation_driver=image_generation_driver)
memory = Mock()
memory.load_artifacts = Mock(return_value=[image_artifact])
image_generator.find_input_memory = Mock(return_value=memory)

image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess]
image_generator.image_generation_driver.run_image_inpainting.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess]
value=b"image data", format="png", width=512, height=512
)

Expand Down
18 changes: 11 additions & 7 deletions tests/unit/tools/test_outpainting_image_variation_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,18 @@ def image_loader(self, image_artifact) -> Mock:

@pytest.fixture()
def image_generator(self, image_generation_driver, image_loader) -> OutpaintingImageGenerationTool:
return OutpaintingImageGenerationTool(engine=image_generation_driver, image_loader=image_loader)
return OutpaintingImageGenerationTool(
image_generation_driver=image_generation_driver, image_loader=image_loader
)

def test_validate_output_configs(self, image_generation_driver) -> None:
with pytest.raises(ValueError):
OutpaintingImageGenerationTool(engine=image_generation_driver, output_dir="test", output_file="test")
OutpaintingImageGenerationTool(
image_generation_driver=image_generation_driver, output_dir="test", output_file="test"
)

def test_image_outpainting(self, image_generator, path_from_resource_path) -> None:
image_generator.engine.run.return_value = ImageArtifact(
image_generator.image_generation_driver.run_image_variation.return_value = ImageArtifact(
value=b"image data", format="png", width=512, height=512
)

Expand All @@ -56,10 +60,10 @@ def test_image_outpainting_with_outfile(
) -> None:
outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.png"
image_generator = OutpaintingImageGenerationTool(
engine=image_generation_driver, output_file=outfile, image_loader=image_loader
image_generation_driver=image_generation_driver, output_file=outfile, image_loader=image_loader
)

image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess]
image_generator.image_generation_driver.run_image_outpainting.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess]
value=b"image data", format="png", width=512, height=512
)

Expand All @@ -78,12 +82,12 @@ def test_image_outpainting_with_outfile(
assert os.path.exists(outfile)

def test_image_outpainting_from_memory(self, image_generation_driver, image_artifact):
image_generator = OutpaintingImageGenerationTool(engine=image_generation_driver)
image_generator = OutpaintingImageGenerationTool(image_generation_driver=image_generation_driver)
memory = Mock()
memory.load_artifacts = Mock(return_value=[image_artifact])
image_generator.find_input_memory = Mock(return_value=memory)

image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess]
image_generator.image_generation_driver.run_image_variation.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess]
value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt"
)

Expand Down
14 changes: 9 additions & 5 deletions tests/unit/tools/test_prompt_image_generation_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@ def image_generation_driver(self) -> Mock:

@pytest.fixture()
def image_generator(self, image_generation_driver) -> PromptImageGenerationTool:
return PromptImageGenerationTool(engine=image_generation_driver)
return PromptImageGenerationTool(image_generation_driver=image_generation_driver)

def test_validate_output_configs(self, image_generation_driver) -> None:
with pytest.raises(ValueError):
PromptImageGenerationTool(engine=image_generation_driver, output_dir="test", output_file="test")
PromptImageGenerationTool(
image_generation_driver=image_generation_driver, output_dir="test", output_file="test"
)

def test_generate_image(self, image_generator) -> None:
image_generator.engine.run.return_value = Mock(
image_generator.image_generation_driver.run_text_to_image.return_value = Mock(
value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt"
)

Expand All @@ -35,9 +37,11 @@ def test_generate_image(self, image_generator) -> None:

def test_generate_image_with_outfile(self, image_generation_driver) -> None:
outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.png"
image_generator = PromptImageGenerationTool(engine=image_generation_driver, output_file=outfile)
image_generator = PromptImageGenerationTool(
image_generation_driver=image_generation_driver, output_file=outfile
)

image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess]
image_generator.image_generation_driver.run_text_to_image.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess]
value=b"image data", format="png", width=512, height=512
)

Expand Down
22 changes: 13 additions & 9 deletions tests/unit/tools/test_text_to_speech_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,33 @@

class TestTextToSpeechTool:
@pytest.fixture()
def text_to_speech_engine(self) -> Mock:
def text_to_speech_driver(self) -> Mock:
return Mock()

@pytest.fixture()
def text_to_speech_client(self, text_to_speech_engine) -> TextToSpeechTool:
return TextToSpeechTool(engine=text_to_speech_engine)
def text_to_speech_client(self, text_to_speech_driver) -> TextToSpeechTool:
return TextToSpeechTool(text_to_speech_driver=text_to_speech_driver)

def test_validate_output_configs(self, text_to_speech_engine) -> None:
def test_validate_output_configs(self, text_to_speech_driver) -> None:
with pytest.raises(ValueError):
TextToSpeechTool(engine=text_to_speech_engine, output_dir="test", output_file="test")
TextToSpeechTool(text_to_speech_driver=text_to_speech_driver, output_dir="test", output_file="test")

def test_text_to_speech(self, text_to_speech_client) -> None:
text_to_speech_client.engine.run.return_value = Mock(value=b"audio data", format="mp3")
text_to_speech_client.text_to_speech_driver.run_text_to_audio.return_value = Mock(
value=b"audio data", format="mp3"
)

audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}})

assert audio_artifact

def test_text_to_speech_with_outfile(self, text_to_speech_engine) -> None:
def test_text_to_speech_with_outfile(self, text_to_speech_driver) -> None:
outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.mp3"
text_to_speech_client = TextToSpeechTool(engine=text_to_speech_engine, output_file=outfile)
text_to_speech_client = TextToSpeechTool(text_to_speech_driver=text_to_speech_driver, output_file=outfile)

text_to_speech_client.engine.run.return_value = AudioArtifact(value=b"audio data", format="mp3") # pyright: ignore[reportFunctionMemberAccess]
text_to_speech_client.text_to_speech_driver.run_text_to_audio.return_value = AudioArtifact( # pyright: ignore[reportFunctionMemberAccess]
value=b"audio data", format="mp3"
)

audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}})

Expand Down
Loading

0 comments on commit d72af55

Please sign in to comment.