diff --git a/big_sleep/big_sleep.py b/big_sleep/big_sleep.py
index b849221..9e7b779 100644
--- a/big_sleep/big_sleep.py
+++ b/big_sleep/big_sleep.py
@@ -1,10 +1,3 @@
-import torch
-import torch.nn.functional as F
-from torch import nn
-from torch.optim import Adam
-
-from torchvision.utils import save_image
-
 import os
 import sys
 import subprocess
@@ -14,6 +7,15 @@
 
 from datetime import datetime
 from pathlib import Path
+import random
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.optim import Adam
+from torchvision.utils import save_image
+import torchvision.transforms as T
+from PIL import Image
 from tqdm import tqdm, trange
 
 from big_sleep.ema import EMA
@@ -63,10 +65,20 @@ def open_folder(path):
         pass
 
 
-def underscorify(value):
-  no_punctuation = str(value.translate(str.maketrans('', '', string.punctuation)))
-  spaces_to_one_underline = re.sub(r'[-\s]+', '_', no_punctuation).strip('-_') # strip gets rid of leading or trailing underscores
-  return spaces_to_one_underline
+def create_text_path(text=None, img=None, encoding=None):
+    input_name = ""
+    if text is not None:
+        input_name += text
+    if img is not None:
+        if isinstance(img, str):
+            img_name = "".join(img.split(".")[:-1]) # replace spaces by underscores, remove img extension
+            img_name = img_name.split("/")[-1]  # only take img name, not path
+        else:
+            img_name = "PIL_img"
+        input_name += "_" + img_name
+    if encoding is not None:
+        input_name = "your_encoding"
+    return input_name.replace("-", "_").replace(",", "").replace(" ", "_").strip('-_')[:255]
 
 # tensor helpers
 
@@ -85,6 +97,39 @@ def differentiable_topk(x, k, temperature=1.):
     topks = torch.cat(topk_tensors, dim=-1)
     return topks.reshape(n, k, dim).sum(dim = 1)
 
+
+def create_clip_img_transform(image_width):
+    clip_mean = [0.48145466, 0.4578275, 0.40821073]
+    clip_std = [0.26862954, 0.26130258, 0.27577711]
+    transform = T.Compose([
+                    #T.ToPILImage(),
+                    T.Resize(image_width),
+                    T.CenterCrop((image_width, image_width)),
+                    T.ToTensor(),
+                    T.Normalize(mean=clip_mean, std=clip_std)
+            ])
+    return transform
+
+
+def rand_cutout(image, size, center_bias=False, center_focus=2):
+    width = image.shape[-1]
+    min_offset = 0
+    max_offset = width - size
+    if center_bias:
+        # sample around image center
+        center = max_offset / 2
+        std = center / center_focus
+        offset_x = int(random.gauss(mu=center, sigma=std))
+        offset_y = int(random.gauss(mu=center, sigma=std))
+        # resample uniformly if over boundaries
+        offset_x = random.randint(min_offset, max_offset) if (offset_x > max_offset or offset_x < min_offset) else offset_x
+        offset_y = random.randint(min_offset, max_offset) if (offset_y > max_offset or offset_y < min_offset) else offset_y
+    else:
+        offset_x = random.randint(min_offset, max_offset)
+        offset_y = random.randint(min_offset, max_offset)
+    cutout = image[:, :, offset_x:offset_x + size, offset_y:offset_y + size]
+    return cutout
+
 # load clip
 
 perceptor, normalize_image = load('ViT-B/32', jit = False)
@@ -150,7 +195,7 @@ def forward(self):
         out = self.biggan(*self.latents(), 1)
         return (out + 1) / 2
 
