Skip to content

Commit

Permalink
Late-import transformers for magic prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Nov 13, 2023
1 parent 719ab71 commit 646672c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 24 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ exclude = "tests"
module = "transformers"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "torch"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "spacy.*"
ignore_missing_imports = true
Expand Down
35 changes: 20 additions & 15 deletions src/dynamicprompts/generators/magicprompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,13 @@

logger = logging.getLogger(__name__)

try:
if TYPE_CHECKING:
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Pipeline,
pipeline,
set_seed,
)
except ImportError as ie:
raise ImportError(
"You need to install the transformers library to use the MagicPrompt generator. "
"You can do this by running `pip install -U dynamicprompts[magicprompt]`.",
) from ie

if TYPE_CHECKING:
import torch

DEFAULT_MODEL_NAME = "Gustavosta/MagicPrompt-Stable-Diffusion"
MAX_SEED = 2**32 - 1
Expand Down Expand Up @@ -71,6 +62,18 @@ def clean_up_magic_prompt(orig_prompt: str, prompt: str) -> str:
return prompt


def _import_transformers():
try:
import transformers

return transformers
except ImportError as ie:
raise ImportError(
"You need to install the transformers library to use the MagicPrompt generator. "
"You can do this by running `pip install -U dynamicprompts[magicprompt]`.",
) from ie


class MagicPromptGenerator(PromptGenerator):
generator: Pipeline | None = None
tokenizer: AutoTokenizer | None = None
Expand All @@ -83,13 +86,14 @@ def _load_pipeline(self, model_name: str) -> Pipeline:
logger.warning("First load of MagicPrompt may take a while.")

if MagicPromptGenerator.generator is None:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
transformers = _import_transformers()
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
tokenizer.pad_token_id = model.config.eos_token_id

MagicPromptGenerator.tokenizer = tokenizer
MagicPromptGenerator.model = model
MagicPromptGenerator.generator = pipeline(
MagicPromptGenerator.generator = transformers.pipeline(
task="text-generation",
tokenizer=tokenizer,
model=model,
Expand Down Expand Up @@ -123,6 +127,7 @@ def __init__(
:param blocklist_regex: A regex to use to filter out prompts that match it.
:param batch_size: The batch size to use when generating prompts.
"""
transformers = _import_transformers()
self._device = device
self.set_model(model_name)

Expand All @@ -140,7 +145,7 @@ def __init__(
self._blocklist_regex = None

if seed is not None:
set_seed(int(seed))
transformers.set_seed(int(seed))

self._batch_size = batch_size

Expand Down
20 changes: 11 additions & 9 deletions tests/generators/test_magicprompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@

import pytest

pytest.importorskip("dynamicprompts.generators.magicprompt")

@pytest.fixture(autouse=True)
def mock_import_transformers(monkeypatch):
from dynamicprompts.generators import magicprompt

@pytest.mark.slow
class TestMagicPrompt:
def test_default_generator(self):
from dynamicprompts.generators.dummygenerator import DummyGenerator
from dynamicprompts.generators.magicprompt import MagicPromptGenerator
monkeypatch.setattr(magicprompt, "_import_transformers", MagicMock())


def test_default_generator():
from dynamicprompts.generators.dummygenerator import DummyGenerator
from dynamicprompts.generators.magicprompt import MagicPromptGenerator

generator = MagicPromptGenerator()
assert isinstance(generator._prompt_generator, DummyGenerator)
generator = MagicPromptGenerator()
assert isinstance(generator._prompt_generator, DummyGenerator)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -121,7 +124,6 @@ def _generator(
assert not any(artist in magic_prompt for artist in boring_artists)


@pytest.mark.slow
def test_generate_passes_kwargs():
from dynamicprompts.generators.magicprompt import MagicPromptGenerator

Expand Down

0 comments on commit 646672c

Please sign in to comment.