Skip to content

Commit

Permalink
Add some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Oct 25, 2024
1 parent 1e80649 commit 6fcc953
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tests/unit/tasks/test_base_image_generation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,18 @@ def test_negative_rulesets_from_rules(self) -> None:
def test_validate_output_dir(self) -> None:
with pytest.raises(ValueError):
MockImageGenerationTask(TextArtifact("some input"), output_dir="some/dir", output_file="some/file")

def test__get_prompts(self):
task = MockImageGenerationTask(
TextArtifact("some input"), rulesets=[Ruleset(name="Ruleset", rules=[Rule(value="Rule")])]
)

assert task._get_prompts(task.input.to_text()) == ["some input", "Rule"]

def test__get_negative_prompts(self):
task = MockImageGenerationTask(
TextArtifact("some input"),
negative_rulesets=[Ruleset(name="Negative Ruleset", rules=[Rule(value="Negative Rule")])],
)

assert task._get_negative_prompts() == ["Negative Rule"]
36 changes: 36 additions & 0 deletions tests/unit/tools/test_image_query_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest

from griptape.artifacts.image_artifact import ImageArtifact
from griptape.tools import ImageQueryTool
from tests.mocks.mock_image_query_driver import MockImageQueryDriver
from tests.utils import defaults


class TestImageQueryTool:
@pytest.fixture()
def tool(self):
task_memory = defaults.text_task_memory("memory_name")
task_memory.store_artifact("namespace", ImageArtifact(b"", format="png", width=1, height=1, name="test"))
return ImageQueryTool(input_memory=[task_memory], image_query_driver=MockImageQueryDriver())

def test_query_image_from_disk(self, tool):
assert tool.query_image_from_disk({"values": {"query": "test", "image_paths": []}}).value == "mock text"

def test_query_images_from_memory(self, tool):
assert (
tool.query_images_from_memory(
{
"values": {
"query": "test",
"memory_name": tool.input_memory[0].name,
"image_artifacts": [
{
"image_artifact_name": "test",
"image_artifact_namespace": "namespace",
}
],
}
}
).value
== "mock text"
)

0 comments on commit 6fcc953

Please sign in to comment.