Skip to content

Commit

Permalink
make sure classifier free guidance condition scaling is exposed on DA…
Browse files Browse the repository at this point in the history
…LLE2 forward function
  • Loading branch information
lucidrains committed Apr 14, 2022
1 parent 4c827ba commit 7e93b9d
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 20 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,10 @@ dalle2 = DALLE2(
decoder = decoder
)

images = dalle2(['cute puppy chasing after a squirrel'])
images = dalle2(
['cute puppy chasing after a squirrel'],
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)

# save your image
```
Expand Down
39 changes: 21 additions & 18 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,16 @@ def __init__(
def forward_with_cond_scale(
self,
x,
*,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(x, *args, **kwargs)

if cond_scale == 1:
return self.forward(x, **kwargs)
return logits

logits = self.forward(x, **kwargs)
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs)
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale

def forward(
Expand Down Expand Up @@ -635,15 +636,16 @@ def __init__(
def forward_with_cond_scale(
self,
x,
*,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(x, *args, **kwargs)

if cond_scale == 1:
return self.forward(x, **kwargs)
return logits

logits = self.forward(x, **kwargs)
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs)
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale

def forward(
Expand Down Expand Up @@ -774,8 +776,8 @@ def q_posterior(self, x_start, x_t, t):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped

def p_mean_variance(self, x, t, image_embed, clip_denoised: bool):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, image_embed = image_embed))
def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale))

if clip_denoised:
x_recon.clamp_(-1., 1.)
Expand All @@ -784,31 +786,31 @@ def p_mean_variance(self, x, t, image_embed, clip_denoised: bool):
return model_mean, posterior_variance, posterior_log_variance

@torch.no_grad()
def p_sample(self, x, t, image_embed, clip_denoised = True, repeat_noise = False):
def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, clip_denoised = clip_denoised)
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

@torch.no_grad()
def p_sample_loop(self, shape, image_embed):
def p_sample_loop(self, shape, image_embed, cond_scale = 1):
device = self.betas.device

b = shape[0]
img = torch.randn(shape, device=device)

for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed)
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale)
return img

@torch.no_grad()
def sample(self, image_embed):
def sample(self, image_embed, cond_scale = 1.):
batch_size = image_embed.shape[0]
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed)
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale)

def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
Expand Down Expand Up @@ -869,7 +871,8 @@ def __init__(
@torch.no_grad()
def forward(
self,
text
text,
cond_scale = 1.
):
device = next(self.parameters()).device

Expand All @@ -878,5 +881,5 @@ def forward(
text = tokenizer.tokenize(text).to(device)

image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
images = self.decoder.sample(image_embed)
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
return images
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.6',
version = '0.0.7',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
Expand Down

0 comments on commit 7e93b9d

Please sign in to comment.