Skip to content

Commit

Permalink
Add locks to guard against non-threadsafe solver behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
jcapriot committed Nov 14, 2023
1 parent 874f935 commit 35d4d5d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
19 changes: 18 additions & 1 deletion pydiso/mkl_solver.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
#cython: linetrace=True
cimport numpy as np
from cython cimport numeric
from cpython.pythread cimport (
PyThread_type_lock,
PyThread_allocate_lock,
PyThread_acquire_lock,
PyThread_release_lock,
PyThread_free_lock
)

import warnings
import numpy as np
Expand Down Expand Up @@ -184,7 +191,7 @@ cdef class MKLPardisoSolver:
cdef int_t _factored
cdef size_t shape[2]
cdef int_t _initialized

cdef PyThread_type_lock lock
cdef void * a

cdef object _data_type
Expand Down Expand Up @@ -253,6 +260,9 @@ cdef class MKLPardisoSolver:
raise ValueError("Matrix is not square")
self.shape = n_row, n_col

# allocate the lock
self.lock = PyThread_allocate_lock()

self._data_type = A.dtype
if matrix_type is None:
if np.issubdtype(self._data_type, np.complexfloating):
Expand Down Expand Up @@ -496,6 +506,7 @@ cdef class MKLPardisoSolver:
cdef long_t phase64=-1, nrhs64=0, error64=0

if self._initialized:
PyThread_acquire_lock(self.lock, 1)
if self._is_32:
pardiso(
self.handle, &self._par.maxfct, &self._par.mnum, &self._par.mtype,
Expand All @@ -508,9 +519,12 @@ cdef class MKLPardisoSolver:
&phase64, &self._par64.n, self.a, NULL, NULL, NULL, &nrhs64,
self._par64.iparm, &self._par64.msglvl, NULL, NULL, &error64
)
PyThread_release_lock(self.lock)
err = error or error64
if err!=0:
raise PardisoError("Memmory release error "+_err_messages[err])
#dealloc lock
PyThread_free_lock(self.lock)

cdef _analyze(self):
#phase = 11
Expand Down Expand Up @@ -540,13 +554,16 @@ cdef class MKLPardisoSolver:
cdef int_t error=0
cdef long_t error64=0, phase64=phase, nrhs64=nrhs

PyThread_acquire_lock(self.lock, 1)
if self._is_32:
pardiso(self.handle, &self._par.maxfct, &self._par.mnum, &self._par.mtype,
&phase, &self._par.n, self.a, &self._par.ia[0], &self._par.ja[0],
&self._par.perm[0], &nrhs, self._par.iparm, &self._par.msglvl, b, x, &error)
PyThread_release_lock(self.lock)
return error
else:
pardiso_64(self.handle, &self._par64.maxfct, &self._par64.mnum, &self._par64.mtype,
&phase64, &self._par64.n, self.a, &self._par64.ia[0], &self._par64.ja[0],
&self._par64.perm[0], &nrhs64, self._par64.iparm, &self._par64.msglvl, b, x, &error64)
PyThread_release_lock(self.lock)

Check warning on line 568 in pydiso/mkl_solver.pyx

View check run for this annotation

Codecov / codecov/patch

pydiso/mkl_solver.pyx#L568

Added line #L568 was not covered by tests
return error64
23 changes: 23 additions & 0 deletions tests/test_pydiso.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
set_mkl_threads,
set_mkl_pardiso_threads,
)
from concurrent.futures import ThreadPoolExecutor
import pytest
import sys

Expand Down Expand Up @@ -147,3 +148,25 @@ def test_rhs_size_error():
solver.solve(b_bad)
with pytest.raises(ValueError):
solver.solve(b, x_bad)

def test_threading():
"""
Here we test that calling the solver is safe from multiple threads.
There isn't actually any speedup because it acquires a lock on each call
to pardiso internally (because those calls are not thread safe).
"""
n = 200
n_rhs = 75
A = sp.diags([-1, 2, -1], (-1, 0, 1), shape=(n, n), format='csr')
Ainv = Solver(A)

x_true = np.random.rand(n, n_rhs)
rhs = A @ x_true

with ThreadPoolExecutor() as pool:
x_sol = np.stack(
list(pool.map(lambda i: Ainv.solve(rhs[:, i]), range(n_rhs))),
axis=1
)

np.testing.assert_allclose(x_true, x_sol)

0 comments on commit 35d4d5d

Please sign in to comment.