Skip to content

Commit

Permalink
refactor: remove fill_with_flags (#660)
Browse files Browse the repository at this point in the history
  • Loading branch information
onerandomusername authored Jul 28, 2022
1 parent cd8acef commit 6769f80
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 33 deletions.
1 change: 1 addition & 0 deletions changelog/660.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Remove the internal ``fill_with_flags`` decorator for flags classes and use the built in :meth:`object.__init_subclass__` method.
54 changes: 25 additions & 29 deletions disnake/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,25 +117,6 @@ def all_flags_value(flags: Dict[str, int]) -> int:
return functools.reduce(operator.or_, flags.values())


def fill_with_flags(*, inverted: bool = False):
def decorator(cls: Type[BF]) -> Type[BF]:
cls.VALID_FLAGS = {}
for name, value in cls.__dict__.items():
if isinstance(value, flag_value):
value._parent = cls
cls.VALID_FLAGS[name] = value.flag

if inverted:
cls.DEFAULT_VALUE = all_flags_value(cls.VALID_FLAGS)
else:
cls.DEFAULT_VALUE = 0

return cls

return decorator


# n.b. flags must inherit from this and use the decorator above
class BaseFlags:
VALID_FLAGS: ClassVar[Dict[str, int]]
DEFAULT_VALUE: ClassVar[int]
Expand All @@ -151,6 +132,29 @@ def __init__(self, **kwargs: bool):
raise TypeError(f"{key!r} is not a valid flag name.")
setattr(self, key, value)

@classmethod
def __init_subclass__(cls, inverted: bool = False, no_fill_flags: bool = False):
# add a way to bypass filling flags, eg for ListBaseFlags.
if no_fill_flags:
return cls

# use the parent's current flags as a base if they exist
cls.VALID_FLAGS = getattr(cls, "VALID_FLAGS", {}).copy()

for name, value in cls.__dict__.items():
if isinstance(value, flag_value):
value._parent = cls
cls.VALID_FLAGS[name] = value.flag

if not cls.VALID_FLAGS:
raise RuntimeError(
"At least one flag must be defined in a BaseFlags subclass, or 'no_fill_flags' must be set to True"
)

cls.DEFAULT_VALUE = all_flags_value(cls.VALID_FLAGS) if inverted else 0

return cls

@classmethod
def _from_value(cls, value: int) -> Self:
self = cls.__new__(cls)
Expand Down Expand Up @@ -295,7 +299,7 @@ def _set_flag(self, o: int, toggle: bool) -> None:
raise TypeError(f"Value to set for {self.__class__.__name__} must be a bool.")


class ListBaseFlags(BaseFlags):
class ListBaseFlags(BaseFlags, no_fill_flags=True):
"""
A base class for flags that aren't powers of 2.
Instead, values are used as exponents to map to powers of 2 to avoid collisions,
Expand Down Expand Up @@ -330,8 +334,7 @@ def __repr__(self) -> str:
return f"<{self.__class__.__name__} values={self.values}>"


@fill_with_flags(inverted=True)
class SystemChannelFlags(BaseFlags):
class SystemChannelFlags(BaseFlags, inverted=True):
"""
Wraps up a Discord system channel flag value.
Expand Down Expand Up @@ -466,7 +469,6 @@ def join_notification_replies(self):
return 8


@fill_with_flags()
class MessageFlags(BaseFlags):
"""
Wraps up a Discord Message flag value.
Expand Down Expand Up @@ -621,7 +623,6 @@ def failed_to_mention_roles_in_thread(self):
return 1 << 8


@fill_with_flags()
class PublicUserFlags(BaseFlags):
"""
Wraps up the Discord User Public flags.
Expand Down Expand Up @@ -814,7 +815,6 @@ def all(self) -> List[UserFlags]:
return [public_flag for public_flag in UserFlags if self._has_flag(public_flag.value)]


@fill_with_flags()
class Intents(BaseFlags):
"""
Wraps up a Discord gateway intent flag.
Expand Down Expand Up @@ -1449,7 +1449,6 @@ def automod(self):
return (1 << 20) | (1 << 21)


@fill_with_flags()
class MemberCacheFlags(BaseFlags):
"""Controls the library's cache policy when it comes to members.
Expand Down Expand Up @@ -1632,7 +1631,6 @@ def _voice_only(self):
return self.value == 1


@fill_with_flags()
class ApplicationFlags(BaseFlags):
"""
Wraps up the Discord Application flags.
Expand Down Expand Up @@ -1777,7 +1775,6 @@ def gateway_message_content_limited(self):
return 1 << 19


@fill_with_flags()
class ChannelFlags(BaseFlags):
"""Wraps up the Discord Channel flags.
Expand Down Expand Up @@ -1872,7 +1869,6 @@ def pinned(self):
return 1 << 1


@fill_with_flags()
class AutoModKeywordPresets(ListBaseFlags):
"""
Wraps up the pre-defined auto moderation keyword lists, provided by Discord.
Expand Down
3 changes: 1 addition & 2 deletions disnake/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Iterator, Optional, Set, Tuple

from .flags import BaseFlags, alias_flag_value, fill_with_flags, flag_value
from .flags import BaseFlags, alias_flag_value, flag_value

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down Expand Up @@ -68,7 +68,6 @@ def wrapped(cls):
return wrapped


@fill_with_flags()
class Permissions(BaseFlags):
"""Wraps up the Discord permission value.
Expand Down
3 changes: 1 addition & 2 deletions tests/test_flags.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import pytest

from disnake.flags import ListBaseFlags, fill_with_flags, flag_value
from disnake.flags import ListBaseFlags, flag_value


@fill_with_flags()
class _ListFlags(ListBaseFlags):
@flag_value
def flag1(self):
Expand Down

0 comments on commit 6769f80

Please sign in to comment.