-
Notifications
You must be signed in to change notification settings - Fork 14
/
prompt_manager.py
63 lines (57 loc) · 3.07 KB
/
prompt_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from typing import Optional, List, Dict, Any
import torch
from tqdm import tqdm
from transformers import CLIPTokenizer
import constants
from models.neti_clip_text_encoder import NeTICLIPTextModel
from utils.types import NeTIBatch
class PromptManager:
""" Class for computing all time and space embeddings for a given prompt. """
def __init__(self, tokenizer: CLIPTokenizer,
text_encoder: NeTICLIPTextModel,
timesteps: List[int] = constants.SD_INFERENCE_TIMESTEPS,
unet_layers: List[str] = constants.UNET_LAYERS,
placeholder_token_id: Optional[List] = None,
placeholder_token: Optional[List] = None,
torch_dtype: torch.dtype = torch.float32):
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.timesteps = timesteps
self.unet_layers = unet_layers
self.placeholder_token = placeholder_token
self.placeholder_token_id = placeholder_token_id
self.dtype = torch_dtype
def embed_prompt(self, text: str,
truncation_idx: Optional[int] = None,
num_images_per_prompt: int = 1) -> List[Dict[str, Any]]:
"""
Compute the conditioning vectors for the given prompt. We assume that the prompt is defined using `{}`
for indicating where to place the placeholder token string. See constants.VALIDATION_PROMPTS for examples.
"""
text = text.format(self.placeholder_token)
ids = self.tokenizer(
text,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
# Compute embeddings for each timestep and each U-Net layer
print(f"Computing embeddings over {len(self.timesteps)} timesteps and {len(self.unet_layers)} U-Net layers.")
hidden_states_per_timestep = []
for timestep in tqdm(self.timesteps):
_hs = {"this_idx": 0}.copy()
for layer_idx, unet_layer in enumerate(self.unet_layers):
batch = NeTIBatch(input_ids=ids.to(device=self.text_encoder.device),
timesteps=timestep.unsqueeze(0).to(device=self.text_encoder.device),
unet_layers=torch.tensor(layer_idx, device=self.text_encoder.device).unsqueeze(0),
placeholder_token_id=self.placeholder_token_id,
truncation_idx=truncation_idx)
layer_hs, layer_hs_bypass = self.text_encoder(batch=batch)
layer_hs = layer_hs[0].to(dtype=self.dtype)
_hs[f"CONTEXT_TENSOR_{layer_idx}"] = layer_hs.repeat(num_images_per_prompt, 1, 1)
if layer_hs_bypass is not None:
layer_hs_bypass = layer_hs_bypass[0].to(dtype=self.dtype)
_hs[f"CONTEXT_TENSOR_BYPASS_{layer_idx}"] = layer_hs_bypass.repeat(num_images_per_prompt, 1, 1)
hidden_states_per_timestep.append(_hs)
print("Done.")
return hidden_states_per_timestep