Skip to content

Commit

Permalink
complex
Browse files Browse the repository at this point in the history
  • Loading branch information
gertjanvanzwieten committed Apr 17, 2024
1 parent ff5743a commit 4f18cde
Showing 1 changed file with 9 additions and 36 deletions.
45 changes: 9 additions & 36 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,10 +852,6 @@ def _node(self, cache, subgraph, times):

# simplifications
_eig = lambda self, symmetric: None
_loopsum = lambda self, loop_index: None # NOTE: type of `loop_index` is `_LoopIndex`
_real = lambda self: None
_imag = lambda self: None
_conjugate = lambda self: None

representations = ()

Expand Down Expand Up @@ -1108,9 +1104,6 @@ def evalf(func, length):
def _derivative(self, var, seen):
return insertaxis(derivative(self.func, var, seen), self.ndim-1, self.length)

def _loopsum(self, index):
return InsertAxis(loop_sum(self.func, index), self.length)

@cached_property
def _assparse(self):
return tuple((*(InsertAxis(idx, self.length) for idx in indices), prependaxes(Range(self.length), values.shape), InsertAxis(values, self.length)) for *indices, values in self.func._assparse)
Expand Down Expand Up @@ -1278,9 +1271,6 @@ def _axes_for(self, ndim, axis):
axes[axis:axis] = range(funcaxis, funcaxis + ndim)
return tuple(axes)

def _loopsum(self, index):
return Transpose(loop_sum(self.func, index), self.axes)

@cached_property
def _assparse(self):
return tuple((*(indices[i] for i in self.axes), values) for *indices, values in self.func._assparse)
Expand Down Expand Up @@ -2573,9 +2563,8 @@ def return_type(T):
return complex

def _simplified(self):
retval = self.args[0]._conjugate()
if retval is not None:
return retval
for f, in self.args[0]._as(FloatToComplex):
return FloatToComplex(f)
return super()._simplified()


Expand All @@ -2588,9 +2577,8 @@ def return_type(T):
return float

def _simplified(self):
retval = self.args[0]._real()
if retval is not None:
return retval
for f, in self.args[0]._as(FloatToComplex):
return f
return super()._simplified()


Expand All @@ -2603,9 +2591,8 @@ def return_type(T):
return float

def _simplified(self):
retval = self.args[0]._imag()
if retval is not None:
return retval
for f, in self.args[0]._as(FloatToComplex):
return zeros_like(self)
return super()._simplified()


Expand Down Expand Up @@ -2652,15 +2639,6 @@ def return_type(T):
raise TypeError(f'Expected an array with dtype float but got {T.__name__}.')
return complex

def _real(self):
return self.args[0]

def _imag(self):
return zeros_like(self.args[0])

def _conjugate(self):
return self

def _derivative(self, var, seen):
if var.dtype == complex:
raise ValueError('The complex derivative does not exist.')
Expand Down Expand Up @@ -3101,9 +3079,6 @@ def evalf(arr):
def _derivative(self, var, seen):
return diagonalize(derivative(self.func, var, seen), self.ndim-2, self.ndim-1)

def _loopsum(self, index):
return Diagonalize(loop_sum(self.func, index))

@cached_property
def _assparse(self):
return tuple((*indices, indices[-1], values) for *indices, values in self.func._assparse)
Expand Down Expand Up @@ -3343,9 +3318,6 @@ def evalf(f):
def _derivative(self, var, seen):
return ravel(derivative(self.func, var, seen), axis=self.ndim-1)

def _loopsum(self, index):
return Ravel(loop_sum(self.func, index))

@cached_property
def _assparse(self):
return tuple((*indices[:-2], indices[-2]*self.func.shape[-1]+indices[-1], values) for *indices, values in self.func._assparse)
Expand Down Expand Up @@ -4320,9 +4292,10 @@ def _node_loop_body(self, cache, subgraph, times):
def _simplified(self):
if isinstance(self.func, Zeros):
return zeros_like(self)
elif self.index not in self.func.arguments:
if self.index not in self.func.arguments:
return self.func * astype(self.index.length, self.func.dtype)
return self.func._loopsum(self.index)
if simple := self._as_any(insertaxis, diagonalize, ravel):
return simple

@cached_property
def _assparse(self):
Expand Down

0 comments on commit 4f18cde

Please sign in to comment.