Skip to content

Commit

Permalink
Add type inference and test-cases for Variables
Browse files Browse the repository at this point in the history
  • Loading branch information
ashuping committed Sep 11, 2024
1 parent 8e98126 commit 9093985
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 4 deletions.
64 changes: 60 additions & 4 deletions modules/data/Variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
from modules.types.Option import Option, Nothing
from modules.types import common
from modules.types.Option import Option, Nothing, Some

from typing import TypeVar
from abc import ABC, abstractmethod
import logging
from typing import TypeVar

A = TypeVar("A")

log = logging.getLogger(__name__)

class Variable[A](ABC):
''' Stores a quantity whose value can vary from workflow to workflow,
controlled either statically (manually set in configuration) or
Expand Down Expand Up @@ -66,6 +70,14 @@ def var_type(self) -> Option[type]:
'''
... # pragma: no cover

@property
@abstractmethod
def var_type_inferred(self) -> bool:
''' True if `var_type` was automatically inferred (and thus potentially
inaccurate); False if it was explicitly provided.
'''
... # pragma: no cover

@property
@abstractmethod
def default(self) -> Option[A]:
Expand Down Expand Up @@ -95,10 +107,35 @@ class BoundVariable[A](Variable):
def __init__(self, name: str, val: A, var_type: Option[type] = Nothing(), desc: Option[str] = Nothing(), default: Option[A] = Nothing()):
self.__name = name
self.__val = val
self.__type = var_type
self.__desc = desc
self.__default = default

if var_type == Nothing():
if default == Nothing():
# Infer type from `val`
self.__type = Some(type(val))
self.__inferred = True
log.warning(f'Inferred type {self.__type} from val {val} for BoundVariable {name}.')
else:
# Try to find a common type between `val` and `default`
ctype = common(val, default.val)
if ctype.is_okay:
log.warning(f'Inferred type {ctype.val} from val {val} and default {default.val} for BoundVariable {name}')
self.__type = Some(ctype.val)
self.__inferred = True
elif ctype.is_warning:
log.warning(f'Tried to infer type for BoundVariable {name} from val {val} and default {default.val}, but resultant type {ctype.val} seems overly broad. Using it anyway.')
self.__type = Some(ctype.val)
self.__inferred = True
else:
log.warning(f'Failed to infer type for BoundVariable {name} - val ({val}) and default ({default.val}) have no types in common!')
self.__type = Nothing()
self.__inferred = False
else:
# Type was explicitly provided
self.__type = var_type
self.__inferred = False

@property
def is_bound(self) -> bool:
return True
Expand All @@ -119,6 +156,10 @@ def val(self) -> A:
def var_type(self) -> Option[type]:
return self.__type

@property
def var_type_inferred(self) -> bool:
return self.__inferred

