Skip to content

Commit

Permalink
Fixing cython import
Browse files Browse the repository at this point in the history
  • Loading branch information
mikaem committed Feb 7, 2024
1 parent 27e0a23 commit 7244f6f
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 17 deletions.
6 changes: 3 additions & 3 deletions shenfun/forms/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import sympy as sp
from shenfun.config import config
from shenfun.optimization.cython import evaluate
from shenfun.optimization import cython
from shenfun.spectralbase import BoundaryConditions
from mpi4py_fft import DistArray

Expand Down Expand Up @@ -775,10 +775,10 @@ def eval(self, x, output_array=None):
work = np.dot(P, bv)

elif len(x) == 2:
work = evaluate.evaluate_2D(work, bv, M, r2c, last_conj_index, sl)
work = cython.evaluate.evaluate_2D(work, bv, M, r2c, last_conj_index, sl)

elif len(x) == 3:
work = evaluate.evaluate_3D(work, bv, M, r2c, last_conj_index, sl)
work = cython.evaluate.evaluate_3D(work, bv, M, r2c, last_conj_index, sl)

sc = self.scales()[vec][base_j]
if not hasattr(sc, 'free_symbols'):
Expand Down
6 changes: 3 additions & 3 deletions shenfun/fourier/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import numpy as np
from mpi4py_fft import fftw
from shenfun.spectralbase import SpectralBase, Transform, islicedict, slicedict
from shenfun.optimization.cython import convolve
from shenfun.optimization import cython
from shenfun.config import config

bases = ['R2C', 'C2C']
Expand Down Expand Up @@ -470,7 +470,7 @@ def convolve(self, u, v, uv=None, fast=True):
uv = np.zeros(N+1, dtype=u.dtype)
Np = N if not N % 2 == 0 else N+1
k1 = np.fft.fftfreq(Np, 1./Np).astype(int)
convolve.convolve_real_1D(u, v, uv, k1)
cython.convolve.convolve_real_1D(u, v, uv, k1)

return uv

Expand Down Expand Up @@ -636,6 +636,6 @@ def convolve(self, u, v, uv=None, fast=True):

Np = N if not N % 2 == 0 else N+1
k = np.fft.fftfreq(Np, 1./Np).astype(int)
convolve.convolve_1D(u, v, uv, k)
cython.convolve.convolve_1D(u, v, uv, k)

return uv
8 changes: 4 additions & 4 deletions shenfun/hermite/matrices.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from shenfun.matrixbase import SpectralMatrix, SpectralMatDict
from shenfun.optimization.cython import Matvec
from shenfun.optimization import cython
from shenfun.la import TDMA
from . import bases

Expand Down Expand Up @@ -74,15 +74,15 @@ def matvec(self, v, c, format='cython', axis=0):
c.fill(0)
if format == 'cython' and v.ndim == 3:
ld = self[-2]*np.ones(M-2)
Matvec.Tridiagonal_matvec3D_ptr(v, c, ld, self[0], ld, axis)
cython.Matvec.Tridiagonal_matvec3D_ptr(v, c, ld, self[0], ld, axis)
self.scale_array(c, self.scale)
elif format == 'cython' and v.ndim == 2:
ld = self[-2]*np.ones(M-2)
Matvec.Tridiagonal_matvec2D_ptr(v, c, ld, self[0], ld, axis)
cython.Matvec.Tridiagonal_matvec2D_ptr(v, c, ld, self[0], ld, axis)
self.scale_array(c, self.scale)
elif format == 'cython' and v.ndim == 1:
ld = self[-2]*np.ones(M-2)
Matvec.Tridiagonal_matvec(v, c, ld, self[0], ld)
cython.Matvec.Tridiagonal_matvec(v, c, ld, self[0], ld)
self.scale_array(c, self.scale)
elif format == 'self':
if axis > 0:
Expand Down
5 changes: 4 additions & 1 deletion shenfun/legendre/dlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
from mpi4py_fft import fftw
from mpi4py_fft.fftw.utilities import FFTW_MEASURE, FFTW_PRESERVE_INPUT
from shenfun.optimization import runtimeoptimizer
from shenfun.optimization.cython import Leg2Cheb, Cheb2Leg, Lambda
from shenfun.optimization import cython
from shenfun.spectralbase import islicedict, slicedict
from shenfun.forms.arguments import FunctionSpace
from . import fastgl

__all__ = ['DLT', 'leg2cheb', 'cheb2leg', 'Leg2chebHaleTownsend',
'Leg2Cheb', 'Cheb2Leg', 'FMMLeg2Cheb', 'FMMCheb2Leg']

Leg2Cheb = getattr(cython, 'Leg2Cheb', None)
Cheb2Leg = getattr(cython, 'Cheb2Leg', None)
Lambda = getattr(cython, 'Lambda', None)

class DLT:
r"""Discrete Legendre Transform
Expand Down
5 changes: 4 additions & 1 deletion shenfun/optimization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
"""
import importlib
from functools import wraps
from . import cython
from shenfun.config import config

try:
from . import cython
except ModuleNotFoundError:
cython = None
try:
from . import numba
except ModuleNotFoundError:
Expand Down
6 changes: 3 additions & 3 deletions shenfun/tensorproductspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from shenfun.fourier.bases import R2C, C2C
from shenfun.utilities import apply_mask
from shenfun.forms.arguments import Function, Array
from shenfun.optimization.cython import evaluate
from shenfun.optimization import cython
from shenfun.spectralbase import slicedict, islicedict, SpectralBase
from shenfun.coordinates import Coordinates

Expand Down Expand Up @@ -837,10 +837,10 @@ def _eval_cython(self, points, coefficients, output_array):
last_conj_index = M
sl = self.local_slice()[axis].start
if len(self) == 2:
output_array = evaluate.evaluate_2D(output_array, coefficients, P, r2c, last_conj_index, sl)
output_array = cython.evaluate.evaluate_2D(output_array, coefficients, P, r2c, last_conj_index, sl)

elif len(self) == 3:
output_array = evaluate.evaluate_3D(output_array, coefficients, P, r2c, last_conj_index, sl)
output_array = cython.evaluate.evaluate_3D(output_array, coefficients, P, r2c, last_conj_index, sl)

output_array = np.atleast_1d(output_array)
output_array = comm.allreduce(output_array)
Expand Down
5 changes: 3 additions & 2 deletions shenfun/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
import sympy as sp
from scipy.fftpack import dct
from scipy.integrate import quad
from shenfun.optimization import runtimeoptimizer
from shenfun.optimization.cython import Lambda
from shenfun.optimization import runtimeoptimizer, cython
from shenfun.config import config
from .findbasis import get_bc_basis, get_stencil_matrix, n

Expand All @@ -21,6 +20,8 @@
'mayavi_show', 'quiver3D', 'get_bc_basis', 'get_stencil_matrix',
'scalar_product', 'n', 'cross', 'reset_profile', 'Lambda']

Lambda = getattr(cython, 'Lambda', None)

def dx(u, weighted=False):
r"""Compute integral of u over domain
Expand Down

0 comments on commit 7244f6f

Please sign in to comment.