Skip to content

Commit

Permalink
undo some types
Browse files Browse the repository at this point in the history
  • Loading branch information
jcapriot committed Mar 1, 2024
1 parent d018d0b commit 88f7ea8
Showing 1 changed file with 32 additions and 29 deletions.
61 changes: 32 additions & 29 deletions pydiso/mkl_solver.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ import numpy as np
import scipy.sparse as sp
import os

cdef extern from 'mkl.h':
ctypedef long long MKL_INT64
ctypedef unsigned long long MKL_UINT64
ctypedef int MKL_INT

ctypedef MKL_INT int_t
ctypedef MKL_INT64 long_t

ctypedef long long MKL_INT64
ctypedef unsigned long long MKL_UINT64
ctypedef int MKL_INT
cdef extern from 'mkl.h':
int MKL_DOMAIN_PARDISO

ctypedef struct MKLVersion:
Expand All @@ -44,17 +47,17 @@ cdef extern from 'mkl.h':

ctypedef void * _MKL_DSS_HANDLE_t

void pardiso(_MKL_DSS_HANDLE_t, const MKL_INT*, const MKL_INT*, const MKL_INT*,
const MKL_INT *, const MKL_INT *, const void *, const MKL_INT *,
const MKL_INT *, MKL_INT *, const MKL_INT *, MKL_INT *,
const MKL_INT *, void *, void *, MKL_INT *) nogil
void pardiso(_MKL_DSS_HANDLE_t, const int*, const int*, const int*,
const int *, const int *, const void *, const int *,
const int *, int *, const int_t *, int *,
const int *, void *, void *, int *) nogil

void pardiso_64(_MKL_DSS_HANDLE_t, const long long int *, const long long int *, const long long int *,
const long long int *, const long long int *, const void *, const long long int *,
const long long int *, long long int *, const long long int *, long long int *,
const long long int *, void *, void *, long long int *) nogil
void pardiso_64(_MKL_DSS_HANDLE_t, const long_t *, const long_t *, const long_t *,
const long_t *, const long_t *, const void *, const long_t *,
const long_t *, long_t *, const long_t *, long_t *,
const long_t *, void *, void *, long_t *) nogil

if sizeof(MKL_INT) == 4:
if sizeof(int_t) == 4:
_np_mkl_int = np.int32
else:
_np_mkl_int = np.int64
Expand Down Expand Up @@ -172,14 +175,14 @@ def get_mkl_version():
return vers

cdef class _PardisoParams:
cdef MKL_INT iparm[64]
cdef MKL_INT n, mtype, maxfct, mnum, msglvl
cdef MKL_INT[:] ia, ja, perm
cdef int_t iparm[64]
cdef int_t n, mtype, maxfct, mnum, msglvl
cdef int_t[:] ia, ja, perm

cdef class _PardisoParams64:
cdef MKL_INT64 iparm[64]
cdef MKL_INT64 n, mtype, maxfct, mnum, msglvl
cdef MKL_INT64[:] ia, ja, perm
cdef long_t iparm[64]
cdef long_t n, mtype, maxfct, mnum, msglvl
cdef long_t[:] ia, ja, perm

ctypedef fused _par_params:
_PardisoParams
Expand Down Expand Up @@ -303,7 +306,7 @@ cdef class MKLPardisoSolver:
integer_len = A.indices.itemsize
# we only need to call the 64 bit version if
# sizeof(MKL_INT) == 4 and A.indices.itemsize == 8
self._call32 = not (sizeof(MKL_INT) == 4 and integer_len == 8)
self._call32 = not (sizeof(int_t) == 4 and integer_len == 8)
if self._call32:
self._par = _PardisoParams()
self._initialize(self._par, A, matrix_type, verbose)
Expand Down Expand Up @@ -413,7 +416,7 @@ cdef class MKLPardisoSolver:
if bp == xp:
raise PardisoError("b and x must be different arrays")

cdef MKL_INT nrhs = b.shape[1] if b.ndim == 2 else 1
cdef int_t nrhs = b.shape[1] if b.ndim == 2 else 1

if transpose:
self.set_iparm(11, 2)
Expand All @@ -440,7 +443,7 @@ cdef class MKLPardisoSolver:
else:
return np.array(self._par64.iparm)

def set_iparm(self, MKL_INT i, MKL_INT val):
def set_iparm(self, int i, int val):
if i > 63 or i < 0:
raise IndexError(f"index {i} is out of bounds for size 64 array")
if i not in [
Expand Down Expand Up @@ -514,8 +517,8 @@ cdef class MKLPardisoSolver:

def __del__(self):
# Need to call pardiso with phase=-1 to release memory
cdef MKL_INT phase=-1, nrhs=0, error=0
cdef MKL_INT64 phase64=-1, nrhs64=0, error64=0
cdef int_t phase=-1, nrhs=0, error=0
cdef long_t phase64=-1, nrhs64=0, error64=0

# Only need to deallocate if the handle itself was ever allocated.
if self._initialized():
Expand Down Expand Up @@ -561,7 +564,7 @@ cdef class MKLPardisoSolver:

self._factored = True

cdef _solve(self, void* b, void* x, MKL_INT nrhs_in):
cdef _solve(self, void* b, void* x, int_t nrhs_in):
#phase = 33
if(not self._factored):
raise PardisoError("Cannot solve without a previous factorization.")
Expand All @@ -571,9 +574,9 @@ cdef class MKLPardisoSolver:
raise PardisoError("Solve step error, "+_err_messages[err])

@cython.boundscheck(False)
cdef MKL_INT _run_pardiso(self, MKL_INT phase, void* b=NULL, void* x=NULL, MKL_INT nrhs=0) nogil:
cdef MKL_INT error=0
cdef MKL_INT64 error64=0, phase64=phase, nrhs64=nrhs
cdef int _run_pardiso(self, int_t phase, void* b=NULL, void* x=NULL, int_t nrhs=0) nogil:
cdef int_t error=0
cdef long_t error64=0, phase64=phase, nrhs64=nrhs

PyThread_acquire_lock(self.lock, 1)
if self._call32:
Expand All @@ -585,6 +588,6 @@ cdef class MKLPardisoSolver:
&phase64, &self._par64.n, self.a, &self._par64.ia[0], &self._par64.ja[0],
&self._par64.perm[0], &nrhs64, self._par64.iparm, &self._par64.msglvl, b, x, &error64)
PyThread_release_lock(self.lock)
error = error or <MKL_INT> error64
error = error or <int_t> error64
return error

0 comments on commit 88f7ea8

Please sign in to comment.