From d72af554c1b918cfa3a56d6667e96f47bd73480c Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 24 Oct 2024 16:07:03 -0700 Subject: [PATCH] Fix tests --- .../drivers/text-to-speech-drivers.md | 2 +- docs/griptape-framework/structures/tasks.md | 2 +- .../official-tools/text-to-speech-tool.md | 2 +- .../tasks/variation_image_generation_task.py | 2 +- .../tools/variation_image_generation/tool.py | 2 +- .../tasks/test_audio_transcription_task.py | 22 +++++++++---------- .../tasks/test_base_image_generation_task.py | 18 +-------------- tests/unit/tasks/test_image_query_task.py | 2 +- tests/unit/tasks/test_text_to_speech_task.py | 22 +++++++++---------- .../test_inpainting_image_generation_tool.py | 16 ++++++++------ .../test_outpainting_image_variation_tool.py | 18 +++++++++------ .../test_prompt_image_generation_tool.py | 14 +++++++----- tests/unit/tools/test_text_to_speech_tool.py | 22 +++++++++++-------- tests/unit/tools/test_transcription_tool.py | 22 +++++++++++-------- .../test_variation_image_generation_tool.py | 17 ++++++++------ 15 files changed, 94 insertions(+), 89 deletions(-) diff --git a/docs/griptape-framework/drivers/text-to-speech-drivers.md b/docs/griptape-framework/drivers/text-to-speech-drivers.md index 4ea1c574f2..55e3437527 100644 --- a/docs/griptape-framework/drivers/text-to-speech-drivers.md +++ b/docs/griptape-framework/drivers/text-to-speech-drivers.md @@ -5,7 +5,7 @@ search: ## Overview -[Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md) are used by [Text To Speech Engines](../engines/audio-engines.md) to build and execute API calls to audio generation models. +[Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md) are used to build and execute API calls to audio generation models. Provide a Driver when building an [Engine](../engines/audio-engines.md), then pass it to a [Tool](../tools/index.md) for use by an [Agent](../structures/agents.md): diff --git a/docs/griptape-framework/structures/tasks.md b/docs/griptape-framework/structures/tasks.md index 7472010205..85ed0a4324 100644 --- a/docs/griptape-framework/structures/tasks.md +++ b/docs/griptape-framework/structures/tasks.md @@ -361,7 +361,7 @@ This Task is useful for orchestrating multiple specialized Structures in a singl ## Text to Speech Task -This Task enables Structures to synthesize speech from text using [Text to Speech Engines](../../reference/griptape/engines/audio/text_to_speech_engine.md) and [Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md). +This Task enables Structures to synthesize speech from text using [Text to Speech Engines](../../reference/griptape/engines/audio/text_to_speech_driver.md) and [Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md). ```python --8<-- "docs/griptape-framework/structures/src/tasks_17.py" diff --git a/docs/griptape-tools/official-tools/text-to-speech-tool.md b/docs/griptape-tools/official-tools/text-to-speech-tool.md index ac3f54f8e8..07221a61ee 100644 --- a/docs/griptape-tools/official-tools/text-to-speech-tool.md +++ b/docs/griptape-tools/official-tools/text-to-speech-tool.md @@ -1,6 +1,6 @@ # Text To Speech Tool -This Tool enables LLMs to synthesize speech from text using [Text to Speech Engines](../../reference/griptape/engines/audio/text_to_speech_engine.md) and [Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md). +This Tool enables LLMs to synthesize speech from text using [Text to Speech Engines](../../reference/griptape/engines/audio/text_to_speech_driver.md) and [Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md). ```python --8<-- "docs/griptape-tools/official-tools/src/text_to_speech_tool_1.py" diff --git a/griptape/tasks/variation_image_generation_task.py b/griptape/tasks/variation_image_generation_task.py index 35119992b6..4a9b417857 100644 --- a/griptape/tasks/variation_image_generation_task.py +++ b/griptape/tasks/variation_image_generation_task.py @@ -32,7 +32,7 @@ class VariationImageGenerationTask(BaseImageGenerationTask): """ image_generation_driver: BaseImageGenerationDriver = field( - default=Factory(lambda: Defaults.drivers_config.base_image_generation_driver), + default=Factory(lambda: Defaults.drivers_config.image_generation_driver), kw_only=True, ) _input: Union[tuple[Union[str, TextArtifact], ImageArtifact], Callable[[BaseTask], ListArtifact], ListArtifact] = ( diff --git a/griptape/tools/variation_image_generation/tool.py b/griptape/tools/variation_image_generation/tool.py index fa3a3d5a15..8194258646 100644 --- a/griptape/tools/variation_image_generation/tool.py +++ b/griptape/tools/variation_image_generation/tool.py @@ -86,7 +86,7 @@ def image_variation_from_memory(self, params: dict[str, dict[str, str]]) -> Imag def _generate_variation( self, prompt: str, negative_prompt: str, image_artifact: ImageArtifact ) -> ImageArtifact | ErrorArtifact: - output_artifact = self.image_generation_driver.try_image_variation( + output_artifact = self.image_generation_driver.run_image_variation( prompts=[prompt], negative_prompts=[negative_prompt], image=image_artifact ) diff --git a/tests/unit/tasks/test_audio_transcription_task.py b/tests/unit/tasks/test_audio_transcription_task.py index 33405ad104..7f8f04f810 100644 --- a/tests/unit/tasks/test_audio_transcription_task.py +++ b/tests/unit/tasks/test_audio_transcription_task.py @@ -3,7 +3,7 @@ import pytest from griptape.artifacts import AudioArtifact, TextArtifact -from griptape.engines import AudioTranscriptionEngine +from griptape.drivers import BaseAudioTranscriptionDriver from griptape.structures import Agent, Pipeline from griptape.tasks import AudioTranscriptionTask, BaseTask @@ -14,32 +14,32 @@ def audio_artifact(self): return AudioArtifact(value=b"audio data", format="mp3") @pytest.fixture() - def audio_transcription_engine(self): + def audio_transcription_driver(self): return Mock() - def test_audio_input(self, audio_artifact, audio_transcription_engine): - task = AudioTranscriptionTask(audio_artifact, audio_transcription_engine=audio_transcription_engine) + def test_audio_input(self, audio_artifact, audio_transcription_driver): + task = AudioTranscriptionTask(audio_artifact, audio_transcription_driver=audio_transcription_driver) assert task.input.value == audio_artifact.value - def test_callable_input(self, audio_artifact, audio_transcription_engine): + def test_callable_input(self, audio_artifact, audio_transcription_driver): def callable_input(task: BaseTask) -> AudioArtifact: return audio_artifact - task = AudioTranscriptionTask(callable_input, audio_transcription_engine=audio_transcription_engine) + task = AudioTranscriptionTask(callable_input, audio_transcription_driver=audio_transcription_driver) assert task.input == audio_artifact - def test_config_audio_transcription_engine(self, audio_artifact): + def test_config_audio_transcription_driver(self, audio_artifact): task = AudioTranscriptionTask(audio_artifact) Agent().add_task(task) - assert isinstance(task.audio_transcription_engine, AudioTranscriptionEngine) + assert isinstance(task.audio_transcription_driver, BaseAudioTranscriptionDriver) - def test_run(self, audio_artifact, audio_transcription_engine): - audio_transcription_engine.run.return_value = TextArtifact("mock transcription") + def test_run(self, audio_artifact, audio_transcription_driver): + audio_transcription_driver.run.return_value = TextArtifact("mock transcription") - task = AudioTranscriptionTask(audio_artifact, audio_transcription_engine=audio_transcription_engine) + task = AudioTranscriptionTask(audio_artifact, audio_transcription_driver=audio_transcription_driver) pipeline = Pipeline() pipeline.add_task(task) diff --git a/tests/unit/tasks/test_base_image_generation_task.py b/tests/unit/tasks/test_base_image_generation_task.py index a01f506f14..c4272d78c0 100644 --- a/tests/unit/tasks/test_base_image_generation_task.py +++ b/tests/unit/tasks/test_base_image_generation_task.py @@ -7,14 +7,6 @@ class TestBaseImageGenerationTask: def test_validate_negative_rulesets(self) -> None: - with pytest.raises(ValueError): - MockImageGenerationTask( - TextArtifact("some input"), - negative_rulesets=[Ruleset(name="Negative Ruleset", rules=[Rule(value="Negative Rule")])], - negative_rules=[Rule(value="Negative Rule")], - output_dir="some/dir", - ) - assert MockImageGenerationTask( TextArtifact("some input"), negative_rulesets=[Ruleset(name="Negative Ruleset", rules=[Rule(value="Negative Rule")])], @@ -22,14 +14,6 @@ def test_validate_negative_rulesets(self) -> None: ) def test_validate_negative_rules(self) -> None: - with pytest.raises(ValueError): - MockImageGenerationTask( - TextArtifact("some input"), - negative_rulesets=[Ruleset(name="Negative Ruleset", rules=[Rule(value="Negative Rule")])], - negative_rules=[Rule(value="Negative Rule")], - output_dir="some/dir", - ) - assert MockImageGenerationTask( TextArtifact("some input"), negative_rules=[Rule(value="Negative Rule")], output_dir="some/dir" ) @@ -46,7 +30,7 @@ def test_negative_rulesets_from_rules(self) -> None: task = MockImageGenerationTask(TextArtifact("some input"), negative_rules=[rule], output_dir="some/dir") - assert task.negative_rulesets[0].name == task.NEGATIVE_RULESET_NAME + assert task.negative_rulesets[0].name == task.DEFAULT_NEGATIVE_RULESET_NAME assert task.negative_rulesets[0].rules[0] == rule def test_validate_output_dir(self) -> None: diff --git a/tests/unit/tasks/test_image_query_task.py b/tests/unit/tasks/test_image_query_task.py index ea7ba100fc..e9321f2fba 100644 --- a/tests/unit/tasks/test_image_query_task.py +++ b/tests/unit/tasks/test_image_query_task.py @@ -13,7 +13,7 @@ class TestImageQueryTask: @pytest.fixture() def image_query_driver(self) -> Mock: mock = Mock() - mock.run.return_value = TextArtifact("image") + mock.query.return_value = TextArtifact("image") return mock diff --git a/tests/unit/tasks/test_text_to_speech_task.py b/tests/unit/tasks/test_text_to_speech_task.py index 44348fef00..55a2abc968 100644 --- a/tests/unit/tasks/test_text_to_speech_task.py +++ b/tests/unit/tasks/test_text_to_speech_task.py @@ -1,14 +1,14 @@ from unittest.mock import Mock from griptape.artifacts import AudioArtifact, TextArtifact -from griptape.engines import TextToSpeechEngine +from griptape.drivers.text_to_speech.base_text_to_speech_driver import BaseTextToSpeechDriver from griptape.structures import Agent, Pipeline from griptape.tasks import BaseTask, TextToSpeechTask class TestTextToSpeechTask: def test_string_input(self): - task = TextToSpeechTask("string input", text_to_speech_engine=Mock()) + task = TextToSpeechTask("string input", text_to_speech_driver=Mock()) assert task.input.value == "string input" @@ -18,27 +18,27 @@ def test_callable_input(self): def callable_input(task: BaseTask) -> TextArtifact: return input_artifact - task = TextToSpeechTask(callable_input, text_to_speech_engine=Mock()) + task = TextToSpeechTask(callable_input, text_to_speech_driver=Mock()) assert task.input == input_artifact - def test_config_text_to_speech_engine(self): + def test_config_text_to_speech_driver(self): task = TextToSpeechTask("foo bar") Agent().add_task(task) - assert isinstance(task.text_to_speech_engine, TextToSpeechEngine) + assert isinstance(task.text_to_speech_driver, BaseTextToSpeechDriver) def test_calls(self): - text_to_speech_engine = Mock() - text_to_speech_engine.run.return_value = AudioArtifact(b"audio content", format="mp3") + text_to_speech_driver = Mock() + text_to_speech_driver.run_text_to_audio.return_value = AudioArtifact(b"audio content", format="mp3") - assert TextToSpeechTask("test", text_to_speech_engine=text_to_speech_engine).run().value == b"audio content" + assert TextToSpeechTask("test", text_to_speech_driver=text_to_speech_driver).run().value == b"audio content" def test_run(self): - text_to_speech_engine = Mock() - text_to_speech_engine.run.return_value = AudioArtifact(b"audio content", format="mp3") + text_to_speech_driver = Mock() + text_to_speech_driver.run_text_to_audio.return_value = AudioArtifact(b"audio content", format="mp3") - task = TextToSpeechTask("some text", text_to_speech_engine=text_to_speech_engine) + task = TextToSpeechTask("some text", text_to_speech_driver=text_to_speech_driver) pipeline = Pipeline() pipeline.add_task(task) diff --git a/tests/unit/tools/test_inpainting_image_generation_tool.py b/tests/unit/tools/test_inpainting_image_generation_tool.py index d085643b43..16b9fb8df6 100644 --- a/tests/unit/tools/test_inpainting_image_generation_tool.py +++ b/tests/unit/tools/test_inpainting_image_generation_tool.py @@ -27,14 +27,16 @@ def image_loader(self) -> Mock: @pytest.fixture() def image_generator(self, image_generation_driver, image_loader) -> InpaintingImageGenerationTool: - return InpaintingImageGenerationTool(engine=image_generation_driver, image_loader=image_loader) + return InpaintingImageGenerationTool(image_generation_driver=image_generation_driver, image_loader=image_loader) def test_validate_output_configs(self, image_generation_driver) -> None: with pytest.raises(ValueError): - InpaintingImageGenerationTool(engine=image_generation_driver, output_dir="test", output_file="test") + InpaintingImageGenerationTool( + image_generation_driver=image_generation_driver, output_dir="test", output_file="test" + ) def test_image_inpainting(self, image_generator, path_from_resource_path) -> None: - image_generator.engine.run.return_value = Mock( + image_generator.image_generation_driver.run_image_inpainting.return_value = Mock( value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" ) @@ -56,10 +58,10 @@ def test_image_inpainting_with_outfile( ) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.png" image_generator = InpaintingImageGenerationTool( - engine=image_generation_driver, output_file=outfile, image_loader=image_loader + image_generation_driver=image_generation_driver, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + image_generator.image_generation_driver.run_image_inpainting.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] value=b"image data", format="png", width=512, height=512 ) @@ -78,12 +80,12 @@ def test_image_inpainting_with_outfile( assert os.path.exists(outfile) def test_image_inpainting_from_memory(self, image_generation_driver, image_artifact): - image_generator = InpaintingImageGenerationTool(engine=image_generation_driver) + image_generator = InpaintingImageGenerationTool(image_generation_driver=image_generation_driver) memory = Mock() memory.load_artifacts = Mock(return_value=[image_artifact]) image_generator.find_input_memory = Mock(return_value=memory) - image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + image_generator.image_generation_driver.run_image_inpainting.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] value=b"image data", format="png", width=512, height=512 ) diff --git a/tests/unit/tools/test_outpainting_image_variation_tool.py b/tests/unit/tools/test_outpainting_image_variation_tool.py index 9902d49715..2467395607 100644 --- a/tests/unit/tools/test_outpainting_image_variation_tool.py +++ b/tests/unit/tools/test_outpainting_image_variation_tool.py @@ -27,14 +27,18 @@ def image_loader(self, image_artifact) -> Mock: @pytest.fixture() def image_generator(self, image_generation_driver, image_loader) -> OutpaintingImageGenerationTool: - return OutpaintingImageGenerationTool(engine=image_generation_driver, image_loader=image_loader) + return OutpaintingImageGenerationTool( + image_generation_driver=image_generation_driver, image_loader=image_loader + ) def test_validate_output_configs(self, image_generation_driver) -> None: with pytest.raises(ValueError): - OutpaintingImageGenerationTool(engine=image_generation_driver, output_dir="test", output_file="test") + OutpaintingImageGenerationTool( + image_generation_driver=image_generation_driver, output_dir="test", output_file="test" + ) def test_image_outpainting(self, image_generator, path_from_resource_path) -> None: - image_generator.engine.run.return_value = ImageArtifact( + image_generator.image_generation_driver.run_image_variation.return_value = ImageArtifact( value=b"image data", format="png", width=512, height=512 ) @@ -56,10 +60,10 @@ def test_image_outpainting_with_outfile( ) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.png" image_generator = OutpaintingImageGenerationTool( - engine=image_generation_driver, output_file=outfile, image_loader=image_loader + image_generation_driver=image_generation_driver, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + image_generator.image_generation_driver.run_image_outpainting.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] value=b"image data", format="png", width=512, height=512 ) @@ -78,12 +82,12 @@ def test_image_outpainting_with_outfile( assert os.path.exists(outfile) def test_image_outpainting_from_memory(self, image_generation_driver, image_artifact): - image_generator = OutpaintingImageGenerationTool(engine=image_generation_driver) + image_generator = OutpaintingImageGenerationTool(image_generation_driver=image_generation_driver) memory = Mock() memory.load_artifacts = Mock(return_value=[image_artifact]) image_generator.find_input_memory = Mock(return_value=memory) - image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] + image_generator.image_generation_driver.run_image_variation.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" ) diff --git a/tests/unit/tools/test_prompt_image_generation_tool.py b/tests/unit/tools/test_prompt_image_generation_tool.py index bb5c7fc976..5c5057d6c6 100644 --- a/tests/unit/tools/test_prompt_image_generation_tool.py +++ b/tests/unit/tools/test_prompt_image_generation_tool.py @@ -16,14 +16,16 @@ def image_generation_driver(self) -> Mock: @pytest.fixture() def image_generator(self, image_generation_driver) -> PromptImageGenerationTool: - return PromptImageGenerationTool(engine=image_generation_driver) + return PromptImageGenerationTool(image_generation_driver=image_generation_driver) def test_validate_output_configs(self, image_generation_driver) -> None: with pytest.raises(ValueError): - PromptImageGenerationTool(engine=image_generation_driver, output_dir="test", output_file="test") + PromptImageGenerationTool( + image_generation_driver=image_generation_driver, output_dir="test", output_file="test" + ) def test_generate_image(self, image_generator) -> None: - image_generator.engine.run.return_value = Mock( + image_generator.image_generation_driver.run_text_to_image.return_value = Mock( value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" ) @@ -35,9 +37,11 @@ def test_generate_image(self, image_generator) -> None: def test_generate_image_with_outfile(self, image_generation_driver) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.png" - image_generator = PromptImageGenerationTool(engine=image_generation_driver, output_file=outfile) + image_generator = PromptImageGenerationTool( + image_generation_driver=image_generation_driver, output_file=outfile + ) - image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + image_generator.image_generation_driver.run_text_to_image.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] value=b"image data", format="png", width=512, height=512 ) diff --git a/tests/unit/tools/test_text_to_speech_tool.py b/tests/unit/tools/test_text_to_speech_tool.py index 6f2c43bd39..25f98aaf8e 100644 --- a/tests/unit/tools/test_text_to_speech_tool.py +++ b/tests/unit/tools/test_text_to_speech_tool.py @@ -11,29 +11,33 @@ class TestTextToSpeechTool: @pytest.fixture() - def text_to_speech_engine(self) -> Mock: + def text_to_speech_driver(self) -> Mock: return Mock() @pytest.fixture() - def text_to_speech_client(self, text_to_speech_engine) -> TextToSpeechTool: - return TextToSpeechTool(engine=text_to_speech_engine) + def text_to_speech_client(self, text_to_speech_driver) -> TextToSpeechTool: + return TextToSpeechTool(text_to_speech_driver=text_to_speech_driver) - def test_validate_output_configs(self, text_to_speech_engine) -> None: + def test_validate_output_configs(self, text_to_speech_driver) -> None: with pytest.raises(ValueError): - TextToSpeechTool(engine=text_to_speech_engine, output_dir="test", output_file="test") + TextToSpeechTool(text_to_speech_driver=text_to_speech_driver, output_dir="test", output_file="test") def test_text_to_speech(self, text_to_speech_client) -> None: - text_to_speech_client.engine.run.return_value = Mock(value=b"audio data", format="mp3") + text_to_speech_client.text_to_speech_driver.run_text_to_audio.return_value = Mock( + value=b"audio data", format="mp3" + ) audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}}) assert audio_artifact - def test_text_to_speech_with_outfile(self, text_to_speech_engine) -> None: + def test_text_to_speech_with_outfile(self, text_to_speech_driver) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.mp3" - text_to_speech_client = TextToSpeechTool(engine=text_to_speech_engine, output_file=outfile) + text_to_speech_client = TextToSpeechTool(text_to_speech_driver=text_to_speech_driver, output_file=outfile) - text_to_speech_client.engine.run.return_value = AudioArtifact(value=b"audio data", format="mp3") # pyright: ignore[reportFunctionMemberAccess] + text_to_speech_client.text_to_speech_driver.run_text_to_audio.return_value = AudioArtifact( # pyright: ignore[reportFunctionMemberAccess] + value=b"audio data", format="mp3" + ) audio_artifact = text_to_speech_client.text_to_speech(params={"values": {"text": "say this!"}}) diff --git a/tests/unit/tools/test_transcription_tool.py b/tests/unit/tools/test_transcription_tool.py index 07368495f0..c2f175dab7 100644 --- a/tests/unit/tools/test_transcription_tool.py +++ b/tests/unit/tools/test_transcription_tool.py @@ -8,7 +8,7 @@ class TestTranscriptionTool: @pytest.fixture() - def transcription_engine(self) -> Mock: + def audio_transcription_driver(self) -> Mock: return Mock() @pytest.fixture() @@ -26,26 +26,30 @@ def mock_path(self, mocker) -> Mock: return mocker - def test_init_transcription_client(self, transcription_engine, audio_loader) -> None: - assert AudioTranscriptionTool(engine=transcription_engine, audio_loader=audio_loader) + def test_init_transcription_client(self, audio_transcription_driver, audio_loader) -> None: + assert AudioTranscriptionTool(audio_transcription_driver=audio_transcription_driver, audio_loader=audio_loader) @patch("builtins.open", mock_open(read_data=b"audio data")) - def test_transcribe_audio_from_disk(self, transcription_engine, audio_loader) -> None: - client = AudioTranscriptionTool(engine=transcription_engine, audio_loader=audio_loader) - client.engine.run.return_value = Mock(value="transcription") # pyright: ignore[reportFunctionMemberAccess] + def test_transcribe_audio_from_disk(self, audio_transcription_driver, audio_loader) -> None: + client = AudioTranscriptionTool( + audio_transcription_driver=audio_transcription_driver, audio_loader=audio_loader + ) + client.audio_transcription_driver.run.return_value = Mock(value="transcription") # pyright: ignore[reportFunctionMemberAccess] text_artifact = client.transcribe_audio_from_disk(params={"values": {"path": "audio.wav"}}) assert text_artifact assert text_artifact.value == "transcription" - def test_transcribe_audio_from_memory(self, transcription_engine, audio_loader) -> None: - client = AudioTranscriptionTool(engine=transcription_engine, audio_loader=audio_loader) + def test_transcribe_audio_from_memory(self, audio_transcription_driver, audio_loader) -> None: + client = AudioTranscriptionTool( + audio_transcription_driver=audio_transcription_driver, audio_loader=audio_loader + ) memory = Mock() memory.load_artifacts = Mock(return_value=[AudioArtifact(value=b"audio data", format="wav", name="name")]) client.find_input_memory = Mock(return_value=memory) - client.engine.run.return_value = Mock(value="transcription") # pyright: ignore[reportFunctionMemberAccess] + client.audio_transcription_driver.run.return_value = Mock(value="transcription") # pyright: ignore[reportFunctionMemberAccess] text_artifact = client.transcribe_audio_from_memory( params={"values": {"memory_name": "memory", "artifact_namespace": "namespace", "artifact_name": "name"}} diff --git a/tests/unit/tools/test_variation_image_generation_tool.py b/tests/unit/tools/test_variation_image_generation_tool.py index c796f2c881..46eb47707a 100644 --- a/tests/unit/tools/test_variation_image_generation_tool.py +++ b/tests/unit/tools/test_variation_image_generation_tool.py @@ -27,16 +27,19 @@ def image_loader(self) -> Mock: @pytest.fixture() def image_generator(self, image_generation_driver, image_loader) -> VariationImageGenerationTool: - return VariationImageGenerationTool(engine=image_generation_driver, image_loader=image_loader) + return VariationImageGenerationTool(image_generation_driver=image_generation_driver, image_loader=image_loader) def test_validate_output_configs(self, image_generation_driver, image_loader) -> None: with pytest.raises(ValueError): VariationImageGenerationTool( - engine=image_generation_driver, output_dir="test", output_file="test", image_loader=image_loader + image_generation_driver=image_generation_driver, + output_dir="test", + output_file="test", + image_loader=image_loader, ) def test_image_variation(self, image_generator, path_from_resource_path) -> None: - image_generator.engine.run.return_value = Mock( + image_generator.image_generation_driver.run_image_variation.return_value = Mock( value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" ) @@ -55,10 +58,10 @@ def test_image_variation(self, image_generator, path_from_resource_path) -> None def test_image_variation_with_outfile(self, image_generation_driver, image_loader, path_from_resource_path) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.png" image_generator = VariationImageGenerationTool( - engine=image_generation_driver, output_file=outfile, image_loader=image_loader + image_generation_driver=image_generation_driver, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + image_generator.image_generation_driver.run_image_variation.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] value=b"image data", format="png", width=512, height=512 ) @@ -76,12 +79,12 @@ def test_image_variation_with_outfile(self, image_generation_driver, image_loade assert os.path.exists(outfile) def test_image_variation_from_memory(self, image_generation_driver, image_artifact): - image_generator = VariationImageGenerationTool(engine=image_generation_driver) + image_generator = VariationImageGenerationTool(image_generation_driver=image_generation_driver) memory = Mock() memory.load_artifacts = Mock(return_value=[image_artifact]) image_generator.find_input_memory = Mock(return_value=memory) - image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] + image_generator.image_generation_driver.run_image_variation.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" )