Skip to content

Commit

Permalink
fix: fix literal with widgeT_type (#586)
Browse files Browse the repository at this point in the history
  • Loading branch information
tlambert03 authored Sep 23, 2023
1 parent b82b3c8 commit 2760ae7
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
32 changes: 24 additions & 8 deletions src/magicgui/type_map/_type_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,8 @@ def match_type(type_: Any, default: Any | None = None) -> WidgetTuple | None:
return widgets.FunctionGui, {"function": default}

origin = get_origin(type_) or type_
if origin is Literal:
choices = []
nullable = False
for choice in get_args(type_):
if choice is None:
nullable = True
else:
choices.append(choice)
choices, nullable = _literal_choices(type_)
if choices is not None: # it's a Literal type
return widgets.ComboBox, {"choices": choices, "nullable": nullable}

# sequence of paths
Expand Down Expand Up @@ -193,6 +187,24 @@ def _type_optional(
return type_, nullable


def _literal_choices(annotation: Any) -> tuple[list | None, bool]:
"""Return choices and nullable for a Literal type.
if annotation is not a Literal type, returns (None, False)
"""
origin = get_origin(annotation) or annotation
choices: list | None = None
nullable = False
if origin is Literal:
choices = []
for choice in get_args(annotation):
if choice is None:
nullable = True
else:
choices.append(choice)
return choices, nullable


def _pick_widget_type(
value: Any = Undefined,
annotation: Any = Undefined,
Expand All @@ -219,6 +231,10 @@ def _pick_widget_type(
_type, optional = _type_optional(value, annotation)
options.setdefault("nullable", optional)
choices = choices or (isinstance(_type, EnumMeta) and _type)
literal_choices, nullable = _literal_choices(annotation)
if literal_choices is not None:
choices = literal_choices
options["nullable"] = nullable

if "widget_type" in options:
widget_type = options.pop("widget_type")
Expand Down
10 changes: 10 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,13 @@ def test_type_registered_warns():
register_type(Path, widget_type=widgets.TextEdit)
assert isinstance(widgets.create_widget(annotation=Path), widgets.TextEdit)
assert isinstance(widgets.create_widget(annotation=Path), widgets.FileEdit)


def test_pick_widget_literal():
from typing import Literal

cls, options = type_map.get_widget_class(
annotation=Annotated[Literal["a", "b"], {"widget_type": "RadioButtons"}]
)
assert cls == widgets.RadioButtons
assert set(options["choices"]) == {"a", "b"}

0 comments on commit 2760ae7

Please sign in to comment.