Skip to content

Commit

Permalink
Merge pull request #59 from scipp/two-arg-generic-param
Browse files Browse the repository at this point in the history
Fix check preventing use of `ScopeTwoParam` subclass as provider param
  • Loading branch information
SimonHeybrock authored Sep 15, 2023
2 parents 4902f05 + 70fc8a1 commit 3cd8c86
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from sciline.task_graph import TaskGraph

from .domain import Scope
from .domain import Scope, ScopeTwoParams
from .param_table import ParamTable
from .scheduler import Scheduler
from .series import Series
Expand Down Expand Up @@ -353,11 +353,14 @@ def __setitem__(self, key: Type[T], param: T) -> None:
expected = np_origin
else:
expected = underlying
elif issubclass(origin, Scope):
elif issubclass(origin, (Scope, ScopeTwoParams)):
scope = origin.__orig_bases__[0]
while (orig := get_origin(scope)) is not None and orig is not Scope:
while (orig := get_origin(scope)) is not None and orig not in (
Scope,
ScopeTwoParams,
):
scope = orig.__orig_bases__[0]
expected = get_args(scope)[1]
expected = get_args(scope)[-1]
else:
expected = origin

Expand Down
50 changes: 50 additions & 0 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,56 @@ def func2(x: int) -> str:
assert ncall == 1


def test_Scope_subclass_can_be_set_as_param() -> None:
Param = TypeVar('Param')

class Str(sl.Scope[Param, str], str):
...

pipeline = sl.Pipeline(params={Str[int]: Str[int]('1')})
pipeline[Str[float]] = Str[float]('2.0')
assert pipeline.compute(Str[int]) == Str[int]('1')
assert pipeline.compute(Str[float]) == Str[float]('2.0')


def test_Scope_subclass_can_be_set_as_param_with_unbound_typevar() -> None:
Param = TypeVar('Param')

class Str(sl.Scope[Param, str], str):
...

pipeline = sl.Pipeline()
pipeline[Str[Param]] = Str[Param]('1') # type: ignore[valid-type]
assert pipeline.compute(Str[int]) == Str[int]('1')
assert pipeline.compute(Str[float]) == Str[float]('1')


def test_ScopeTwoParam_subclass_can_be_set_as_param() -> None:
Param1 = TypeVar('Param1')
Param2 = TypeVar('Param2')

class Str(sl.ScopeTwoParams[Param1, Param2, str], str):
...

pipeline = sl.Pipeline(params={Str[int, float]: Str[int, float]('1')})
pipeline[Str[float, int]] = Str[float, int]('2.0')
assert pipeline.compute(Str[int, float]) == Str[int, float]('1')
assert pipeline.compute(Str[float, int]) == Str[float, int]('2.0')


def test_ScopeTwoParam_subclass_can_be_set_as_param_with_unbound_typevar() -> None:
Param1 = TypeVar('Param1')
Param2 = TypeVar('Param2')

class Str(sl.ScopeTwoParams[Param1, Param2, str], str):
...

pipeline = sl.Pipeline()
pipeline[Str[Param1, Param2]] = Str[Param1, Param2]('1') # type: ignore[valid-type]
assert pipeline.compute(Str[int, float]) == Str[int, float]('1')
assert pipeline.compute(Str[float, int]) == Str[float, int]('1')


def test_generic_providers_produce_use_dependencies_based_on_bound_typevar() -> None:
Param = TypeVar('Param')

Expand Down

0 comments on commit 3cd8c86

Please sign in to comment.