Skip to content

Commit

Permalink
Add tests for type-checking functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ashuping committed Sep 16, 2024
1 parent 9093985 commit f23c8b6
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 17 deletions.
32 changes: 17 additions & 15 deletions modules/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,38 +25,38 @@ class BroadCommonTypeWarning(Exception):
''' Used when looking for common types. When the common type between two
values is extremely broad (e.g. `object`), this warning is given.
'''
def __init__(self, objs: list[any], common: type, note: str = ''):
def __init__(self, types: list[any], common: type, note: str = ''):
''' Construct a BroadCommonTypeWarning.
:param objs: List of objects that triggered this warning.
:param types: List of types that triggered this warning.
:param common: The common type that triggered this warning.
:param note: Further information about this warning.
'''
obj_str = '<no objects provided>' if len(objs) == 0 \
else str(objs[0]) if len(objs) == 1 \
else ', '.join([str(o) for o in objs[:-1]]) + ' and ' + str(objs[-1])
type_str = '<no types provided>' if len(types) == 0 \
else f'`{types[0].__name__}`' if len(types) == 1 \
else ', '.join([f'`{o.__name__}`' for o in types[:-1]]) + ' and ' + f'`{types[-1].__name__}`'

super().__init__(f'Objects {obj_str} only share the broad common type {common.__name__}.{(" " + note) if note else ""}')
super().__init__(f'Types {type_str} only share the broad common type `{common.__name__}`.{(" " + note) if note else ""}')

self.objs = objs
self.types = types
self.common = common
self.note = note


class NoCommonTypeError(Exception):
def __init__(self, objs: list[any], note: str = ''):
def __init__(self, types: list[any], note: str = ''):
''' Construct a NoCommonTypeError
:param objs: List of objects that triggered this error
:param types: List of types that triggered this error
:param note: Further information about this error
'''
obj_str = '<no objects provided>' if len(objs) == 0 \
else str(objs[0]) if len(objs) == 1 \
else ', '.join([str(o) for o in objs[:-1]]) + ' and ' + str(objs[-1])
type_str = '<no types provided>' if len(types) == 0 \
else f'`{types[0].__name__}`' if len(types) == 1 \
else ', '.join([f'`{o.__name__}`' for o in types[:-1]]) + ' and ' + f'`{types[-1].__name__}`'

super().__init__(f'Objects {obj_str} do not share a common type.{(" " + note) if note else ""}')
super().__init__(f'Types {type_str} do not share a common type.{(" " + note) if note else ""}')

self.objs = objs
self.types = types
self.note = note


Expand Down Expand Up @@ -93,12 +93,14 @@ def common_t(a: type, b: type) -> Result[Exception, Exception, type]:

def _candidate_search(a, b, ignore_broad=True):
# Search `a` to find common ancestors of `b`
BROAD_CLASSES = [object, ABC, type, Generic]
BROAD_CLASSES = [ABC, type, Generic]
candidate = None
superclasses = a.__bases__
while candidate is None and len(superclasses) > 0:
next_cycle_superclasses = []
for scls in superclasses:
if scls is object:
continue
if issubclass(b, scls) and (scls not in BROAD_CLASSES or ignore_broad == False):
candidate = scls
break
Expand Down
4 changes: 2 additions & 2 deletions test/data/test_Variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,5 @@ class TestSubclass2(TestSuperclass):
# BoundVariables should choose the closest common ancestor of `val` and `default` when performing type inference.
assert BoundVariable('x', TestSubclass1(), default=Some(TestSubclass2())).var_type == Some(TestSuperclass)

# Sometimes, the closest common ancestor is `object`.
assert BoundVariable('x', TestSubclass1(), default=Some(1)).var_type == Some(object)
# If the only common ancestor is `object`, reject the inferred type.
assert BoundVariable('x', TestSubclass1(), default=Some(1)).var_type == Nothing()
119 changes: 119 additions & 0 deletions test/types/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
'''
Chicory ML Workflow Manager
Copyright (C) 2024 Alexis Maya-Isabelle Shuping
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
from abc import ABC

from modules.types.Result import Result, Okay, Warning, Error
from modules.types import type_check, common, common_t, NoCommonTypeError, BroadCommonTypeWarning

class TestSuperclass():
pass

class TestSubclass1(TestSuperclass):
pass

class TestSubclass2(TestSuperclass):
pass

class TestUnrelatedClass():
pass

class TestAbstractClass(ABC):
pass

class TestUnrelatedAbstract(ABC):
pass

def test_BroadCommonTypeWarning():
w = BroadCommonTypeWarning([int, str], object, 'test note')

assert str(w) == 'Types `int` and `str` only share the broad common type `object`. test note'

assert w.types == [int, str]
assert w.common == object
assert w.note == 'test note'

def test_NoCommonTypeError():
e = NoCommonTypeError([int, str], 'test note')

assert str(e) == 'Types `int` and `str` do not share a common type. test note'

assert e.types == [int, str]
assert e.note == 'test note'

def test_type_check_exact_match():
sup = TestSuperclass()
sub1 = TestSubclass1()
sub2 = TestSubclass2()
unr = TestUnrelatedClass()

assert type_check(sup, TestSuperclass)
assert type_check(sub1, TestSubclass1)
assert type_check(sub2, TestSubclass2)

assert not type_check(sup, TestUnrelatedClass)

def test_type_check_subclass_match():
sup = TestSuperclass()
sub1 = TestSubclass1()
sub2 = TestSubclass2()
unr = TestUnrelatedClass()

assert type_check(sub1, TestSuperclass)
assert type_check(sub2, TestSuperclass)

assert not type_check(sub2, TestSubclass1)
assert not type_check(sub1, TestUnrelatedClass)

def test_common_t_sibling():
assert common_t(TestSubclass1, TestSubclass2) == Okay(TestSuperclass)

def test_common_t_parent():
assert common_t(TestSubclass1, TestSuperclass) == Okay(TestSuperclass)

def test_common_t_child():
assert common_t(TestSuperclass, TestSubclass2) == Okay(TestSuperclass)

def test_common_t_broad_related():
assert common_t(TestAbstractClass, TestUnrelatedAbstract) == Warning([
BroadCommonTypeWarning(
[TestAbstractClass, TestUnrelatedAbstract], ABC
)
], ABC)

def test_common_t_unrelated():
assert common_t(TestSuperclass, TestUnrelatedClass) == Error([
NoCommonTypeError(
[TestSuperclass, TestUnrelatedClass]
)])

def test_common():
# `common` is a wrapper around `common_t`, so we don't need to test it as
# extensively.

sup = TestSuperclass()
sub1 = TestSubclass1()
sub2 = TestSubclass2()
unr = TestUnrelatedClass()


assert common(sup, sub1) == Okay(TestSuperclass)
assert common(sub2, sup) == Okay(TestSuperclass)
assert common(sub1, sub2) == Okay(TestSuperclass)
assert common(sub1, unr) == Error([
NoCommonTypeError([TestSubclass1, TestUnrelatedClass])
])

0 comments on commit f23c8b6

Please sign in to comment.