Skip to content

Commit

Permalink
Merge pull request #6 from jcapriot/transpose_solve
Browse files Browse the repository at this point in the history
add transpose option to solve call
  • Loading branch information
jcapriot authored Nov 14, 2023
2 parents bbd5658 + fee2501 commit 8a92a13
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 7 deletions.
10 changes: 8 additions & 2 deletions pydiso/mkl_solver.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ cdef class MKLPardisoSolver:
def __call__(self, b):
return self.solve(b)

def solve(self, b, x=None):
def solve(self, b, x=None, transpose=False):
"""solve(self, b, x=None, transpose=False)
Solves the equation AX=B using the factored A matrix
Expand All @@ -354,6 +354,8 @@ cdef class MKLPardisoSolver:
x : numpy.ndarray, optional
A pre-allocated output array (of the same data type as A).
If None, a new array is constructed.
transpose : bool, optional
If True, it will solve A^TX=B using the factored A matrix.
Returns
-------
Expand Down Expand Up @@ -388,6 +390,10 @@ cdef class MKLPardisoSolver:

cdef int_t nrhs = b.shape[1] if b.ndim == 2 else 1

if transpose:
self.set_iparm(11, 2)
else:
self.set_iparm(11, 0)
self._solve(bp, xp, nrhs)
return x

Expand Down Expand Up @@ -420,7 +426,7 @@ cdef class MKLPardisoSolver:
if self._is_32:
self._par.iparm[i] = val
else:
self._par.iparm[i] = val
self._par64.iparm[i] = val

@property
def nnz(self):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def configuration(parent_package="", top_path=None):
python_requires=">=3.8",
setup_requires=[
"numpy>=1.8",
"cython>=3.0",
"cython>=0.29.31",
],
install_requires=[
'numpy>=1.8',
Expand Down
24 changes: 20 additions & 4 deletions tests/test_pydiso.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
set_mkl_pardiso_threads,
)
import pytest
import sys

np.random.seed(12345)
n = 40
Expand Down Expand Up @@ -39,6 +40,7 @@
}


@pytest.mark.xfail(sys.platform == "darwin", reason="Unexpected Thread bug in third party library")
def test_thread_setting():
n1 = get_mkl_max_threads()
n2 = get_mkl_pardiso_max_threads()
Expand Down Expand Up @@ -93,8 +95,22 @@ def test_solver(A, matrix_type):
x2 = solver.solve(b)

eps = np.finfo(dtype).eps
rel_err = np.linalg.norm(x-x2)/np.linalg.norm(x)
assert rel_err < 1E3*eps
np.testing.assert_allclose(x, x2, atol=1E3*eps)

@pytest.mark.parametrize("A, matrix_type", inputs)
def test_transpose_solver(A, matrix_type):
dtype = A.dtype
if np.issubdtype(dtype, np.complexfloating):
x = xc.astype(dtype)
else:
x = xr.astype(dtype)
b = A.T @ x

solver = Solver(A, matrix_type=matrix_type)
x2 = solver.solve(b, transpose=True)

eps = np.finfo(dtype).eps
np.testing.assert_allclose(x, x2, atol=1E3*eps)

def test_multiple_RHS():
A = A_real_dict["real_symmetric_positive_definite"]
Expand All @@ -105,8 +121,7 @@ def test_multiple_RHS():
x2 = solver.solve(b)

eps = np.finfo(np.float64).eps
rel_err = np.linalg.norm(x-x2)/np.linalg.norm(x)
assert rel_err < 1E3*eps
np.testing.assert_allclose(x, x2, atol=1E3*eps)


def test_matrix_type_errors():
Expand All @@ -119,6 +134,7 @@ def test_matrix_type_errors():
solver = Solver(A, matrix_type="real_symmetric_positive_definite")



def test_rhs_size_error():
A = A_real_dict["real_symmetric_positive_definite"]
solver = Solver(A, "real_symmetric_positive_definite")
Expand Down

0 comments on commit 8a92a13

Please sign in to comment.