Skip to content

Commit

Permalink
fix an issue with karras unet when no class labels, and make sure it …
Browse files Browse the repository at this point in the history
…is compatible with GaussianDiffusion wrapper
  • Loading branch information
lucidrains committed Feb 6, 2024
1 parent 3b78964 commit f363854
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def __init__(
):
super().__init__()
assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
assert not model.random_or_learned_sinusoidal_cond
assert not hasattr(model, 'random_or_learned_sinusoidal_cond') or not model.random_or_learned_sinusoidal_cond

self.model = model

Expand Down
7 changes: 5 additions & 2 deletions denoising_diffusion_pytorch/karras_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,8 +456,9 @@ def __init__(
self.needs_class_labels = exists(num_classes)
self.num_classes = num_classes

self.to_class_emb = Linear(num_classes, 4 * dim)
self.add_class_emb = MPAdd(t = mp_add_emb_t)
if self.needs_class_labels:
self.to_class_emb = Linear(num_classes, 4 * dim)
self.add_class_emb = MPAdd(t = mp_add_emb_t)

# final embedding activations

Expand Down Expand Up @@ -537,6 +538,8 @@ def __init__(
Decoder(curr_dim, curr_dim, has_attn = mid_has_attn, **block_kwargs),
])

self.out_dim = channels

@property
def downsample_factor(self):
return 2 ** self.num_downsamples
Expand Down
2 changes: 1 addition & 1 deletion denoising_diffusion_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.10.1'
__version__ = '1.10.2'

0 comments on commit f363854

Please sign in to comment.