diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 852bc1cc..6ee666b2 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -278,8 +278,7 @@ def encode_domains(self, batch: RawDomainGroupsT) -> LatentsDomainGroupsDT: """ return { domains: { - name: self.domain_mods[name].encode(domain) - for name, domain in data.items() + name: self.encode_domain(domain, name) for name, domain in data.items() } for domains, data in batch.items() } @@ -317,7 +316,7 @@ def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroup """ return { domains: { - name: self.domain_mods[name].decode(domain) + name: self.decode_domain(domain, name) for name, domain in latents.items() } for domains, latents in latents_domain.items()