diff --git a/big_sleep/big_sleep.py b/big_sleep/big_sleep.py index 4a50a02..84f4692 100644 --- a/big_sleep/big_sleep.py +++ b/big_sleep/big_sleep.py @@ -166,7 +166,7 @@ def __init__( def reset(self): self.model.init_latents() - def forward(self, text, return_loss = True): + def forward(self, text_embed, return_loss = True): width, num_cutouts = self.image_size, self.num_cutouts out = self.model() @@ -190,7 +190,6 @@ def forward(self, text, return_loss = True): into = normalize_image(into) image_embed = perceptor.encode_image(into) - text_embed = perceptor.encode_text(text) latents, soft_one_hot_classes = self.model.latents() num_latents = latents.shape[0] @@ -287,7 +286,8 @@ def set_text(self, text): self.textpath = textpath self.filename = Path(f'./{textpath}.png') - self.encoded_text = tokenize(text).cuda() + encoded_text = tokenize(text).cuda() + self.encoded_text = perceptor.encode_text(encoded_text).detach() def reset(self): self.model.reset() diff --git a/big_sleep/version.py b/big_sleep/version.py index d4f3346..2b8877c 100644 --- a/big_sleep/version.py +++ b/big_sleep/version.py @@ -1 +1 @@ -__version__ = '0.4.11' +__version__ = '0.5.0'