Skip to content

Commit

Permalink
Merge pull request #50 from simpeg/conjugate
Browse files Browse the repository at this point in the history
Add conjugate solve
  • Loading branch information
jcapriot authored Oct 11, 2024
2 parents 32ade92 + 7d1dc30 commit 7f81ada
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 1 deletion.
34 changes: 33 additions & 1 deletion pymatsolver/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class Base(ABC):
__numpy_ufunc__ = True
__array_ufunc__ = None

_is_conjugate = False

def __init__(
self, A, is_symmetric=None, is_positive_definite=False, is_hermitian=None, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs
):
Expand Down Expand Up @@ -251,7 +253,13 @@ def _transpose_class(self):
return self.__class__

def transpose(self):
"""Return the transposed solve operator."""
"""Return the transposed solve operator.
Returns
-------
pymatsolver.solvers.Base
"""

if self.is_symmetric:
return self
if self._transpose_class is None:
Expand All @@ -274,6 +282,23 @@ def T(self):
"""
return self.transpose()

def conjugate(self):
"""Return the complex conjugate version of this solver.
Returns
-------
pymatsolver.solvers.Base
"""
if self.is_real:
return self
else:
# make a shallow copy of myself
conjugated = copy.copy(self)
conjugated._is_conjugate = not self._is_conjugate
return conjugated

conj = conjugate

def _compute_accuracy(self, rhs, x):
resid_norm = np.linalg.norm(rhs - self.A @ x)
rhs_norm = np.linalg.norm(rhs)
Expand Down Expand Up @@ -308,6 +333,8 @@ def solve(self, rhs):
if ndim == 1:
if len(rhs) != n:
raise ValueError(f'Expected a vector of length {n}, got {len(rhs)}')
if self._is_conjugate:
rhs = rhs.conjugate()
x = self._solve_single(rhs)
else:
if ndim == 2 and rhs.shape[-1] == 1:
Expand All @@ -331,6 +358,8 @@ def solve(self, rhs):
# (which is more common for direct solvers).
rhs = rhs.transpose()
# should end up with shape (n, -1)
if self._is_conjugate:
rhs = rhs.conjugate()
x = self._solve_multiple(rhs)
if do_broadcast:
# undo the reshaping above
Expand All @@ -347,6 +376,9 @@ def solve(self, rhs):
#TODO remove this in v0.4.0.
if x.size == n:
x = x.reshape(-1)

if self._is_conjugate:
x = x.conjugate()
return x

@abstractmethod
Expand Down
46 changes: 46 additions & 0 deletions tests/test_conjugate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest
import pymatsolver
import numpy as np
import scipy.sparse as sp
import numpy.testing as npt


@pytest.mark.parametrize('solver_class', [pymatsolver.Solver, pymatsolver.SolverLU, pymatsolver.Pardiso, pymatsolver.Mumps])
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
@pytest.mark.parametrize('n_rhs', [1, 4])
def test_conjugate_solve(solver_class, dtype, n_rhs):
if solver_class is pymatsolver.Pardiso and not pymatsolver.AvailableSolvers['Pardiso']:
pytest.skip("pydiso not installed.")
if solver_class is pymatsolver.Mumps and not pymatsolver.AvailableSolvers['Mumps']:
pytest.skip("python-mumps not installed.")

n = 10
D = sp.diags(np.linspace(1, 10, n))
if dtype == np.float64:
L = sp.diags([1, -1], [0, -1], shape=(n, n))

sol = np.linspace(0.9, 1.1, n)
# non-symmetric real matrix
else:
# non-symmetric
L = sp.diags([1, -1j], [0, -1], shape=(n, n))
sol = np.linspace(0.9, 1.1, n) - 1j * np.linspace(0.9, 1.1, n)[::-1]

if n_rhs > 1:
sol = np.pad(sol[:, None], [(0, 0), (0, n_rhs - 1)], mode='constant')

A = D @ L @ D @ L.T

# double check it solves
rhs = A @ sol
Ainv = solver_class(A)
npt.assert_allclose(Ainv @ rhs, sol)

# is conjugate solve correct?
rhs_conj = A.conjugate() @ sol
Ainv_conj = Ainv.conjugate()
npt.assert_allclose(Ainv_conj @ rhs_conj, sol)

# is conjugate -> conjugate solve correct?
Ainv2 = Ainv_conj.conjugate()
npt.assert_allclose(Ainv2 @ rhs, sol)

0 comments on commit 7f81ada

Please sign in to comment.