Skip to content

Commit

Permalink
encode_domains and decode_domains now use encode_domain and `de…
Browse files Browse the repository at this point in the history
…code_domain`
  • Loading branch information
bdvllrs committed Jun 20, 2024
1 parent 9bcbc02 commit 8a68808
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,7 @@ def encode_domains(self, batch: RawDomainGroupsT) -> LatentsDomainGroupsDT:
"""
return {
domains: {
name: self.domain_mods[name].encode(domain)
for name, domain in data.items()
name: self.encode_domain(domain, name) for name, domain in data.items()
}
for domains, data in batch.items()
}
Expand Down Expand Up @@ -317,7 +316,7 @@ def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroup
"""
return {
domains: {
name: self.domain_mods[name].decode(domain)
name: self.decode_domain(domain, name)
for name, domain in latents.items()
}
for domains, latents in latents_domain.items()
Expand Down

0 comments on commit 8a68808

Please sign in to comment.