Skip to content

Commit

Permalink
Use ABC instead of ABCMeta (readability)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Jan 22, 2024
1 parent 6637c15 commit f20c900
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ Fix bugs related to imports and `default_decoders`.
refers to `DeterministicGlobalWorkspace`.

# 0.4.0
* Use ABCMeta for abstract methods.
* Use ABC for abstract methods.
2 changes: 1 addition & 1 deletion shimmer/modules/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class DomainModule(pl.LightningModule):
"""
Base class for a DomainModule.
We do not use ABCMeta here because some modules could be without encore or decoder.
We do not use ABC here because some modules could be without encore or decoder.
"""

def encode(self, x: Any) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import ABCMeta, abstractmethod
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping

import torch
Expand Down Expand Up @@ -79,7 +79,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return self.layers(x), self.uncertainty_level.expand(x.size(0), -1)


class GWModule(nn.Module, metaclass=ABCMeta):
class GWModule(nn.Module, ABC):
domain_descr: Mapping[str, DomainDescription]
latent_dim: int

Expand Down
4 changes: 2 additions & 2 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import ABCMeta, abstractmethod
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Literal

Expand Down Expand Up @@ -65,7 +65,7 @@ def contrastive_loss_with_uncertainty(
return 0.5 * (ce + ce_t)


class GWLosses(torch.nn.Module, metaclass=ABCMeta):
class GWLosses(torch.nn.Module, ABC):
"""
Base Abstract Class for Global Workspace (GW) losses. This module is used
to compute the different losses of the GW (typically translation, cycle,
Expand Down
6 changes: 3 additions & 3 deletions shimmer/modules/vae.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from abc import ABCMeta, abstractmethod
from abc import ABC, abstractmethod
from collections.abc import Sequence

import torch
Expand Down Expand Up @@ -29,7 +29,7 @@ def gaussian_nll(
)


class VAEEncoder(nn.Module, metaclass=ABCMeta):
class VAEEncoder(nn.Module, ABC):
"""
Base class for a VAE encoder.
"""
Expand All @@ -48,7 +48,7 @@ def forward(
...


class VAEDecoder(nn.Module, metaclass=ABCMeta):
class VAEDecoder(nn.Module, ABC):
"""
Base class for a VAE decoder.
"""
Expand Down

0 comments on commit f20c900

Please sign in to comment.