-# load siren
+
 class BigSleep(nn.Module):
     def __init__(
         self,
@@ -161,13 +206,15 @@ def __init__(
         max_classes = None,
         class_temperature = 2.,
         experimental_resample = False,
-        ema_decay = 0.99
+        ema_decay = 0.99,
+        center_bias = False,
     ):
         super().__init__()
         self.loss_coef = loss_coef
         self.image_size = image_size
         self.num_cutouts = num_cutouts
         self.experimental_resample = experimental_resample
+        self.center_bias = center_bias
 
         self.interpolation_settings = {'mode': 'bilinear', 'align_corners': False} if bilinear else {'mode': 'nearest'}
 
@@ -187,8 +234,6 @@ def sim_txt_to_img(self, text_embed, img_embed, text_type="max"):
             sign = 1
         return sign * self.loss_coef * torch.cosine_similarity(text_embed, img_embed, dim = -1).mean()
 
-
-
     def forward(self, text_embeds, text_min_embeds=[], return_loss = True):
         width, num_cutouts = self.image_size, self.num_cutouts
 
@@ -199,10 +244,10 @@ def forward(self, text_embeds, text_min_embeds=[], return_loss = True):
 
         pieces = []
         for ch in range(num_cutouts):
+            # sample cutout size
             size = int(width * torch.zeros(1,).normal_(mean=.8, std=.3).clip(.5, .95))
-            offsetx = torch.randint(0, width - size, ())
-            offsety = torch.randint(0, width - size, ())
-            apper = out[:, :, offsetx:offsetx + size, offsety:offsety + size]
+            # get cutout
+            apper = rand_cutout(out, size, center_bias=self.center_bias)
             if (self.experimental_resample):
                 apper = resample(apper, (224, 224))
             else:
@@ -242,13 +287,16 @@ def forward(self, text_embeds, text_min_embeds=[], return_loss = True):
         for txt_min_embed in text_min_embeds:
             results.append(self.sim_txt_to_img(txt_min_embed, image_embed, "min"))
         sim_loss = sum(results).mean()
-        return (lat_loss, cls_loss, sim_loss)
+        return out, (lat_loss, cls_loss, sim_loss)
+
 
 class Imagine(nn.Module):
     def __init__(
         self,
-        text,
         *,
+        text=None,
+        img=None,
+        encoding=None,
         text_min = "",
         lr = .07,
         image_size = 512,
@@ -266,7 +314,9 @@ def __init__(
         save_date_time = False,
         save_best = False,
         experimental_resample = False,
-        ema_decay = 0.99
+        ema_decay = 0.99,
+        num_cutouts = 128,
+        center_bias = False,
     ):
         super().__init__()
 
@@ -290,9 +340,9 @@ def __init__(
             max_classes = max_classes,
             class_temperature = class_temperature,
             experimental_resample = experimental_resample,
-            ema_decay
-            = ema_decay
-
+            ema_decay = ema_decay,
+            num_cutouts = num_cutouts,
+            center_bias = center_bias,
         ).cuda()
 
         self.model = model
@@ -314,35 +364,65 @@ def __init__(
             "max": [],
             "min": []
         }
-        self.set_text(text, text_min)
-
-    def encode_one_phrase(self, phrase):
-        return perceptor.encode_text(tokenize(f'''{phrase}''').cuda()).detach().clone()
+        # create img transform
+        self.clip_transform = create_clip_img_transform(perceptor.input_resolution.item())
+        # create starting encoding
+        self.set_clip_encoding(text=text, img=img, encoding=encoding, text_min=text_min)
+    
+    def create_clip_encoding(self, text=None, img=None, encoding=None):
+        self.text = text
+        self.img = img
+        if encoding is not None:
+            encoding = encoding.cuda()
+        #elif self.create_story:
+        #    encoding = self.update_story_encoding(epoch=0, iteration=1)
+        elif text is not None and img is not None:
+            encoding = (self.create_text_encoding(text) + self.create_img_encoding(img)) / 2
+        elif text is not None:
+            encoding = self.create_text_encoding(text)
+        elif img is not None:
+            encoding = self.create_img_encoding(img)
+        return encoding
+
+    def create_text_encoding(self, text):
+        tokenized_text = tokenize(text).cuda()
+        with torch.no_grad():
+            text_encoding = perceptor.encode_text(tokenized_text).detach()
+        return text_encoding
     
-    def encode_multiple_phrases(self, text, text_type="max"):
-        if len(text) > 0 and "\\" in text:
-            self.encoded_texts[text_type] = [self.encode_one_phrase(prompt_min) for prompt_min in text.split("\\")]
+    def create_img_encoding(self, img):
+        if isinstance(img, str):
+            img = Image.open(img)
+        normed_img = self.clip_transform(img).unsqueeze(0).cuda()
+        with torch.no_grad():
+            img_encoding = perceptor.encode_image(normed_img).detach()
+        return img_encoding
+    
+    
+    def encode_multiple_phrases(self, text, img=None, encoding=None, text_type="max"):
+        if text is not None and "\\" in text:
+            self.encoded_texts[text_type] = [self.create_clip_encoding(text=prompt_min, img=img, encoding=encoding) for prompt_min in text.split("\\")]
         else:
-            self.encoded_texts[text_type] = [self.encode_one_phrase(text)]
+            self.encoded_texts[text_type] = [self.create_clip_encoding(text=text, img=img, encoding=encoding)]
 
-    def encode_max_and_min(self, text, text_min=""):
-        self.encode_multiple_phrases(text)
-        self.encode_multiple_phrases(text_min, "min")
+    def encode_max_and_min(self, text, img=None, encoding=None, text_min=""):
+        self.encode_multiple_phrases(text, img=img, encoding=encoding)
+        if text_min is not None and text_min != "":
+            self.encode_multiple_phrases(text_min, img=img, encoding=encoding, text_type="min")
 
-    def set_text(self, text, text_min=""):
+    def set_clip_encoding(self, text=None, img=None, encoding=None, text_min=""):
         self.text = text
         self.text_min = text_min
-        textpath = text[:255]
+        
         if len(text_min) > 0:
-            textpath = textpath + "_wout_" + text_min[:255]
-        textpath = underscorify(textpath)
+            text = text + "_wout_" + text_min[:255] if text is not None else "wout_" + text_min[:255]
+        text_path = create_text_path(text=text, img=img, encoding=encoding)
         if self.save_date_time:
-            textpath = datetime.now().strftime("%y%m%d-%H%M%S-") + textpath
-
-        self.textpath = textpath
-        self.filename = Path(f'./{textpath}.png')
-        self.encode_max_and_min(text, text_min) # Tokenize and encode each prompt
+            text_path = datetime.now().strftime("%y%m%d-%H%M%S-") + text_path
 
+        self.text_path = text_path
+        self.filename = Path(f'./{text_path}.png')
+        self.encode_max_and_min(text, img=img, encoding=encoding, text_min=text_min) # Tokenize and encode each prompt
 
     def reset(self):
         self.model.reset()
@@ -353,7 +433,7 @@ def train_step(self, epoch, i, pbar=None):
         total_loss = 0
 
         for _ in range(self.gradient_accumulate_every):
-            losses = self.model(self.encoded_texts["max"], self.encoded_texts["min"])
+            out, losses = self.model(self.encoded_texts["max"], self.encoded_texts["min"])
             loss = sum(losses) / self.gradient_accumulate_every
             total_loss += loss
             loss.backward()
@@ -365,8 +445,8 @@ def train_step(self, epoch, i, pbar=None):
         if (i + 1) % self.save_every == 0:
             with torch.no_grad():
                 self.model.model.latents.eval()
-                losses = self.model(self.encoded_texts["max"], self.encoded_texts["min"])
-                top_score, best = torch.topk(losses[2], k = 1, largest = False)
+                out, losses = self.model(self.encoded_texts["max"], self.encoded_texts["min"])
+                top_score, best = torch.topk(losses[2], k=1, largest=False)
                 image = self.model.model()[best].cpu()
                 self.model.model.latents.train()
 
@@ -379,20 +459,22 @@ def train_step(self, epoch, i, pbar=None):
                 if self.save_progress:
                     total_iterations = epoch * self.iterations + i
                     num = total_iterations // self.save_every
-                    save_image(image, Path(f'./{self.textpath}.{num}.png'))
+                    save_image(image, Path(f'./{self.text_path}.{num}.png'))
 
                 if self.save_best and top_score.item() < self.current_best_score:
                     self.current_best_score = top_score.item()
-                    save_image(image, Path(f'./{self.textpath}.best.png'))
+                    save_image(image, Path(f'./{self.text_path}.best.png'))
 
-        return total_loss
+        return out, total_loss
 
     def forward(self):
         penalizing = ""
         if len(self.text_min) > 0:
             penalizing = f'penalizing "{self.text_min}"'
-        print(f'Imagining "{self.text}" {penalizing}...')
-        self.model(self.encoded_texts["max"][0]) # one warmup step due to issue with CLIP and CUDA
+        print(f'Imagining "{self.text_path}" {penalizing}...')
+        
+        with torch.no_grad():
+            self.model(self.encoded_texts["max"][0]) # one warmup step due to issue with CLIP and CUDA
 
         if self.open_folder:
             open_folder('./')
@@ -403,7 +485,7 @@ def forward(self):
             pbar = trange(self.iterations, desc='   iteration', position=1, leave=True)
             image_pbar.update(0)
             for i in pbar:
-                loss = self.train_step(epoch, i, image_pbar)
+                out, loss = self.train_step(epoch, i, image_pbar)
                 pbar.set_description(f'loss: {loss.item():04.2f}')
 
                 if terminate:
diff --git a/big_sleep/cli.py b/big_sleep/cli.py
index 177d08d..b4dc920 100644
--- a/big_sleep/cli.py
+++ b/big_sleep/cli.py
@@ -4,8 +4,10 @@
 from pathlib import Path
 from .version import __version__;
 
+
 def train(
-    text,
+    text=None,
+    img=None,
     text_min="",
     lr = .07,
     image_size = 512,
@@ -25,7 +27,9 @@ def train(
     class_temperature = 2.,
     save_best = False,
     experimental_resample = False,
-    ema_decay = 0.5
+    ema_decay = 0.5,
+    num_cutouts = 128,
+    center_bias = False,
 ):
     print(f'Starting up... v{__version__}')
 
@@ -33,7 +37,8 @@ def train(
         seed = rnd.randint(0, 1e6)
 
     imagine = Imagine(
-        text,
+        text=text,
+        img=img,
         text_min=text_min,
         lr = lr,
         image_size = image_size,
@@ -51,7 +56,9 @@ def train(
         save_date_time = save_date_time,
         save_best = save_best,
         experimental_resample = experimental_resample,
-        ema_decay = ema_decay
+        ema_decay = ema_decay,
+        num_cutouts = num_cutouts,
+        center_bias = center_bias,
     )
 
     if not overwrite and imagine.filename.exists():