Skip to content

Commit

Permalink
feat: add multiple classes conditionning
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob committed Mar 8, 2024
1 parent ab18fe1 commit d91a3c6
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 20 deletions.
6 changes: 1 addition & 5 deletions models/modules/diffusion_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,7 @@ def __init__(
self.cond_embed_gammas_in = inner_channel
else:
self.cond_embed_dim = cond_embed_dim

if any(cond in self.denoise_fn.conditioning for cond in ["class", "ref"]):
self.cond_embed_gammas = self.cond_embed_dim // 2
else:
self.cond_embed_gammas = self.cond_embed_dim
self.cond_embed_gammas = self.denoise_fn.cond_embed_gammas

self.cond_embed = nn.Sequential(
nn.Linear(self.cond_embed_gammas, self.cond_embed_gammas),
Expand Down
38 changes: 31 additions & 7 deletions models/modules/palette_denoise_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,38 @@ def __init__(self, model, cond_embed_dim, ref_embed_net, conditioning, nclasses)
self.conditioning = conditioning
self.cond_embed_dim = cond_embed_dim
self.ref_embed_net = ref_embed_net
self.cond_embed_gammas = cond_embed_dim

# Label embedding
if "class" in conditioning:
cond_embed_class = cond_embed_dim // 2
self.netl_embedder_class = LabelEmbedder(
nclasses,
cond_embed_class, # * image_size * image_size
)
nn.init.normal_(self.netl_embedder_class.embedding_table.weight, std=0.02)
if type(nclasses) == list:
# TODO this is arbitrary, half for class & half for detector
cond_embed_class = cond_embed_dim // (len(nclasses) + 1)
self.netl_embedders_class = nn.ModuleList(
[LabelEmbedder(nc, cond_embed_class) for nc in nclasses]
)
for embed in self.netl_embedders_class:
self.cond_embed_gammas -= cond_embed_class
nn.init.normal_(embed.embedding_table.weight, std=0.02)
else:
# TODO this can be included in the general case
cond_embed_class = cond_embed_dim // 2
self.netl_embedder_class = LabelEmbedder(
nclasses,
cond_embed_class, # * image_size * image_size
)
self.cond_embed_gammas -= cond_embed_class
nn.init.normal_(
self.netl_embedder_class.embedding_table.weight, std=0.02
)

if "mask" in conditioning:
cond_embed_mask = cond_embed_dim
self.netl_embedder_mask = LabelEmbedder(
nclasses,
cond_embed_mask, # * image_size * image_size
)
self.cond_embed_gammas -= cond_embed_class
nn.init.normal_(self.netl_embedder_mask.embedding_table.weight, std=0.02)

# Instantiate model
Expand Down Expand Up @@ -90,6 +106,7 @@ def __init__(self, model, cond_embed_dim, ref_embed_net, conditioning, nclasses)
self.emb_layers = nn.Sequential(
torch.nn.SiLU(), nn.Linear(ref_embed_dim, cond_embed_class)
)
self.cond_embed_gammas -= cond_embed_class

def forward(self, input, embed_noise_level, cls, mask, ref):
cls_embed, mask_embed, ref_embed = self.compute_cond(input, cls, mask, ref)
Expand All @@ -114,7 +131,14 @@ def forward(self, input, embed_noise_level, cls, mask, ref):

def compute_cond(self, input, cls, mask, ref):
if "class" in self.conditioning and cls is not None:
cls_embed = self.netl_embedder_class(cls)
if hasattr(self, "netl_embedders_class"):
cls_embed = []
for i in range(len(self.netl_embedders_class)):
cls_embed.append(self.netl_embedders_class[i](cls[:, i]))
cls_embed = torch.cat(cls_embed, dim=1)
else:
# TODO general case
cls_embed = self.netl_embedder_class(cls)
else:
cls_embed = None

Expand Down
18 changes: 14 additions & 4 deletions models/palette_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,11 @@ def __init__(self, opt, rank):

max_visual_outputs = max(self.opt.train_batch_size, self.opt.num_test_images)

self.num_classes = max(
self.opt.f_s_semantic_nclasses, self.opt.cls_semantic_nclasses
)
# self.num_classes = max(
# self.opt.f_s_semantic_nclasses, self.opt.cls_semantic_nclasses
# )
# TODO decide if we keep cls_semantic_nclasses (not used atm)
self.num_classes = self.opt.f_s_semantic_nclasses

self.use_ref = (
self.opt.alg_diffusion_cond_image_creation == "ref"
Expand Down Expand Up @@ -583,10 +585,18 @@ def inference(self, nb_imgs, offset=0):

# task: super resolution, pix2pix
elif self.task in ["super_resolution", "pix2pix"]:
cls = None

if "class" in self.opt.alg_diffusion_cond_embed:
cls = []
for i in self.num_classes:
cls.append(torch.randint_like(self.cls[:, 0], 0, i))
cls = torch.stack(cls, dim=1)

self.output, self.visuals = netG.restoration(
y_cond=self.cond_image[:nb_imgs],
sample_num=self.sample_num,
cls=None,
cls=cls,
)
self.fake_B = self.output

Expand Down
5 changes: 3 additions & 2 deletions options/common_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,10 @@ def initialize(self, parser):
)
parser.add_argument(
"--f_s_semantic_nclasses",
default=2,
default=[2],
nargs="+",
type=int,
help="number of classes of the semantic loss classifier",
help="number of classes of the semantic loss classifiers",
)
parser.add_argument(
"--f_s_class_weights",
Expand Down
3 changes: 2 additions & 1 deletion options/inference_diffusion_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def initialize(self, parser):
parser.add_argument(
"--cls",
type=int,
default=-1,
nargs="+",
default=[-1],
help="override input bbox classe for generation",
)

Expand Down
7 changes: 6 additions & 1 deletion scripts/gen_single_image_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,12 @@ def generate(

if opt.model_type == "palette":
if "class" in model.denoise_fn.conditioning:
cls_tensor = torch.ones(1, dtype=torch.int64, device=device) * cls
if len(cls_value) > 1:
cls_tensor = torch.tensor(
cls_value, dtype=torch.int64, device=device
).unsqueeze(0)
else:
cls_tensor = torch.ones(1, dtype=torch.int64, device=device) * cls_value
else:
cls_tensor = None
if ref is not None:
Expand Down

0 comments on commit d91a3c6

Please sign in to comment.