Skip to content

Commit

Permalink
Add coeffs broadcast (#57)
Browse files Browse the repository at this point in the history
* adding coeffs (yet to test)

* ruff fixes :/

* fused loss should be called fused loss

* logging the broadcast loss too

* fixed broadcast loss computation

* broadcast sum takes no coeffs

* fix comment -- fused, not broadcast

* resolve test conflict

* FINALLY got a proper test..

* Add missing comma

* fix dumb bug

---------

Co-authored-by: bdvllrs <[email protected]>
  • Loading branch information
RolandBERTINJOHANNET and bdvllrs authored Apr 16, 2024
1 parent e73de91 commit c6f0bfe
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 14 deletions.
35 changes: 30 additions & 5 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,8 +664,17 @@ class BroadcastLossCoefs(TypedDict, total=False):
contrastives: float
"""Contrastive loss coefficient."""

broadcast: float
"""Broadcast loss coefficient."""
fused: float
"""fused loss coefficient (encode multiple domains and decode to one of them)."""

demi_cycles: float
"""demi_cycles loss coefficient. Demi-cycles are always one-to-one"""

cycles: float
"""cycles loss coefficient. Cycles can be many-to-one"""

translations: float
"""translation loss coefficient. Translation, like cycles, can be many-to-one."""


class GWLossesFusion(GWLossesBase):
Expand Down Expand Up @@ -732,6 +741,7 @@ def broadcast_loss(
demi_cycle_losses: list[str] = []
cycle_losses: list[str] = []
translation_losses: list[str] = []
fused_losses: list[str] = []

for group_domains, latents in latent_domains.items():
encoded_latents = self.gw_mod.encode(latents)
Expand Down Expand Up @@ -775,8 +785,10 @@ def broadcast_loss(

if num_active_domains == 1 and domain in selected_latents:
demi_cycle_losses.append(loss_label + "_loss")
if num_active_domains == 1 and domain not in selected_latents:
elif domain not in selected_latents:
translation_losses.append(loss_label + "_loss")
else: # fused loss
fused_losses.append(loss_label + "_loss")

if num_active_domains < num_total_domains:
inverse_selected_latents = {
Expand Down Expand Up @@ -830,9 +842,13 @@ def broadcast_loss(
metrics["translations"] = torch.mean(
torch.stack([losses[loss_name] for loss_name in translation_losses])
)
if fused_losses:
metrics["fused"] = torch.mean(
torch.stack([losses[loss_name] for loss_name in fused_losses])
)

total_loss = torch.mean(torch.stack(list(losses.values())))
return {"broadcast": total_loss, **metrics}
metrics.update(losses)
return metrics

def step(
self,
Expand Down Expand Up @@ -864,4 +880,13 @@ def step(
dim=0,
).mean()

metrics["broadcast_loss"] = torch.stack(
[
metrics[name] * 1.0 # broadcast loss is all-encompassing
for name, coef in self.loss_coefs.items()
if isinstance(coef, float) and coef > 0 and name != "contrastives"
],
dim=0,
).mean()

return LossOutput(loss, metrics)
21 changes: 12 additions & 9 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ def test_broadcast_loss():
gw_encoders = {"domain1": nn.Linear(10, 10), "domain2": nn.Linear(10, 10)}
gw_decoders = {"domain1": nn.Linear(10, 10), "domain2": nn.Linear(10, 10)}
workspace_dim = 10
loss_coefs: BroadcastLossCoefs = {"broadcast": 1.0, "contrastives": 0.1}
loss_coefs: BroadcastLossCoefs = {
"fused": 1.0,
"cycles": 1.0,
"demi_cycles": 1.0,
"translations": 1.0,
"contrastives": 0.1,
}

gw_fusion = GlobalWorkspaceFusion(
domain_mods,
Expand All @@ -58,18 +64,15 @@ def test_broadcast_loss():
# Test broadcast_loss with the corrected structure
output = gw_fusion.loss_mod.broadcast_loss(latent_domains, "train")

# Ensure the total broadcast loss is returned and is a single value
assert "broadcast" in output
assert output["broadcast"].dim() == 0, "broadcast loss should be a single value."

er_msg = "Demi-cycle, cycle, and translation metrics should be in the output."
er_msg = "Demi-cycle, cycle, fused and translation metrics should be in the output."
assert all(
metric in output for metric in ["demi_cycles", "cycles", "translations"]
metric in output
for metric in ["demi_cycles", "cycles", "translations", "fused"]
), er_msg

er_msg = "Losses should be a 1D tensor with size equal to the batch size."
er_msg = "Losses should be scalar tensors or 1D tensor with size equal to one."
assert all(
loss.dim() == 1 and loss.size(0) == 5
(loss.dim() == 0 or (loss.dim() == 1 and loss.size(0) == 1))
for key, loss in output.items()
if key.endswith("_loss")
), er_msg

0 comments on commit c6f0bfe

Please sign in to comment.