diff --git a/modules/types/ComparableError.py b/modules/types/ComparableError.py new file mode 100644 index 0000000..6e036ac --- /dev/null +++ b/modules/types/ComparableError.py @@ -0,0 +1,70 @@ +''' + Chicory ML Workflow Manager + Copyright (C) 2024 Alexis Maya-Isabelle Shuping + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +''' + +from typing import List + +class ComparableError: + ''' Encapsulates an Exception, providing a sensible `__eq__` operation. + ''' + def __init__(self, exc: Exception): + self.__exc = exc + + @property + def exc(self): + return self.__exc + + @classmethod + def encapsulate(cls, exc: any): + ''' Encapsulate an Exception in a ComparableError. + + If this method is called on anything other than an Exception, the + parameter is returned unchanged. + ''' + if issubclass(type(exc), Exception): + return cls(exc) + else: + return exc + + @classmethod + def array_encapsulate(cls, arr: List[any]): + ''' Encapsulate an entire array of Exceptions + + This method is useful for performing array comparisons. + ''' + return [ComparableError.encapsulate(e) for e in arr] + + def __eq__(self, other): + ''' Check whether the encapsulated Exceptions are equal. + + Two Exceptions are defined as equal if: + - Both are of the same type + - If one has an `args` attribute: + - The other has an `args` attribute + - The two `args` attributes evaluate to equal. + ''' + other = ComparableError.encapsulate(other) # ensure `other` is comparable + if type(self.exc) != type(other.exc): + return False + + if hasattr(self.exc, 'args'): + if not hasattr(other.exc, 'args'): + return False + else: + return self.exc.args == other.exc.args + else: + return True \ No newline at end of file diff --git a/modules/types/Either.py b/modules/types/Either.py index 76fdb9a..286aac3 100644 --- a/modules/types/Either.py +++ b/modules/types/Either.py @@ -16,6 +16,8 @@ along with this program. If not, see . ''' +from modules.types.ComparableError import ComparableError + from typing import TypeVar, Callable from abc import ABC, abstractmethod @@ -81,12 +83,12 @@ def __eq__(self, other) -> bool: if not other.is_right: return False - return self.val == other.val + return ComparableError.encapsulate(self.val) == ComparableError.encapsulate(other.val) else: if other.is_right: return False - return self.lval == other.lval + return ComparableError.encapsulate(self.lval) == ComparableError.encapsulate(other.lval) diff --git a/modules/types/Option.py b/modules/types/Option.py index aac3ecc..3f03beb 100644 --- a/modules/types/Option.py +++ b/modules/types/Option.py @@ -16,6 +16,8 @@ along with this program. If not, see . ''' +from modules.types.ComparableError import ComparableError + from typing import TypeVar, Callable from abc import ABC, abstractmethod @@ -65,7 +67,7 @@ def flat_map(self, f: Callable[[A], 'Option[B]']) -> 'Option[B]': def __eq__(self, other): if self.has_val: if other.has_val: - return self.val == other.val + return ComparableError.encapsulate(self.val) == ComparableError.encapsulate(other.val) else: return False else: diff --git a/test/types/test_ComparableError.py b/test/types/test_ComparableError.py new file mode 100644 index 0000000..2f90deb --- /dev/null +++ b/test/types/test_ComparableError.py @@ -0,0 +1,27 @@ +from modules.types.ComparableError import ComparableError + +def test_get_exc(): + x = TypeError("test error") + assert ComparableError(x).exc == x + +def test_encapsulate(): + class CustomException(Exception): + pass + + class NotAnException(): + pass + + assert type(ComparableError.encapsulate(TypeError())) == ComparableError + assert type(ComparableError.encapsulate(CustomException())) == ComparableError + assert type(ComparableError.encapsulate(12)) == int + assert type(ComparableError.encapsulate("test")) == str + assert type(ComparableError.encapsulate(NotAnException())) == NotAnException + +def test_eq(): + assert ComparableError(TypeError("test")) == ComparableError(TypeError("test")) + assert ComparableError(TypeError("test")) != ComparableError(TypeError("test2")) + assert ComparableError(TypeError("test")) != ComparableError(ArithmeticError("test")) + +def test_auto_encapsulate(): + assert ComparableError(TypeError("test")) == TypeError("test") + assert ComparableError(TypeError("test")) != TypeError("test2") \ No newline at end of file diff --git a/test/types/test_Either.py b/test/types/test_Either.py index 57d7a11..52ee904 100644 --- a/test/types/test_Either.py +++ b/test/types/test_Either.py @@ -84,6 +84,20 @@ def test_flat_map_left_to_right(): assert Left("str").flat_map(transform) == Left("str") assert Left(12).flat_map(transform) == Left(12) +def test_exception_comparison(): + # Exceptions use ComparableError's __eq__ operator to allow comparison by + # value. + assert Left(TypeError("test")) == Left(TypeError("test")) + assert Left(TypeError("test")) != Left(TypeError("test2")) + assert Left(TypeError("test")) != Left(ArithmeticError("test")) + + assert Right(TypeError("test")) == Right(TypeError("test")) + assert Right(TypeError("test")) != Right(TypeError("test2")) + assert Right(TypeError("test")) != Right(ArithmeticError("test")) + + assert Right(TypeError("test")) != Left(TypeError("test")) + assert Left(TypeError("test")) != Right(TypeError("test")) + def test_flat_map_long_chain(): assert Right("str").flat_map( lambda x: Right(f"precious {x}") diff --git a/test/types/test_Option.py b/test/types/test_Option.py index bbd780f..afa4740 100644 --- a/test/types/test_Option.py +++ b/test/types/test_Option.py @@ -42,6 +42,16 @@ def test_eq(): assert Nothing() == Nothing() +def test_exc_eq(): + # ExSeOS-H monads use `ComparableError` to enable comparison of Exceptions + # by value. Exceptions of the same type with the same args are treated as + # equal for the purpose of comparing monads. + assert Some(TypeError("test")) == Some(TypeError("test")) + assert Some(TypeError("test")) != Some(TypeError("test2")) + assert Some(TypeError("test")) != Some(ArithmeticError("test")) + assert Some(TypeError("test")) != Nothing() + assert Nothing() != Some(TypeError("test")) + def test_map_some(): transform = lambda x: f"my {x}"