Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Support Annotated types in list/tuple #588

Merged
merged 1 commit into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions src/magicgui/widgets/_concrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
TYPE_CHECKING,
Any,
Callable,
ForwardRef,
Generic,
Iterable,
Iterator,
List,
Literal,
Sequence,
Tuple,
Expand All @@ -29,7 +31,7 @@
)
from weakref import ref

from typing_extensions import get_args, get_origin
from typing_extensions import Annotated, get_args, get_origin

from magicgui._type_resolution import resolve_single_type
from magicgui._util import merge_super_sigs, safe_issubclass
Expand Down Expand Up @@ -690,17 +692,27 @@ def annotation(self, value: Any) -> None:
self._args_type = None
return

value = resolve_single_type(value)
value_resolved = resolve_single_type(value)
if isinstance(value, (str, ForwardRef)):
value = value_resolved
# unwrap annotated (options are not needed to normalize `annotation`)
while get_origin(value) is Annotated:
value = get_args(value)[0]
arg: type | None = None

if value and value is not inspect.Parameter.empty:
orig = get_origin(value) or value
if value_resolved and value_resolved is not inspect.Parameter.empty:
orig = get_origin(value_resolved) or value_resolved
if not (safe_issubclass(orig, list) or isinstance(orig, list)):
raise TypeError(
f"cannot set annotation {value} to {type(self).__name__}."
)
args = get_args(value)
arg = args[0] if len(args) > 0 else None
args_resolved = get_args(value_resolved)
if len(args_resolved) > 0:
value = List[args_resolved[0]] # type: ignore
else:
value = list

self._annotation = value
self._args_type = arg
Expand Down Expand Up @@ -929,17 +941,23 @@ def annotation(self, value: Any) -> None:
self._args_types = None
return

value = resolve_single_type(value)
value_resolved = resolve_single_type(value)
if isinstance(value, (str, ForwardRef)):
value = value_resolved
# unwrap annotated (options are not needed to normalize `annotation`)
while get_origin(value) is Annotated:
value = get_args(value)[0]
args: tuple[type, ...] | None = None

if value and value is not inspect.Parameter.empty:
orig = get_origin(value)
if value_resolved and value_resolved is not inspect.Parameter.empty:
orig = get_origin(value_resolved)
if not (safe_issubclass(orig, tuple) or isinstance(orig, tuple)):
raise TypeError(
f"cannot set annotation {value} to {type(self).__name__}."
)
args = get_args(value)
value = Tuple[args]
args_resolved = get_args(value_resolved)
value = Tuple[args_resolved]

self._annotation = value
self._args_types = args
Expand Down
18 changes: 18 additions & 0 deletions tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from unittest.mock import MagicMock, patch

import pytest
from typing_extensions import Annotated

from magicgui import magicgui, types, use_app, widgets
from magicgui.widgets import Container, request_values
Expand Down Expand Up @@ -906,6 +907,15 @@ def f4(x: List[int] = ()): # type: ignore
assert type(f4.x[0]) is widgets.SpinBox
assert f4.x.value == [0]

@magicgui
def f5(x: List[Annotated[int, {"max": 3}]]):
pass

assert type(f5.x) is widgets.ListEdit
assert f5.x.annotation == List[int]
f5.x.btn_plus.changed()
assert f5.x[0].max == 3


def test_tuple_edit():
"""Test TupleEdit."""
Expand Down Expand Up @@ -946,6 +956,14 @@ def f2(x: Tuple[int, str]):
assert f2.x.annotation == Tuple[int, str]
assert f2.x.value == (0, "")

@magicgui
def f3(x: Tuple[Annotated[int, {"max": 3}], str]):
pass

assert type(f3.x) is widgets.TupleEdit
assert f2.x.annotation == Tuple[int, str]
assert f3.x[0].max == 3


def test_request_values(monkeypatch):
from unittest.mock import Mock
Expand Down