From 625d149c51d299dc190410ad0f9a101d74c183d8 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Tue, 15 Oct 2024 16:44:08 +0300 Subject: [PATCH] Fix: Throw error when attempting to dispatch on literal --- runtype/dispatch.py | 4 ++++ runtype/pytypes.py | 5 +++++ tests/test_basic.py | 18 ++++++++++++++++++ 3 files changed, 27 insertions(+) diff --git a/runtype/dispatch.py b/runtype/dispatch.py index 27581ce..ef6c6e5 100644 --- a/runtype/dispatch.py +++ b/runtype/dispatch.py @@ -177,6 +177,10 @@ def define_function(self, f): for signature in get_func_signatures(self.typesystem, f): node = self.root for t in signature: + if not isinstance(t, type): + # XXX this is a temporary fix for preventing certain types from being used for dispatch + if not getattr(t, 'ALLOW_DISPATCH', True): + raise ValueError(f"Type {t} cannot be used for dispatch") node = node.follow_type[t] if node.func is not None: diff --git a/runtype/pytypes.py b/runtype/pytypes.py index 822f734..b47e08e 100644 --- a/runtype/pytypes.py +++ b/runtype/pytypes.py @@ -200,6 +200,8 @@ def test_instance(self, obj, sampler=None): class OneOf(PythonType): values: typing.Sequence + ALLOW_DISPATCH = False + def __init__(self, values): self.values = values @@ -218,6 +220,7 @@ def cast_from(self, obj): raise TypeMismatchError(obj, self) + class GenericType(base_types.GenericType, PythonType): base: PythonDataType item: PythonType @@ -448,6 +451,8 @@ def cast_from(self, obj): class _NoneType(OneOf): + ALLOW_DISPATCH = True # Make an exception + def __init__(self): super().__init__([None]) diff --git a/tests/test_basic.py b/tests/test_basic.py index b737227..7c869be 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -728,6 +728,24 @@ def f(t: Tree[int]): f(Tree()) + def test_literal_dispatch(self): + try: + @multidispatch + def f(x: typing.Literal[1]): + return 1 + + @multidispatch + def f(x: typing.Literal[2]): + return 2 + except ValueError: + pass + else: + assert False + + # If it was working.. + # assert f(1) == 1 + # assert f(2) == 2 + class TestDataclass(TestCase): def setUp(self):