diff --git a/imops/_configs.py b/imops/_configs.py index 61a1aaf7..4b81b2f6 100644 --- a/imops/_configs.py +++ b/imops/_configs.py @@ -1,30 +1,63 @@ +from functools import partial from itertools import product from .backend import Cupy, Cython, Numba, Scipy -scipy_configs = [Scipy()] -radon_configs = [Cython(fast) for fast in [False, True]] -numeric_configs = [ - Scipy(), - *[Cython(fast) for fast in [False, True]], +scipy_backends = [Scipy] +radon_backends = [partial(Cython, fast=fast) for fast in [False, True]] +numeric_backends = [ + Scipy, + *[partial(Cython, fast=fast) for fast in [False, True]], ] -measure_configs = [ - Scipy(), - *[Cython(fast) for fast in [False, True]], +measure_backends = [ + Scipy, + *[partial(Cython, fast=fast) for fast in [False, True]], ] -morphology_configs = [ - Scipy(), - *[Cython(fast) for fast in [False, True]], +morphology_backends = [ + Scipy, + *[partial(Cython, fast=fast) for fast in [False, True]], ] -zoom_configs = [ - Scipy(), - *[Cython(fast) for fast in [False, True]], - *[Numba(*flags) for flags in product([False, True], repeat=3)], - Cupy(), +zoom_backends = [ + Scipy, + *[partial(Cython, fast=fast) for fast in [False, True]], + *[partial(Numba, *flags) for flags in product([False, True], repeat=3)], + Cupy, ] -interp1d_configs = [ - Scipy(), - *[Cython(fast) for fast in [False, True]], - *[Numba(*flags) for flags in product([False, True], repeat=3)], +interp1d_backends = [ + Scipy, + *[partial(Cython, fast=fast) for fast in [False, True]], + *[partial(Numba, *flags) for flags in product([False, True], repeat=3)], ] + + +def is_available(backend): + try: + backend() + return True + except ModuleNotFoundError: + return False + + +def available_backends(backends): + return [backend for backend in backends if is_available(backend)] + + +def to_repr(backend): + if not isinstance(backend, partial): + return backend + + backend_name = backend.func.__name__ + backend_args = ', '.join(map(str, backend.args)) + backend_kwargs = ', '.join(f'{key}={value}' for key, value in backend.keywords.items()) + + if not backend_args: + if not backend_kwargs: + return backend_name + else: + return f'{backend_name}({backend_kwargs})' + else: + if not backend_kwargs: + return f'{backend_name}({backend_args})' + + return f'{backend_name}({backend_args}, {backend_kwargs})' diff --git a/tests/test_zoom.py b/tests/test_zoom.py index 9cfc3d93..32a67304 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -6,7 +6,7 @@ from numpy.testing import assert_allclose from scipy.ndimage import zoom as scipy_zoom -from imops._configs import zoom_configs +from imops._configs import available_backends, to_repr, zoom_backends from imops.backend import SINGLE_THREADED_BACKENDS, Backend from imops.utils import ZOOM_SRC_DIM, get_c_contiguous_permutaion, inverse_permutation, make_immutable from imops.zoom import _zoom, zoom, zoom_to_shape @@ -32,9 +32,12 @@ class Alien1(Backend): pass -@pytest.fixture(params=zoom_configs, ids=map(str, zoom_configs)) +tested_backends = available_backends(zoom_backends) + + +@pytest.fixture(params=tested_backends, ids=map(to_repr, tested_backends)) def backend(request): - return request.param + return request.param() @pytest.fixture(params=[0, 1])