Skip to content

Commit

Permalink
Defer return_scalar for OperatorVariables
Browse files Browse the repository at this point in the history
  • Loading branch information
guyer committed Jun 20, 2024
1 parent 908825d commit 6e6c68e
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 13 deletions.
6 changes: 5 additions & 1 deletion fipy/variables/binaryOperatorVariable.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ def _calcValue_(self):
self.var[1] = physicalField.PhysicalField(value=self.var[1])
val1 = self.var[1]

return self.op(self.var[0].value, val1)
value = self.op(self.var[0].value, val1)
if self.return_scalar:
value = value[()]

return value

@property
def unit(self):
Expand Down
3 changes: 2 additions & 1 deletion fipy/variables/operatorVariable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def _OperatorVariableClass(baseClass=object):
class _OperatorVariable(baseClass):
def __init__(self, op, var, opShape=(), canInline=True, unit=None, inlineComment=None, valueMattersForUnit=None, *args, **kwargs):
def __init__(self, op, var, opShape=(), canInline=True, unit=None, inlineComment=None, valueMattersForUnit=None, return_scalar=False, *args, **kwargs):
self.op = op
self.var = var
self.opShape = opShape
Expand All @@ -23,6 +23,7 @@ def __init__(self, op, var, opShape=(), canInline=True, unit=None, inlineComment
self.valueMattersForUnit = [False for v in var]
else:
self.valueMattersForUnit = valueMattersForUnit
self.return_scalar = return_scalar
self.canInline = canInline #allows for certain functions to opt out of --inline
baseClass.__init__(self, value=None, *args, **kwargs)
self.name = ''
Expand Down
6 changes: 5 additions & 1 deletion fipy/variables/unaryOperatorVariable.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def _UnaryOperatorVariable(operatorClass=None):

class unOp(operatorClass):
def _calcValue_(self):
return self.op(self.var[0].value)
value = self.op(self.var[0].value)
if self.return_scalar:
value = value[()]

return value

@property
def unit(self):
Expand Down
23 changes: 13 additions & 10 deletions fipy/variables/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,12 @@ def __makeVariable(v):
cannotInline = ["expi", "logical_and", "logical_or", "logical_not", "logical_xor", "sign",
"conjugate", "dot", "allclose", "allequal"]
if len(args) == 1:
result = args[0]._UnaryOperatorVariable(op=func, opShape=arr.shape, canInline=func.__name__ not in cannotInline)
result = args[0]._UnaryOperatorVariable(op=func, opShape=arr.shape, canInline=func.__name__ not in cannotInline, return_scalar=return_scalar)
elif len(args) == 2:
result = args[0]._BinaryOperatorVariable(op=func, other=args[1], opShape=arr.shape, canInline=func.__name__ not in cannotInline)
result = args[0]._BinaryOperatorVariable(op=func, other=args[1], opShape=arr.shape, canInline=func.__name__ not in cannotInline, return_scalar=return_scalar)
else:
result = NotImplemented

if return_scalar and result != NotImplemented:
# NumPy 2.0 compatibility
result = result[()]

return result

def __array__(self, dtype=None, copy=None):
Expand Down Expand Up @@ -989,7 +985,7 @@ def _OperatorVariableClass(self, baseClass=None):
return operatorVariable._OperatorVariableClass(baseClass=baseClass)

def _UnaryOperatorVariable(self, op, operatorClass=None, opShape=None, canInline=True, unit=None,
valueMattersForUnit=False):
valueMattersForUnit=False, return_scalar=False):
"""
Check that unit works for `unOp`
Expand All @@ -1008,6 +1004,9 @@ def _UnaryOperatorVariable(self, op, operatorClass=None, opShape=None, canInline
valueMattersForUnit : bool
Whether value of `self` should be used when determining unit,
e.g., ???
return_scalar : bool
Whether to reduce returned zero rank array to a scalar
# Introduced in NumPy 2.0
"""
operatorClass = operatorClass or self._OperatorVariableClass()
from fipy.variables import unaryOperatorVariable
Expand All @@ -1026,7 +1025,7 @@ def _UnaryOperatorVariable(self, op, operatorClass=None, opShape=None, canInline

return unOp(op=op, var=[self], opShape=opShape, canInline=canInline, unit=unit,
inlineComment=inline._operatorVariableComment(canInline=canInline),
valueMattersForUnit=[valueMattersForUnit])
valueMattersForUnit=[valueMattersForUnit], return_scalar=return_scalar)

def _shapeClassAndOther(self, opShape, operatorClass, other):
"""
Expand All @@ -1047,7 +1046,7 @@ def _shapeClassAndOther(self, opShape, operatorClass, other):
return (opShape, baseClass, other)

def _BinaryOperatorVariable(self, op, other, operatorClass=None, opShape=None, canInline=True, unit=None,
value0mattersForUnit=False, value1mattersForUnit=False):
value0mattersForUnit=False, value1mattersForUnit=False, return_scalar=False):
"""
Parameters
----------
Expand All @@ -1066,6 +1065,9 @@ def _BinaryOperatorVariable(self, op, other, operatorClass=None, opShape=None, c
value1MattersForUnit : bool
Whether value of `self` should be used when determining unit,
e.g., `__pow__`
return_scalar : bool
Whether to reduce returned zero rank array to a scalar
# Introduced in NumPy 2.0
"""
if not isinstance(other, Variable):
from fipy.variables.constant import _Constant
Expand All @@ -1087,7 +1089,8 @@ def _BinaryOperatorVariable(self, op, other, operatorClass=None, opShape=None, c

return binOp(op=op, var=[self, other], opShape=opShape, canInline=canInline, unit=unit,
inlineComment=inline._operatorVariableComment(canInline=canInline),
valueMattersForUnit=[value0mattersForUnit, value1mattersForUnit])
valueMattersForUnit=[value0mattersForUnit, value1mattersForUnit],
return_scalar=return_scalar)

def __add__(self, other):
from fipy.terms.term import Term
Expand Down

0 comments on commit 6e6c68e

Please sign in to comment.