Skip to content

Commit

Permalink
Finish adding fusion stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Feb 26, 2024
1 parent 8132fca commit 4a5ea2f
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 28 deletions.
18 changes: 18 additions & 0 deletions shimmer/modules/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def on_before_gw_encode_tr(self, x: torch.Tensor) -> torch.Tensor:
def on_before_gw_encode_cy(self, x: torch.Tensor) -> torch.Tensor:
return x

def on_before_gw_encode_broadcast(self, x: torch.Tensor) -> torch.Tensor:
return x

def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
"""
Computes the loss of the modality. If you implement compute_dcy_loss,
Expand Down Expand Up @@ -127,3 +130,18 @@ def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutpu
used for training. Any other key will be logged, but not trained on.
"""
return self.compute_loss(pred, target)

def compute_broadcast_loss(
self, pred: torch.Tensor, target: torch.Tensor
) -> LossOutput:
"""
Computes the loss for a broadcast (fusion). Override if the translation loss is
different that the generic loss.
Args:
pred: tensor with a predicted latent unimodal representation
target: target tensor
Results:
Dict of losses. Must contain the "loss" key with the total loss
used for training. Any other key will be logged, but not trained on.
"""
return self.compute_loss(pred, target)
2 changes: 1 addition & 1 deletion shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def generic_step(
domain_latents = self.encode_domains(batch)
batch_size = self._get_batch_size(domain_latents)

loss_output = self.loss_mod.step(domain_latents)
loss_output = self.loss_mod.step(domain_latents, mode)

for name, metric in loss_output.all.items():
self.log(
Expand Down
98 changes: 71 additions & 27 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,12 @@ class GWLossesBase(torch.nn.Module, ABC):
"""

@abstractmethod
def step(
self,
domain_latents: LatentsT,
) -> LossOutput:
def step(self, domain_latents: LatentsT, mode: str) -> LossOutput:
"""
Computes the losses
Args:
domain_latents: All latent groups
mode: train/val/test
Returns: LossOutput object
"""
...
Expand Down Expand Up @@ -249,8 +247,7 @@ def contrastive_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]:
return _contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn)

def step(
self,
domain_latents: Mapping[frozenset[str], Mapping[str, torch.Tensor]],
self, domain_latents: Mapping[frozenset[str], Mapping[str, torch.Tensor]], _
) -> LossOutput:
metrics: dict[str, torch.Tensor] = {}

Expand Down Expand Up @@ -328,8 +325,7 @@ def kl_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]:
return losses

def step(
self,
domain_latents: Mapping[frozenset[str], Mapping[str, torch.Tensor]],
self, domain_latents: Mapping[frozenset[str], Mapping[str, torch.Tensor]], _
) -> LossOutput:
metrics: dict[str, torch.Tensor] = {}

Expand Down Expand Up @@ -431,50 +427,98 @@ def translation_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]:
def contrastive_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]:
return _contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn)

def broadcast_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]:
def broadcast_loss(
self, latent_domains: LatentsT, mode: str
) -> dict[str, torch.Tensor]:
losses: dict[str, torch.Tensor] = {}
metrics: dict[str, torch.Tensor] = {}
keys: list[set[str]] = []

for latents in latent_domains.values():
if len(latents) < 2:
continue
for domain1_name, domain1 in latents.items():
z1 = gw_mod.encode(
gw_mod.on_before_gw_encode_cont({domain1_name: domain1})
)
for domain2_name, domain2 in latents.items():
selected_domains = {domain1_name, domain2_name}
if domain1_name == domain2_name or selected_domains in keys:
continue
batch_size = latents[next(iter(latents))].size(0)
device = latents[next(iter(latents))].device

if mode == "val":
scaling_factors = sample_scaling_factors(0.5, batch_size, 5.0, device)
else:
scaling_factors = sample_scaling_factors(0.0, batch_size, 5.0, device)

for scale_type, (
scaling_factor_1,
scaling_factor_2,
indices,
) in scaling_factors.items():
scaled_latents = {}

for i, (domain_name, latent) in enumerate(latents.items()):
scaling_factor = scaling_factor_1 if i == 0 else scaling_factor_2
scaled_latents_subset = latent[indices] * scaling_factor.unsqueeze(
-1
)
scaled_latents_subset = scaled_latents_subset.to(latent)

keys.append(selected_domains)
scaled_latents[domain_name] = scaled_latents_subset

loss_name = f"contrastive_{domain1_name}_and_{domain2_name}"
z2 = gw_mod.encode(
gw_mod.on_before_gw_encode_cont({domain2_name: domain2})
encoded_latents_for_subset = self.gw_mod.encode(scaled_latents)
encoded_latents_for_subset = torch.tanh(encoded_latents_for_subset)
decoded_latents_for_subset = self.gw_mod.decode(
encoded_latents_for_subset
)

for domain_name, latent in latents.items():
domain_mod = self.domain_mods[domain_name]
decoded_latent_for_domain_subset = decoded_latents_for_subset[
domain_name
]
original_latent_for_domain_subset = latents[domain_name][indices]
loss_output = domain_mod.compute_broadcast_loss(
decoded_latent_for_domain_subset,
original_latent_for_domain_subset,
)
loss_output = contrastive_fn(z1, z2)
losses[loss_name] = loss_output.loss
loss_key = f"{domain_name}_loss_{scale_type}"

metrics.update(
{f"{loss_name}_{k}": v for k, v in loss_output.metrics.items()}
{
f"broadcast_{loss_key}_{k}": v
for k, v in loss_output.metrics.items()
}
)
losses[loss_key] = loss_output.loss.mean()

binary_count = scaling_factors["binary"][2].size(0)
softmax_count = scaling_factors["softmax"][2].size(0)
total_count = binary_count + softmax_count

for domain_name, latent in latents.items():
full_loss_key = f"{domain_name}_full_loss"

losses["contrastives"] = torch.stack(list(losses.values()), dim=0).mean()
binary_loss_key = f"{domain_name}_loss_binary"
softmax_loss_key = f"{domain_name}_loss_softmax"

binary_loss = losses[binary_loss_key] * (binary_count / total_count)
softmax_loss = losses[softmax_loss_key] * (softmax_count / total_count)

losses[full_loss_key] = binary_loss + softmax_loss

losses["broadcast"] = torch.stack(
[loss for name, loss in losses.items() if "full_loss" in name], dim=0
).mean()
losses.update(metrics)
return losses

def step(
self,
domain_latents: Mapping[frozenset[str], Mapping[str, torch.Tensor]],
mode: str,
) -> LossOutput:
metrics: dict[str, torch.Tensor] = {}

metrics.update(self.demi_cycle_loss(domain_latents))
metrics.update(self.cycle_loss(domain_latents))
metrics.update(self.translation_loss(domain_latents))
metrics.update(self.contrastive_loss(domain_latents))
metrics.update(self.broadcast_loss(domain_latents))
metrics.update(self.broadcast_loss(domain_latents, mode))

loss = metrics["broadcast"]

Expand Down

0 comments on commit 4a5ea2f

Please sign in to comment.