diff --git a/big_sleep/big_sleep.py b/big_sleep/big_sleep.py index b849221..9e7b779 100644 --- a/big_sleep/big_sleep.py +++ b/big_sleep/big_sleep.py @@ -1,10 +1,3 @@ -import torch -import torch.nn.functional as F -from torch import nn -from torch.optim import Adam - -from torchvision.utils import save_image - import os import sys import subprocess @@ -14,6 +7,15 @@ from datetime import datetime from pathlib import Path +import random + +import torch +import torch.nn.functional as F +from torch import nn +from torch.optim import Adam +from torchvision.utils import save_image +import torchvision.transforms as T +from PIL import Image from tqdm import tqdm, trange from big_sleep.ema import EMA @@ -63,10 +65,20 @@ def open_folder(path): pass -def underscorify(value): - no_punctuation = str(value.translate(str.maketrans('', '', string.punctuation))) - spaces_to_one_underline = re.sub(r'[-\s]+', '_', no_punctuation).strip('-_') # strip gets rid of leading or trailing underscores - return spaces_to_one_underline +def create_text_path(text=None, img=None, encoding=None): + input_name = "" + if text is not None: + input_name += text + if img is not None: + if isinstance(img, str): + img_name = "".join(img.split(".")[:-1]) # replace spaces by underscores, remove img extension + img_name = img_name.split("/")[-1] # only take img name, not path + else: + img_name = "PIL_img" + input_name += "_" + img_name + if encoding is not None: + input_name = "your_encoding" + return input_name.replace("-", "_").replace(",", "").replace(" ", "_").strip('-_')[:255] # tensor helpers @@ -85,6 +97,39 @@ def differentiable_topk(x, k, temperature=1.): topks = torch.cat(topk_tensors, dim=-1) return topks.reshape(n, k, dim).sum(dim = 1) + +def create_clip_img_transform(image_width): + clip_mean = [0.48145466, 0.4578275, 0.40821073] + clip_std = [0.26862954, 0.26130258, 0.27577711] + transform = T.Compose([ + #T.ToPILImage(), + T.Resize(image_width), + T.CenterCrop((image_width, image_width)), + T.ToTensor(), + T.Normalize(mean=clip_mean, std=clip_std) + ]) + return transform + + +def rand_cutout(image, size, center_bias=False, center_focus=2): + width = image.shape[-1] + min_offset = 0 + max_offset = width - size + if center_bias: + # sample around image center + center = max_offset / 2 + std = center / center_focus + offset_x = int(random.gauss(mu=center, sigma=std)) + offset_y = int(random.gauss(mu=center, sigma=std)) + # resample uniformly if over boundaries + offset_x = random.randint(min_offset, max_offset) if (offset_x > max_offset or offset_x < min_offset) else offset_x + offset_y = random.randint(min_offset, max_offset) if (offset_y > max_offset or offset_y < min_offset) else offset_y + else: + offset_x = random.randint(min_offset, max_offset) + offset_y = random.randint(min_offset, max_offset) + cutout = image[:, :, offset_x:offset_x + size, offset_y:offset_y + size] + return cutout + # load clip perceptor, normalize_image = load('ViT-B/32', jit = False) @@ -150,7 +195,7 @@ def forward(self): out = self.biggan(*self.latents(), 1) return (out + 1) / 2 -# load siren + class BigSleep(nn.Module): def __init__( self, @@ -161,13 +206,15 @@ def __init__( max_classes = None, class_temperature = 2., experimental_resample = False, - ema_decay = 0.99 + ema_decay = 0.99, + center_bias = False, ): super().__init__() self.loss_coef = loss_coef self.image_size = image_size self.num_cutouts = num_cutouts self.experimental_resample = experimental_resample + self.center_bias = center_bias self.interpolation_settings = {'mode': 'bilinear', 'align_corners': False} if bilinear else {'mode': 'nearest'} @@ -187,8 +234,6 @@ def sim_txt_to_img(self, text_embed, img_embed, text_type="max"): sign = 1 return sign * self.loss_coef * torch.cosine_similarity(text_embed, img_embed, dim = -1).mean() - - def forward(self, text_embeds, text_min_embeds=[], return_loss = True): width, num_cutouts = self.image_size, self.num_cutouts @@ -199,10 +244,10 @@ def forward(self, text_embeds, text_min_embeds=[], return_loss = True): pieces = [] for ch in range(num_cutouts): + # sample cutout size size = int(width * torch.zeros(1,).normal_(mean=.8, std=.3).clip(.5, .95)) - offsetx = torch.randint(0, width - size, ()) - offsety = torch.randint(0, width - size, ()) - apper = out[:, :, offsetx:offsetx + size, offsety:offsety + size] + # get cutout + apper = rand_cutout(out, size, center_bias=self.center_bias) if (self.experimental_resample): apper = resample(apper, (224, 224)) else: @@ -242,13 +287,16 @@ def forward(self, text_embeds, text_min_embeds=[], return_loss = True): for txt_min_embed in text_min_embeds: results.append(self.sim_txt_to_img(txt_min_embed, image_embed, "min")) sim_loss = sum(results).mean() - return (lat_loss, cls_loss, sim_loss) + return out, (lat_loss, cls_loss, sim_loss) + class Imagine(nn.Module): def __init__( self, - text, *, + text=None, + img=None, + encoding=None, text_min = "", lr = .07, image_size = 512, @@ -266,7 +314,9 @@ def __init__( save_date_time = False, save_best = False, experimental_resample = False, - ema_decay = 0.99 + ema_decay = 0.99, + num_cutouts = 128, + center_bias = False, ): super().__init__() @@ -290,9 +340,9 @@ def __init__( max_classes = max_classes, class_temperature = class_temperature, experimental_resample = experimental_resample, - ema_decay - = ema_decay - + ema_decay = ema_decay, + num_cutouts = num_cutouts, + center_bias = center_bias, ).cuda() self.model = model @@ -314,35 +364,65 @@ def __init__( "max": [], "min": [] } - self.set_text(text, text_min) - - def encode_one_phrase(self, phrase): - return perceptor.encode_text(tokenize(f'''{phrase}''').cuda()).detach().clone() + # create img transform + self.clip_transform = create_clip_img_transform(perceptor.input_resolution.item()) + # create starting encoding + self.set_clip_encoding(text=text, img=img, encoding=encoding, text_min=text_min) + + def create_clip_encoding(self, text=None, img=None, encoding=None): + self.text = text + self.img = img + if encoding is not None: + encoding = encoding.cuda() + #elif self.create_story: + # encoding = self.update_story_encoding(epoch=0, iteration=1) + elif text is not None and img is not None: + encoding = (self.create_text_encoding(text) + self.create_img_encoding(img)) / 2 + elif text is not None: + encoding = self.create_text_encoding(text) + elif img is not None: + encoding = self.create_img_encoding(img) + return encoding + + def create_text_encoding(self, text): + tokenized_text = tokenize(text).cuda() + with torch.no_grad(): + text_encoding = perceptor.encode_text(tokenized_text).detach() + return text_encoding - def encode_multiple_phrases(self, text, text_type="max"): - if len(text) > 0 and "\\" in text: - self.encoded_texts[text_type] = [self.encode_one_phrase(prompt_min) for prompt_min in text.split("\\")] + def create_img_encoding(self, img): + if isinstance(img, str): + img = Image.open(img) + normed_img = self.clip_transform(img).unsqueeze(0).cuda() + with torch.no_grad(): + img_encoding = perceptor.encode_image(normed_img).detach() + return img_encoding + + + def encode_multiple_phrases(self, text, img=None, encoding=None, text_type="max"): + if text is not None and "\\" in text: + self.encoded_texts[text_type] = [self.create_clip_encoding(text=prompt_min, img=img, encoding=encoding) for prompt_min in text.split("\\")] else: - self.encoded_texts[text_type] = [self.encode_one_phrase(text)] + self.encoded_texts[text_type] = [self.create_clip_encoding(text=text, img=img, encoding=encoding)] - def encode_max_and_min(self, text, text_min=""): - self.encode_multiple_phrases(text) - self.encode_multiple_phrases(text_min, "min") + def encode_max_and_min(self, text, img=None, encoding=None, text_min=""): + self.encode_multiple_phrases(text, img=img, encoding=encoding) + if text_min is not None and text_min != "": + self.encode_multiple_phrases(text_min, img=img, encoding=encoding, text_type="min") - def set_text(self, text, text_min=""): + def set_clip_encoding(self, text=None, img=None, encoding=None, text_min=""): self.text = text self.text_min = text_min - textpath = text[:255] + if len(text_min) > 0: - textpath = textpath + "_wout_" + text_min[:255] - textpath = underscorify(textpath) + text = text + "_wout_" + text_min[:255] if text is not None else "wout_" + text_min[:255] + text_path = create_text_path(text=text, img=img, encoding=encoding) if self.save_date_time: - textpath = datetime.now().strftime("%y%m%d-%H%M%S-") + textpath - - self.textpath = textpath - self.filename = Path(f'./{textpath}.png') - self.encode_max_and_min(text, text_min) # Tokenize and encode each prompt + text_path = datetime.now().strftime("%y%m%d-%H%M%S-") + text_path + self.text_path = text_path + self.filename = Path(f'./{text_path}.png') + self.encode_max_and_min(text, img=img, encoding=encoding, text_min=text_min) # Tokenize and encode each prompt def reset(self): self.model.reset() @@ -353,7 +433,7 @@ def train_step(self, epoch, i, pbar=None): total_loss = 0 for _ in range(self.gradient_accumulate_every): - losses = self.model(self.encoded_texts["max"], self.encoded_texts["min"]) + out, losses = self.model(self.encoded_texts["max"], self.encoded_texts["min"]) loss = sum(losses) / self.gradient_accumulate_every total_loss += loss loss.backward() @@ -365,8 +445,8 @@ def train_step(self, epoch, i, pbar=None): if (i + 1) % self.save_every == 0: with torch.no_grad(): self.model.model.latents.eval() - losses = self.model(self.encoded_texts["max"], self.encoded_texts["min"]) - top_score, best = torch.topk(losses[2], k = 1, largest = False) + out, losses = self.model(self.encoded_texts["max"], self.encoded_texts["min"]) + top_score, best = torch.topk(losses[2], k=1, largest=False) image = self.model.model()[best].cpu() self.model.model.latents.train() @@ -379,20 +459,22 @@ def train_step(self, epoch, i, pbar=None): if self.save_progress: total_iterations = epoch * self.iterations + i num = total_iterations // self.save_every - save_image(image, Path(f'./{self.textpath}.{num}.png')) + save_image(image, Path(f'./{self.text_path}.{num}.png')) if self.save_best and top_score.item() < self.current_best_score: self.current_best_score = top_score.item() - save_image(image, Path(f'./{self.textpath}.best.png')) + save_image(image, Path(f'./{self.text_path}.best.png')) - return total_loss + return out, total_loss def forward(self): penalizing = "" if len(self.text_min) > 0: penalizing = f'penalizing "{self.text_min}"' - print(f'Imagining "{self.text}" {penalizing}...') - self.model(self.encoded_texts["max"][0]) # one warmup step due to issue with CLIP and CUDA + print(f'Imagining "{self.text_path}" {penalizing}...') + + with torch.no_grad(): + self.model(self.encoded_texts["max"][0]) # one warmup step due to issue with CLIP and CUDA if self.open_folder: open_folder('./') @@ -403,7 +485,7 @@ def forward(self): pbar = trange(self.iterations, desc=' iteration', position=1, leave=True) image_pbar.update(0) for i in pbar: - loss = self.train_step(epoch, i, image_pbar) + out, loss = self.train_step(epoch, i, image_pbar) pbar.set_description(f'loss: {loss.item():04.2f}') if terminate: diff --git a/big_sleep/cli.py b/big_sleep/cli.py index 177d08d..b4dc920 100644 --- a/big_sleep/cli.py +++ b/big_sleep/cli.py @@ -4,8 +4,10 @@ from pathlib import Path from .version import __version__; + def train( - text, + text=None, + img=None, text_min="", lr = .07, image_size = 512, @@ -25,7 +27,9 @@ def train( class_temperature = 2., save_best = False, experimental_resample = False, - ema_decay = 0.5 + ema_decay = 0.5, + num_cutouts = 128, + center_bias = False, ): print(f'Starting up... v{__version__}') @@ -33,7 +37,8 @@ def train( seed = rnd.randint(0, 1e6) imagine = Imagine( - text, + text=text, + img=img, text_min=text_min, lr = lr, image_size = image_size, @@ -51,7 +56,9 @@ def train( save_date_time = save_date_time, save_best = save_best, experimental_resample = experimental_resample, - ema_decay = ema_decay + ema_decay = ema_decay, + num_cutouts = num_cutouts, + center_bias = center_bias, ) if not overwrite and imagine.filename.exists():