-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add type inference and test-cases for Variables
- Loading branch information
Showing
5 changed files
with
284 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
''' | ||
Chicory ML Workflow Manager | ||
Copyright (C) 2024 Alexis Maya-Isabelle Shuping | ||
This program is free software: you can redistribute it and/or modify | ||
it under the terms of the GNU General Public License as published by | ||
the Free Software Foundation, either version 3 of the License, or | ||
(at your option) any later version. | ||
This program is distributed in the hope that it will be useful, | ||
but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
GNU General Public License for more details. | ||
You should have received a copy of the GNU General Public License | ||
along with this program. If not, see <https://www.gnu.org/licenses/>. | ||
''' | ||
|
||
from modules.types.Result import Result, Okay, Warning, Error | ||
|
||
from abc import ABC | ||
from typing import Generic | ||
|
||
class BroadCommonTypeWarning(Exception): | ||
''' Used when looking for common types. When the common type between two | ||
values is extremely broad (e.g. `object`), this warning is given. | ||
''' | ||
def __init__(self, objs: list[any], common: type, note: str = ''): | ||
''' Construct a BroadCommonTypeWarning. | ||
:param objs: List of objects that triggered this warning. | ||
:param common: The common type that triggered this warning. | ||
:param note: Further information about this warning. | ||
''' | ||
obj_str = '<no objects provided>' if len(objs) == 0 \ | ||
else str(objs[0]) if len(objs) == 1 \ | ||
else ', '.join([str(o) for o in objs[:-1]]) + ' and ' + str(objs[-1]) | ||
|
||
super().__init__(f'Objects {obj_str} only share the broad common type {common.__name__}.{(" " + note) if note else ""}') | ||
|
||
self.objs = objs | ||
self.common = common | ||
self.note = note | ||
|
||
|
||
class NoCommonTypeError(Exception): | ||
def __init__(self, objs: list[any], note: str = ''): | ||
''' Construct a NoCommonTypeError | ||
:param objs: List of objects that triggered this error | ||
:param note: Further information about this error | ||
''' | ||
obj_str = '<no objects provided>' if len(objs) == 0 \ | ||
else str(objs[0]) if len(objs) == 1 \ | ||
else ', '.join([str(o) for o in objs[:-1]]) + ' and ' + str(objs[-1]) | ||
|
||
super().__init__(f'Objects {obj_str} do not share a common type.{(" " + note) if note else ""}') | ||
|
||
self.objs = objs | ||
self.note = note | ||
|
||
|
||
def type_check(val: any, t: type) -> bool: | ||
''' Perform a basic, permissive type-check, ensuring that `val` can be | ||
reasonably considered to have type `t`. | ||
This function is used internally for basic verification of, e.g., | ||
`Variable` values. It considers subclasses, but it does not consider any | ||
of the more complex, type-annotation style details. For example, `t` can | ||
be a `list`, but not a `list[str]`. | ||
:param val: The value to check | ||
:param t: The type that `val` should have | ||
:returns: Whether `val` has type `t` | ||
''' | ||
return issubclass(type(val), t) | ||
|
||
def common_t(a: type, b: type) -> Result[Exception, Exception, type]: | ||
''' As `common`, except that `a` and `b` are types rather than values. | ||
:param a: The first type to compare | ||
:param b: The second type to compare | ||
:returns: `Okay(t)` where `t` is the most specific common type, or | ||
`Warning(BroadCommonTypeWarning, t)` if the common type is too | ||
broad, or `Error(NoCommonTypeError)` if there is no common type at | ||
all. | ||
''' | ||
if issubclass(b, a): | ||
return Okay(a) | ||
|
||
if issubclass(a, b): | ||
return Okay(b) | ||
|
||
def _candidate_search(a, b, ignore_broad=True): | ||
# Search `a` to find common ancestors of `b` | ||
BROAD_CLASSES = [object, ABC, type, Generic] | ||
candidate = None | ||
superclasses = a.__bases__ | ||
while candidate is None and len(superclasses) > 0: | ||
next_cycle_superclasses = [] | ||
for scls in superclasses: | ||
if issubclass(b, scls) and (scls not in BROAD_CLASSES or ignore_broad == False): | ||
candidate = scls | ||
break | ||
else: | ||
next_cycle_superclasses += list(scls.__bases__) | ||
superclasses = next_cycle_superclasses | ||
|
||
return candidate | ||
|
||
candidate = _candidate_search(a, b) # First pass - ignore too-broad types | ||
if candidate: | ||
return Okay(candidate) | ||
|
||
candidate = _candidate_search(a, b, False) # Second pass - include too-broad types | ||
if candidate: | ||
return Warning([BroadCommonTypeWarning([a, b], candidate)], candidate) | ||
else: | ||
return Error([NoCommonTypeError([a, b])]) | ||
|
||
def common(a: any, b: any) -> Result[Exception, Exception, type]: | ||
''' Return the most specifc type that `a` and `b` have in common. | ||
If `a` and `b` have an extremely broad common type (e.g. `object`), then | ||
the result will include a `BroadCommonTypeWarning`. If there is no | ||
common type at all, the result will be a `NoCommonTypeError`. | ||
Note that if `a` and `b` have the same type, or if one is a subclass of | ||
the other, the result will always be `Okay()`, even if one or the | ||
other's type is `object`. `BroadCommonTypeWarning` only applies when | ||
this function has to look for 'common ancestors.' | ||
:param a: The first value to compare | ||
:param b: The second value to compare | ||
:returns: `Okay(t)` where `t` is the most specific common type, or | ||
`Warning(BroadCommonTypeWarning, t)` if the common type is too | ||
broad, or `Error(NoCommonTypeError)` if there is no common type at | ||
all. | ||
''' | ||
return common_t(type(a), type(b)) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
''' | ||
Chicory ML Workflow Manager | ||
Copyright (C) 2024 Alexis Maya-Isabelle Shuping | ||
This program is free software: you can redistribute it and/or modify | ||
it under the terms of the GNU General Public License as published by | ||
the Free Software Foundation, either version 3 of the License, or | ||
(at your option) any later version. | ||
This program is distributed in the hope that it will be useful, | ||
but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
GNU General Public License for more details. | ||
You should have received a copy of the GNU General Public License | ||
along with this program. If not, see <https://www.gnu.org/licenses/>. | ||
''' | ||
from modules.data.Variable import UnboundVariable, BoundVariable | ||
from modules.types.Option import Option, Nothing, Some | ||
|
||
def test_is_bound(): | ||
assert BoundVariable('x', 2).is_bound | ||
assert BoundVariable('x', None).is_bound | ||
assert not UnboundVariable('x').is_bound | ||
|
||
def test_name(): | ||
assert BoundVariable('x', 2).name == 'x' | ||
assert UnboundVariable('y').name == 'y' | ||
assert UnboundVariable('multi-word name').name == 'multi-word name' | ||
|
||
def test_var_type(): | ||
assert BoundVariable('x', 2, Some(int)).var_type == Some(int) | ||
assert UnboundVariable('y', Some(str)).var_type == Some(str) | ||
|
||
def test_basic_type_inference(): | ||
assert BoundVariable('x', 2).var_type == Some(int) | ||
assert BoundVariable('x', 2).var_type_inferred == True | ||
assert BoundVariable('x', 2, Some(int)).var_type_inferred == False | ||
|
||
assert UnboundVariable('y', default=Some(2)).var_type == Some(int) | ||
assert UnboundVariable('y', default=Some(2)).var_type_inferred == True | ||
assert UnboundVariable('y', Some(int), default=Some(2)).var_type_inferred == False | ||
|
||
assert BoundVariable('x', 2, default=Some(2)).var_type == Some(int) | ||
assert BoundVariable('x', 2, default=Some(2)).var_type_inferred == True | ||
|
||
assert BoundVariable('x', 2, Some(int), default=Some(2)).var_type == Some(int) | ||
assert BoundVariable('x', 2, Some(int), default=Some(2)).var_type_inferred == False | ||
|
||
def test_explicit_type_overrides_inference(): | ||
assert BoundVariable('x', 2, Some(str), default=Some(2)).var_type == Some(str) | ||
assert BoundVariable('x', 2, Some(str), default=Some(2)).var_type_inferred == False | ||
|
||
def test_type_inference_uses_common_ancestors(): | ||
class TestSuperclass(): | ||
pass | ||
|
||
class TestSubclass1(TestSuperclass): | ||
pass | ||
|
||
class TestSubclass2(TestSuperclass): | ||
pass | ||
|
||
# BoundVariables should choose the closest common ancestor of `val` and `default` when performing type inference. | ||
assert BoundVariable('x', TestSubclass1(), default=Some(TestSubclass2())).var_type == Some(TestSuperclass) | ||
|
||
# Sometimes, the closest common ancestor is `object`. | ||
assert BoundVariable('x', TestSubclass1(), default=Some(1)).var_type == Some(object) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters