Skip to content

Commit

Permalink
Fix broadcast test (#55)
Browse files Browse the repository at this point in the history
* took out unnecessary contrastive function

* added asserts
  • Loading branch information
RolandBERTINJOHANNET authored Apr 12, 2024
1 parent 93c4f1b commit 5143a54
Showing 1 changed file with 16 additions and 51 deletions.
67 changes: 16 additions & 51 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,11 @@
import torch
from torch import nn
from torch.nn.functional import cross_entropy, normalize

from shimmer.modules.contrastive_loss import ContrastiveLossType
from shimmer.modules.domain import DomainModule, LossOutput
from shimmer.modules.global_workspace import GlobalWorkspaceFusion
from shimmer.modules.losses import BroadcastLossCoefs


def contrastive_loss(x: torch.Tensor, y: torch.Tensor) -> LossOutput:
"""
Simplified CLIP-like contrastive loss that matches the expected signature.
Args:
x (torch.Tensor): Predictions.
y (torch.Tensor): Targets.
Returns:
LossOutput: A dataclass containing the computed loss and optionally
additional metrics.
"""
# Assuming logit_scale is a pre-defined tensor if needed for the calculation
# For the sake of matching the function signature,
# we'll remove it from the parameters
# Similarly, we assume a fixed reduction mode for simplicity
logit_scale = torch.tensor(
1.0
) # Placeholder for an actual logit scale if necessary
reduction = "mean" # Fixed reduction mode

xn = normalize(x, dim=-1)
yn = normalize(y, dim=-1)
logits = torch.matmul(xn, yn.t())
labels = torch.arange(xn.size(0), device=xn.device)
ce_loss = 0.5 * (
cross_entropy(logits * logit_scale.exp(), labels, reduction=reduction)
+ cross_entropy(logits.t() * logit_scale.exp(), labels, reduction=reduction)
)

return LossOutput(loss=ce_loss)


class DummyDomainModule(DomainModule):
def __init__(self, latent_dim: int):
super().__init__(latent_dim)
Expand All @@ -58,18 +23,14 @@ def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
return LossOutput(loss=loss) # Constructing LossOutput with the loss


def setup_global_workspace_fusion() -> GlobalWorkspaceFusion:
"""
Setting up the test environment for GlobalWorkspaceFusion
"""
def test_broadcast_loss():
domain_mods: dict[str, DomainModule] = {
"domain1": DummyDomainModule(latent_dim=10),
"domain2": DummyDomainModule(latent_dim=10),
}
gw_encoders = {"domain1": nn.Linear(10, 10), "domain2": nn.Linear(10, 10)}
gw_decoders = {"domain1": nn.Linear(10, 10), "domain2": nn.Linear(10, 10)}
workspace_dim = 10
contrastive_fn: ContrastiveLossType = contrastive_loss
loss_coefs: BroadcastLossCoefs = {"broadcast": 1.0, "contrastives": 0.1}

gw_fusion = GlobalWorkspaceFusion(
Expand All @@ -83,15 +44,8 @@ def setup_global_workspace_fusion() -> GlobalWorkspaceFusion:
optim_weight_decay=0.0,
scheduler_args=None, # Simplified for testing
learn_logit_scale=False,
contrastive_loss=contrastive_fn,
)

return gw_fusion


def test_broadcast_loss():
gw_fusion = setup_global_workspace_fusion()

# Adjusting the dummy data to fit the expected input structure for broadcast_loss
# Now using a frozenset for the keys to match LatentsDomainGroupsT
latent_domains = {
Expand All @@ -103,8 +57,19 @@ def test_broadcast_loss():

# Test broadcast_loss with the corrected structure
output = gw_fusion.loss_mod.broadcast_loss(latent_domains, "train")
print(output)


# Call the test function to execute the test
test_broadcast_loss()
# Ensure the total broadcast loss is returned and is a single value
assert "broadcast" in output
assert output["broadcast"].dim() == 0, "broadcast loss should be a single value."

er_msg = "Demi-cycle, cycle, and translation metrics should be in the output."
assert all(
metric in output for metric in ["demi_cycles", "cycles", "translations"]
), er_msg

er_msg = "Losses should be a 1D tensor with size equal to the batch size."
assert all(
loss.dim() == 1 and loss.size(0) == 5
for key, loss in output.items()
if key.endswith("_loss")
), er_msg

0 comments on commit 5143a54

Please sign in to comment.