diff --git a/pymatsolver/solvers.py b/pymatsolver/solvers.py index f8d79a4..4d0317c 100644 --- a/pymatsolver/solvers.py +++ b/pymatsolver/solvers.py @@ -47,6 +47,8 @@ class Base(ABC): __numpy_ufunc__ = True __array_ufunc__ = None + _is_conjugate = False + def __init__( self, A, is_symmetric=None, is_positive_definite=False, is_hermitian=None, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs ): @@ -251,7 +253,13 @@ def _transpose_class(self): return self.__class__ def transpose(self): - """Return the transposed solve operator.""" + """Return the transposed solve operator. + + Returns + ------- + pymatsolver.solvers.Base + """ + if self.is_symmetric: return self if self._transpose_class is None: @@ -274,6 +282,23 @@ def T(self): """ return self.transpose() + def conjugate(self): + """Return the complex conjugate version of this solver. + + Returns + ------- + pymatsolver.solvers.Base + """ + if self.is_real: + return self + else: + # make a shallow copy of myself + conjugated = copy.copy(self) + conjugated._is_conjugate = not self._is_conjugate + return conjugated + + conj = conjugate + def _compute_accuracy(self, rhs, x): resid_norm = np.linalg.norm(rhs - self.A @ x) rhs_norm = np.linalg.norm(rhs) @@ -308,6 +333,8 @@ def solve(self, rhs): if ndim == 1: if len(rhs) != n: raise ValueError(f'Expected a vector of length {n}, got {len(rhs)}') + if self._is_conjugate: + rhs = rhs.conjugate() x = self._solve_single(rhs) else: if ndim == 2 and rhs.shape[-1] == 1: @@ -331,6 +358,8 @@ def solve(self, rhs): # (which is more common for direct solvers). rhs = rhs.transpose() # should end up with shape (n, -1) + if self._is_conjugate: + rhs = rhs.conjugate() x = self._solve_multiple(rhs) if do_broadcast: # undo the reshaping above @@ -347,6 +376,9 @@ def solve(self, rhs): #TODO remove this in v0.4.0. if x.size == n: x = x.reshape(-1) + + if self._is_conjugate: + x = x.conjugate() return x @abstractmethod diff --git a/tests/test_conjugate.py b/tests/test_conjugate.py new file mode 100644 index 0000000..4dc2529 --- /dev/null +++ b/tests/test_conjugate.py @@ -0,0 +1,46 @@ +import pytest +import pymatsolver +import numpy as np +import scipy.sparse as sp +import numpy.testing as npt + + +@pytest.mark.parametrize('solver_class', [pymatsolver.Solver, pymatsolver.SolverLU, pymatsolver.Pardiso, pymatsolver.Mumps]) +@pytest.mark.parametrize('dtype', [np.float64, np.complex128]) +@pytest.mark.parametrize('n_rhs', [1, 4]) +def test_conjugate_solve(solver_class, dtype, n_rhs): + if solver_class is pymatsolver.Pardiso and not pymatsolver.AvailableSolvers['Pardiso']: + pytest.skip("pydiso not installed.") + if solver_class is pymatsolver.Mumps and not pymatsolver.AvailableSolvers['Mumps']: + pytest.skip("python-mumps not installed.") + + n = 10 + D = sp.diags(np.linspace(1, 10, n)) + if dtype == np.float64: + L = sp.diags([1, -1], [0, -1], shape=(n, n)) + + sol = np.linspace(0.9, 1.1, n) + # non-symmetric real matrix + else: + # non-symmetric + L = sp.diags([1, -1j], [0, -1], shape=(n, n)) + sol = np.linspace(0.9, 1.1, n) - 1j * np.linspace(0.9, 1.1, n)[::-1] + + if n_rhs > 1: + sol = np.pad(sol[:, None], [(0, 0), (0, n_rhs - 1)], mode='constant') + + A = D @ L @ D @ L.T + + # double check it solves + rhs = A @ sol + Ainv = solver_class(A) + npt.assert_allclose(Ainv @ rhs, sol) + + # is conjugate solve correct? + rhs_conj = A.conjugate() @ sol + Ainv_conj = Ainv.conjugate() + npt.assert_allclose(Ainv_conj @ rhs_conj, sol) + + # is conjugate -> conjugate solve correct? + Ainv2 = Ainv_conj.conjugate() + npt.assert_allclose(Ainv2 @ rhs, sol) \ No newline at end of file