Skip to content

Commit

Permalink
Add special numpy array handling in Variable
Browse files Browse the repository at this point in the history
  • Loading branch information
ashuping committed Oct 23, 2024
1 parent b81f6f3 commit b519ddf
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
27 changes: 26 additions & 1 deletion exseos/types/Variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from abc import ABC, abstractmethod
import logging
import numpy as np
from typing import TypeVar, Generic

A: TypeVar = TypeVar("A")
Expand Down Expand Up @@ -104,12 +105,36 @@ def __eq__(self, other: "Variable") -> bool:
if not issubclass(type(other), Variable):
return False

if self.is_bound:
if not other.is_bound:
return False

if type(self.val.val) is not type(other.val.val):
return False

# Some types must be handled specially
match type(self.val.val):
case np.ndarray:
# Need to check dimensions and dtype
if self.val.val.shape != other.val.val.shape:
return False

if self.val.val.dtype != other.val.val.dtype:
return False

# Numpy arrays have to have `.all()` called at the end
if not (self.val.val == other.val.val).all():
return False
case _:
# All other types are compared directly
if self.val.val != other.val.val:
return False

return all(
[
self.name == other.name,
self.is_bound == other.is_bound,
self.desc == other.desc,
(self.val == other.val) if self.is_bound else True,
self.var_type == other.var_type,
self.default == other.default,
]
Expand Down
9 changes: 9 additions & 0 deletions test/types/test_Variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from exseos.types.Option import Nothing, Some

from abc import ABC
import numpy as np
import pytest


Expand All @@ -42,6 +43,7 @@ def test_eq():
assert BoundVariable("x", 2) != UnboundVariable("x")

assert BoundVariable("x", 2, int) == BoundVariable("x", 2, int)
assert BoundVariable("x", 1, int) != BoundVariable("x", '1', str)
assert BoundVariable("x", 1, int) != BoundVariable("x", 1, str)
assert UnboundVariable("y", int) == UnboundVariable("y", int)
assert UnboundVariable("y", int) != UnboundVariable("y", str)
Expand Down Expand Up @@ -70,6 +72,13 @@ def test_eq():
assert UnboundVariable("y") != "y"


def test_numpy_arr_eq():
assert BoundVariable('a', np.array([1, 2, 3])) == BoundVariable('a', np.array([1, 2, 3]))
assert BoundVariable('a', np.array([1, 2, 3])) != BoundVariable('a', np.array([3, 2, 1]))
assert BoundVariable('a', np.array([1, 2, 3])) != BoundVariable('a', np.array([1, 2]))
assert BoundVariable('a', np.array([1, 2, 3])) != BoundVariable('a', np.array(['1', '2', '3']))


def test_val():
assert BoundVariable("x", 2).val == Some(2)
assert BoundVariable("x", None).val == Some(None)
Expand Down

0 comments on commit b519ddf

Please sign in to comment.