Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for flag enum #207

Merged
merged 4 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions src/superqt/combobox/_enum_combobox.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from enum import Enum, EnumMeta
from typing import Optional, TypeVar
from enum import Enum, EnumMeta, Flag
from functools import reduce
from itertools import combinations
from operator import or_
from typing import Optional, Tuple, TypeVar

from qtpy.QtCore import Signal
from qtpy.QtWidgets import QComboBox
Expand All @@ -21,6 +24,10 @@ def _get_name(enum_value: Enum):
return name


def _get_name_with_value(enum_value: Enum) -> Tuple[str, Enum]:
return _get_name(enum_value), enum_value


class QEnumComboBox(QComboBox):
"""ComboBox presenting options from a python Enum.

Expand All @@ -47,9 +54,20 @@ def setEnumClass(self, enum: Optional[EnumMeta], allow_none=False):
self._allow_none = allow_none and enum is not None
if allow_none:
super().addItem(NONE_STRING)
names = map(_get_name, self._enum_class.__members__.values())
_names = dict.fromkeys(names) # remove duplicates/aliases, keep order
super().addItems(list(_names))
names_ = self._get_enum_member_list(enum)
super().addItems(list(names_))

@staticmethod
def _get_enum_member_list(enum: Optional[EnumMeta]):
if issubclass(enum, Flag):
members = list(enum.__members__.values())
comb_list = []
for i in range(len(members)):
comb_list.extend(reduce(or_, x) for x in combinations(members, i + 1))

else:
comb_list = list(enum.__members__.values())
return dict(map(_get_name_with_value, comb_list))

tlambert03 marked this conversation as resolved.
Show resolved Hide resolved
def enumClass(self) -> Optional[EnumMeta]:
"""Return current Enum class."""
Expand All @@ -70,11 +88,7 @@ def currentEnum(self) -> Optional[EnumType]:
if self._allow_none:
if self.currentText() == NONE_STRING:
return None
else:
return list(self._enum_class.__members__.values())[
self.currentIndex() - 1
]
return list(self._enum_class.__members__.values())[self.currentIndex()]
return self._get_enum_member_list(self._enum_class)[self.currentText()]
return None

def setCurrentEnum(self, value: Optional[EnumType]) -> None:
Expand Down
63 changes: 62 additions & 1 deletion tests/test_enum_comb_box.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum, IntEnum
import sys
from enum import Enum, Flag, IntEnum, IntFlag

import pytest

Expand Down Expand Up @@ -42,6 +43,24 @@ class IntEnum1(IntEnum):
c = 5


class IntFlag1(IntFlag):
a = 1
b = 2
c = 4


class Flag1(Flag):
a = 1
b = 2
c = 4


class IntFlag2(IntFlag):
a = 1
b = 2
c = 3


def test_simple_create(qtbot):
enum = QEnumComboBox(enum_class=Enum1)
qtbot.addWidget(enum)
Expand Down Expand Up @@ -142,3 +161,45 @@ def test_simple_create_int_enum(qtbot):
qtbot.addWidget(enum)
assert enum.count() == 3
assert [enum.itemText(i) for i in range(enum.count())] == ["a", "b", "c"]


@pytest.mark.parametrize("enum_class", [IntFlag1, Flag1])
def test_enum_flag_create(qtbot, enum_class):
enum = QEnumComboBox(enum_class=enum_class)
qtbot.addWidget(enum)
assert enum.count() == 7
assert [enum.itemText(i) for i in range(enum.count())] == [
"a",
"b",
"c",
"a|b",
"a|c",
"b|c",
"a|b|c",
]
enum.setCurrentText("a|b")
assert enum.currentEnum() == enum_class.a | enum_class.b


def test_enum_flag_create_collision(qtbot):
enum = QEnumComboBox(enum_class=IntFlag2)
qtbot.addWidget(enum)
assert enum.count() == 3
assert [enum.itemText(i) for i in range(enum.count())] == ["a", "b", "c"]


@pytest.mark.skipif(
sys.version_info < (3, 11), reason="StrEnum is introduced in python 3.11"
)
def test_create_str_enum(qtbot):
from enum import StrEnum

class StrEnum1(StrEnum):
a = "a"
b = "b"
c = "c"

enum = QEnumComboBox(enum_class=StrEnum1)
qtbot.addWidget(enum)
assert enum.count() == 3
assert [enum.itemText(i) for i in range(enum.count())] == ["a", "b", "c"]
Loading