@property
def default(self) -> Option[A]:
return self.__default
Expand All @@ -144,10 +185,21 @@ class UnboundVariable[A](Variable):
'''
def __init__(self, name: str, var_type: Option[type] = Nothing(), desc: Option[str] = Nothing(), default: Option[A] = Nothing()):
self.__name = name
self.__type = var_type
self.__desc = desc
self.__default = default

if var_type == Nothing():
if default != Nothing():
log.warning(f'Inferred type {type(default.val)} from default {default.val} for UnboundVariable {name}')
self.__type = Some(type(default.val))
self.__inferred = True
else:
self.__type = Nothing()
self.__inferred = False
else:
self.__type = var_type
self.__inferred = False

@property
def is_bound(self) -> bool:
return False
Expand All @@ -168,6 +220,10 @@ def val(self) -> A:
def var_type(self) -> Option[type]:
return self.__type

@property
def var_type_inferred(self) -> bool:
return self.__inferred

@property
def default(self) -> Option[A]:
return self.__default
Expand Down
140 changes: 140 additions & 0 deletions modules/types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
'''
Chicory ML Workflow Manager
Copyright (C) 2024 Alexis Maya-Isabelle Shuping
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
'''

from modules.types.Result import Result, Okay, Warning, Error

from abc import ABC
from typing import Generic

class BroadCommonTypeWarning(Exception):
''' Used when looking for common types. When the common type between two
values is extremely broad (e.g. `object`), this warning is given.
'''
def __init__(self, objs: list[any], common: type, note: str = ''):
''' Construct a BroadCommonTypeWarning.
:param objs: List of objects that triggered this warning.
:param common: The common type that triggered this warning.
:param note: Further information about this warning.
'''
obj_str = '<no objects provided>' if len(objs) == 0 \
else str(objs[0]) if len(objs) == 1 \
else ', '.join([str(o) for o in objs[:-1]]) + ' and ' + str(objs[-1])

super().__init__(f'Objects {obj_str} only share the broad common type {common.__name__}.{(" " + note) if note else ""}')

self.objs = objs
self.common = common
self.note = note


class NoCommonTypeError(Exception):
def __init__(self, objs: list[any], note: str = ''):
''' Construct a NoCommonTypeError
:param objs: List of objects that triggered this error
:param note: Further information about this error
'''
obj_str = '<no objects provided>' if len(objs) == 0 \
else str(objs[0]) if len(objs) == 1 \
else ', '.join([str(o) for o in objs[:-1]]) + ' and ' + str(objs[-1])

super().__init__(f'Objects {obj_str} do not share a common type.{(" " + note) if note else ""}')

self.objs = objs
self.note = note


def type_check(val: any, t: type) -> bool:
''' Perform a basic, permissive type-check, ensuring that `val` can be
reasonably considered to have type `t`.
This function is used internally for basic verification of, e.g.,
`Variable` values. It considers subclasses, but it does not consider any
of the more complex, type-annotation style details. For example, `t` can
be a `list`, but not a `list[str]`.
:param val: The value to check
:param t: The type that `val` should have
:returns: Whether `val` has type `t`
'''
return issubclass(type(val), t)

def common_t(a: type, b: type) -> Result[Exception, Exception, type]:
''' As `common`, except that `a` and `b` are types rather than values.
:param a: The first type to compare
:param b: The second type to compare
:returns: `Okay(t)` where `t` is the most specific common type, or
`Warning(BroadCommonTypeWarning, t)` if the common type is too
broad, or `Error(NoCommonTypeError)` if there is no common type at
all.
'''
if issubclass(b, a):
return Okay(a)

if issubclass(a, b):
return Okay(b)

def _candidate_search(a, b, ignore_broad=True):
# Search `a` to find common ancestors of `b`
BROAD_CLASSES = [object, ABC, type, Generic]
candidate = None
superclasses = a.__bases__
while candidate is None and len(superclasses) > 0:
next_cycle_superclasses = []
for scls in superclasses:
if issubclass(b, scls) and (scls not in BROAD_CLASSES or ignore_broad == False):
candidate = scls
break
else:
next_cycle_superclasses += list(scls.__bases__)
superclasses = next_cycle_superclasses

return candidate

candidate = _candidate_search(a, b) # First pass - ignore too-broad types
if candidate:
return Okay(candidate)

candidate = _candidate_search(a, b, False) # Second pass - include too-broad types
if candidate:
return Warning([BroadCommonTypeWarning([a, b], candidate)], candidate)
else:
return Error([NoCommonTypeError([a, b])])

def common(a: any, b: any) -> Result[Exception, Exception, type]:
''' Return the most specifc type that `a` and `b` have in common.
If `a` and `b` have an extremely broad common type (e.g. `object`), then
the result will include a `BroadCommonTypeWarning`. If there is no
common type at all, the result will be a `NoCommonTypeError`.
Note that if `a` and `b` have the same type, or if one is a subclass of
the other, the result will always be `Okay()`, even if one or the
other's type is `object`. `BroadCommonTypeWarning` only applies when
this function has to look for 'common ancestors.'
:param a: The first value to compare
:param b: The second value to compare
:returns: `Okay(t)` where `t` is the most specific common type, or
`Warning(BroadCommonTypeWarning, t)` if the common type is too
broad, or `Error(NoCommonTypeError)` if there is no common type at
all.
'''
return common_t(type(a), type(b))
Empty file added test/data/__init__.py
Empty file.
68 changes: 68 additions & 0 deletions test/data/test_Variable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
'''
Chicory ML Workflow Manager
Copyright (C) 2024 Alexis Maya-Isabelle Shuping
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
from modules.data.Variable import UnboundVariable, BoundVariable
from modules.types.Option import Option, Nothing, Some

def test_is_bound():
assert BoundVariable('x', 2).is_bound
assert BoundVariable('x', None).is_bound
assert not UnboundVariable('x').is_bound

def test_name():
assert BoundVariable('x', 2).name == 'x'
assert UnboundVariable('y').name == 'y'
assert UnboundVariable('multi-word name').name == 'multi-word name'

def test_var_type():
assert BoundVariable('x', 2, Some(int)).var_type == Some(int)
assert UnboundVariable('y', Some(str)).var_type == Some(str)

def test_basic_type_inference():
assert BoundVariable('x', 2).var_type == Some(int)
assert BoundVariable('x', 2).var_type_inferred == True
assert BoundVariable('x', 2, Some(int)).var_type_inferred == False

assert UnboundVariable('y', default=Some(2)).var_type == Some(int)
assert UnboundVariable('y', default=Some(2)).var_type_inferred == True
assert UnboundVariable('y', Some(int), default=Some(2)).var_type_inferred == False

assert BoundVariable('x', 2, default=Some(2)).var_type == Some(int)
assert BoundVariable('x', 2, default=Some(2)).var_type_inferred == True

assert BoundVariable('x', 2, Some(int), default=Some(2)).var_type == Some(int)
assert BoundVariable('x', 2, Some(int), default=Some(2)).var_type_inferred == False

def test_explicit_type_overrides_inference():
assert BoundVariable('x', 2, Some(str), default=Some(2)).var_type == Some(str)
assert BoundVariable('x', 2, Some(str), default=Some(2)).var_type_inferred == False

def test_type_inference_uses_common_ancestors():
class TestSuperclass():
pass

class TestSubclass1(TestSuperclass):
pass

class TestSubclass2(TestSuperclass):
pass

# BoundVariables should choose the closest common ancestor of `val` and `default` when performing type inference.
assert BoundVariable('x', TestSubclass1(), default=Some(TestSubclass2())).var_type == Some(TestSuperclass)

# Sometimes, the closest common ancestor is `object`.
assert BoundVariable('x', TestSubclass1(), default=Some(1)).var_type == Some(object)
16 changes: 16 additions & 0 deletions test/workflow/stage/test_StageFromFunction.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
'''
Chicory ML Workflow Manager
Copyright (C) 2024 Alexis Maya-Isabelle Shuping
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
from modules.data.Variable import BoundVariable, UnboundVariable
from modules.workflow.stage.StageFromFunction import StageFromFunction, make_StageFromFunction
from modules.types.Option import Some
Expand Down

0 comments on commit 9093985

Please sign in to comment.