diff --git a/modules/types/ComparableError.py b/modules/types/ComparableError.py
new file mode 100644
index 0000000..6e036ac
--- /dev/null
+++ b/modules/types/ComparableError.py
@@ -0,0 +1,70 @@
+'''
+ Chicory ML Workflow Manager
+ Copyright (C) 2024 Alexis Maya-Isabelle Shuping
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+'''
+
+from typing import List
+
+class ComparableError:
+ ''' Encapsulates an Exception, providing a sensible `__eq__` operation.
+ '''
+ def __init__(self, exc: Exception):
+ self.__exc = exc
+
+ @property
+ def exc(self):
+ return self.__exc
+
+ @classmethod
+ def encapsulate(cls, exc: any):
+ ''' Encapsulate an Exception in a ComparableError.
+
+ If this method is called on anything other than an Exception, the
+ parameter is returned unchanged.
+ '''
+ if issubclass(type(exc), Exception):
+ return cls(exc)
+ else:
+ return exc
+
+ @classmethod
+ def array_encapsulate(cls, arr: List[any]):
+ ''' Encapsulate an entire array of Exceptions
+
+ This method is useful for performing array comparisons.
+ '''
+ return [ComparableError.encapsulate(e) for e in arr]
+
+ def __eq__(self, other):
+ ''' Check whether the encapsulated Exceptions are equal.
+
+ Two Exceptions are defined as equal if:
+ - Both are of the same type
+ - If one has an `args` attribute:
+ - The other has an `args` attribute
+ - The two `args` attributes evaluate to equal.
+ '''
+ other = ComparableError.encapsulate(other) # ensure `other` is comparable
+ if type(self.exc) != type(other.exc):
+ return False
+
+ if hasattr(self.exc, 'args'):
+ if not hasattr(other.exc, 'args'):
+ return False
+ else:
+ return self.exc.args == other.exc.args
+ else:
+ return True
\ No newline at end of file
diff --git a/modules/types/Either.py b/modules/types/Either.py
index 76fdb9a..286aac3 100644
--- a/modules/types/Either.py
+++ b/modules/types/Either.py
@@ -16,6 +16,8 @@
along with this program. If not, see .
'''
+from modules.types.ComparableError import ComparableError
+
from typing import TypeVar, Callable
from abc import ABC, abstractmethod
@@ -81,12 +83,12 @@ def __eq__(self, other) -> bool:
if not other.is_right:
return False
- return self.val == other.val
+ return ComparableError.encapsulate(self.val) == ComparableError.encapsulate(other.val)
else:
if other.is_right:
return False
- return self.lval == other.lval
+ return ComparableError.encapsulate(self.lval) == ComparableError.encapsulate(other.lval)
diff --git a/modules/types/Option.py b/modules/types/Option.py
index aac3ecc..3f03beb 100644
--- a/modules/types/Option.py
+++ b/modules/types/Option.py
@@ -16,6 +16,8 @@
along with this program. If not, see .
'''
+from modules.types.ComparableError import ComparableError
+
from typing import TypeVar, Callable
from abc import ABC, abstractmethod
@@ -65,7 +67,7 @@ def flat_map(self, f: Callable[[A], 'Option[B]']) -> 'Option[B]':
def __eq__(self, other):
if self.has_val:
if other.has_val:
- return self.val == other.val
+ return ComparableError.encapsulate(self.val) == ComparableError.encapsulate(other.val)
else:
return False
else:
diff --git a/test/types/test_ComparableError.py b/test/types/test_ComparableError.py
new file mode 100644
index 0000000..2f90deb
--- /dev/null
+++ b/test/types/test_ComparableError.py
@@ -0,0 +1,27 @@
+from modules.types.ComparableError import ComparableError
+
+def test_get_exc():
+ x = TypeError("test error")
+ assert ComparableError(x).exc == x
+
+def test_encapsulate():
+ class CustomException(Exception):
+ pass
+
+ class NotAnException():
+ pass
+
+ assert type(ComparableError.encapsulate(TypeError())) == ComparableError
+ assert type(ComparableError.encapsulate(CustomException())) == ComparableError
+ assert type(ComparableError.encapsulate(12)) == int
+ assert type(ComparableError.encapsulate("test")) == str
+ assert type(ComparableError.encapsulate(NotAnException())) == NotAnException
+
+def test_eq():
+ assert ComparableError(TypeError("test")) == ComparableError(TypeError("test"))
+ assert ComparableError(TypeError("test")) != ComparableError(TypeError("test2"))
+ assert ComparableError(TypeError("test")) != ComparableError(ArithmeticError("test"))
+
+def test_auto_encapsulate():
+ assert ComparableError(TypeError("test")) == TypeError("test")
+ assert ComparableError(TypeError("test")) != TypeError("test2")
\ No newline at end of file
diff --git a/test/types/test_Either.py b/test/types/test_Either.py
index 57d7a11..52ee904 100644
--- a/test/types/test_Either.py
+++ b/test/types/test_Either.py
@@ -84,6 +84,20 @@ def test_flat_map_left_to_right():
assert Left("str").flat_map(transform) == Left("str")
assert Left(12).flat_map(transform) == Left(12)
+def test_exception_comparison():
+ # Exceptions use ComparableError's __eq__ operator to allow comparison by
+ # value.
+ assert Left(TypeError("test")) == Left(TypeError("test"))
+ assert Left(TypeError("test")) != Left(TypeError("test2"))
+ assert Left(TypeError("test")) != Left(ArithmeticError("test"))
+
+ assert Right(TypeError("test")) == Right(TypeError("test"))
+ assert Right(TypeError("test")) != Right(TypeError("test2"))
+ assert Right(TypeError("test")) != Right(ArithmeticError("test"))
+
+ assert Right(TypeError("test")) != Left(TypeError("test"))
+ assert Left(TypeError("test")) != Right(TypeError("test"))
+
def test_flat_map_long_chain():
assert Right("str").flat_map(
lambda x: Right(f"precious {x}")
diff --git a/test/types/test_Option.py b/test/types/test_Option.py
index bbd780f..afa4740 100644
--- a/test/types/test_Option.py
+++ b/test/types/test_Option.py
@@ -42,6 +42,16 @@ def test_eq():
assert Nothing() == Nothing()
+def test_exc_eq():
+ # ExSeOS-H monads use `ComparableError` to enable comparison of Exceptions
+ # by value. Exceptions of the same type with the same args are treated as
+ # equal for the purpose of comparing monads.
+ assert Some(TypeError("test")) == Some(TypeError("test"))
+ assert Some(TypeError("test")) != Some(TypeError("test2"))
+ assert Some(TypeError("test")) != Some(ArithmeticError("test"))
+ assert Some(TypeError("test")) != Nothing()
+ assert Nothing() != Some(TypeError("test"))
+
def test_map_some():
transform = lambda x: f"my {x}"