Skip to content

Commit

Permalink
Add ComparableError to enable comparing exceptions by value; implemen…
Browse files Browse the repository at this point in the history
…t implicit support in Option/Either monads.
  • Loading branch information
ashuping committed Aug 12, 2024
1 parent 1d16414 commit 23d333e
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 3 deletions.
70 changes: 70 additions & 0 deletions modules/types/ComparableError.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
'''
Chicory ML Workflow Manager
Copyright (C) 2024 Alexis Maya-Isabelle Shuping
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
'''

from typing import List

class ComparableError:
''' Encapsulates an Exception, providing a sensible `__eq__` operation.
'''
def __init__(self, exc: Exception):
self.__exc = exc

@property
def exc(self):
return self.__exc

@classmethod
def encapsulate(cls, exc: any):
''' Encapsulate an Exception in a ComparableError.
If this method is called on anything other than an Exception, the
parameter is returned unchanged.
'''
if issubclass(type(exc), Exception):
return cls(exc)
else:
return exc

@classmethod
def array_encapsulate(cls, arr: List[any]):
''' Encapsulate an entire array of Exceptions
This method is useful for performing array comparisons.
'''
return [ComparableError.encapsulate(e) for e in arr]

def __eq__(self, other):
''' Check whether the encapsulated Exceptions are equal.
Two Exceptions are defined as equal if:
- Both are of the same type
- If one has an `args` attribute:
- The other has an `args` attribute
- The two `args` attributes evaluate to equal.
'''
other = ComparableError.encapsulate(other) # ensure `other` is comparable
if type(self.exc) != type(other.exc):
return False

if hasattr(self.exc, 'args'):
if not hasattr(other.exc, 'args'):
return False
else:
return self.exc.args == other.exc.args
else:
return True
6 changes: 4 additions & 2 deletions modules/types/Either.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
'''

from modules.types.ComparableError import ComparableError

from typing import TypeVar, Callable
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -81,12 +83,12 @@ def __eq__(self, other) -> bool:
if not other.is_right:
return False

return self.val == other.val
return ComparableError.encapsulate(self.val) == ComparableError.encapsulate(other.val)
else:
if other.is_right:
return False

return self.lval == other.lval
return ComparableError.encapsulate(self.lval) == ComparableError.encapsulate(other.lval)



Expand Down
4 changes: 3 additions & 1 deletion modules/types/Option.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
'''

from modules.types.ComparableError import ComparableError

from typing import TypeVar, Callable
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -65,7 +67,7 @@ def flat_map(self, f: Callable[[A], 'Option[B]']) -> 'Option[B]':
def __eq__(self, other):
if self.has_val:
if other.has_val:
return self.val == other.val
return ComparableError.encapsulate(self.val) == ComparableError.encapsulate(other.val)
else:
return False
else:
Expand Down
27 changes: 27 additions & 0 deletions test/types/test_ComparableError.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from modules.types.ComparableError import ComparableError

def test_get_exc():
x = TypeError("test error")
assert ComparableError(x).exc == x

def test_encapsulate():
class CustomException(Exception):
pass

class NotAnException():
pass

assert type(ComparableError.encapsulate(TypeError())) == ComparableError
assert type(ComparableError.encapsulate(CustomException())) == ComparableError
assert type(ComparableError.encapsulate(12)) == int
assert type(ComparableError.encapsulate("test")) == str
assert type(ComparableError.encapsulate(NotAnException())) == NotAnException

def test_eq():
assert ComparableError(TypeError("test")) == ComparableError(TypeError("test"))
assert ComparableError(TypeError("test")) != ComparableError(TypeError("test2"))
assert ComparableError(TypeError("test")) != ComparableError(ArithmeticError("test"))

def test_auto_encapsulate():
assert ComparableError(TypeError("test")) == TypeError("test")
assert ComparableError(TypeError("test")) != TypeError("test2")
14 changes: 14 additions & 0 deletions test/types/test_Either.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,20 @@ def test_flat_map_left_to_right():
assert Left("str").flat_map(transform) == Left("str")
assert Left(12).flat_map(transform) == Left(12)

def test_exception_comparison():
# Exceptions use ComparableError's __eq__ operator to allow comparison by
# value.
assert Left(TypeError("test")) == Left(TypeError("test"))
assert Left(TypeError("test")) != Left(TypeError("test2"))
assert Left(TypeError("test")) != Left(ArithmeticError("test"))

assert Right(TypeError("test")) == Right(TypeError("test"))
assert Right(TypeError("test")) != Right(TypeError("test2"))
assert Right(TypeError("test")) != Right(ArithmeticError("test"))

assert Right(TypeError("test")) != Left(TypeError("test"))
assert Left(TypeError("test")) != Right(TypeError("test"))

def test_flat_map_long_chain():
assert Right("str").flat_map(
lambda x: Right(f"precious {x}")
Expand Down
10 changes: 10 additions & 0 deletions test/types/test_Option.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ def test_eq():

assert Nothing() == Nothing()

def test_exc_eq():
# ExSeOS-H monads use `ComparableError` to enable comparison of Exceptions
# by value. Exceptions of the same type with the same args are treated as
# equal for the purpose of comparing monads.
assert Some(TypeError("test")) == Some(TypeError("test"))
assert Some(TypeError("test")) != Some(TypeError("test2"))
assert Some(TypeError("test")) != Some(ArithmeticError("test"))
assert Some(TypeError("test")) != Nothing()
assert Nothing() != Some(TypeError("test"))

def test_map_some():
transform = lambda x: f"my {x}"

Expand Down

0 comments on commit 23d333e

Please sign in to comment.