diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index 791c8d9f..b435415d 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -660,7 +660,7 @@ def searchsorted(self, v, side='left', sorter=None): """ return _return(self.value.searchsorted(v=_as_jax_array_(v), side=side, sorter=sorter)) - def sort(self, axis=-1, kind='quicksort', order=None): + def sort(self, axis=-1, stable=True, order=None): """Sort an array in-place. Parameters @@ -668,11 +668,8 @@ def sort(self, axis=-1, kind='quicksort', order=None): axis : int, optional Axis along which to sort. Default is -1, which means sort along the last axis. - kind : {'quicksort', 'mergesort', 'heapsort', 'stable'} - Sorting algorithm. The default is 'quicksort'. Note that both 'stable' - and 'mergesort' use timsort under the covers and, in general, the - actual implementation will vary with datatype. The 'mergesort' option - is retained for backwards compatibility. + stable : bool, optional + Whether to use a stable sorting algorithm. The default is True. order : str or list of str, optional When `a` is an array with fields defined, this argument specifies which fields to compare first, second, etc. A single field can @@ -680,7 +677,8 @@ def sort(self, axis=-1, kind='quicksort', order=None): but unspecified fields will still be used, in the order in which they come up in the dtype, to break ties. """ - self.value = self.value.sort(axis=axis, kind=kind, order=order) + self.value = self.value.sort(axis=axis, stable=stable, order=order) + def squeeze(self, axis=None): """Remove axes of length one from ``a``.""" diff --git a/brainpy/_src/math/object_transform/tests/test_jit.py b/brainpy/_src/math/object_transform/tests/test_jit.py index d52903d4..b614d081 100644 --- a/brainpy/_src/math/object_transform/tests/test_jit.py +++ b/brainpy/_src/math/object_transform/tests/test_jit.py @@ -52,7 +52,7 @@ def __call__(self, *args, **kwargs): def test_jit_with_static(self): a = bm.Variable(bm.ones(2)) - @bm.jit(static_argnums=1) + @bm.jit(static_argnums=0) def f(b, c): a.value *= b a.value /= c @@ -104,7 +104,7 @@ def __init__(self): self.a = bm.zeros(2) self.b = bm.Variable(bm.ones(2)) - self.call1 = bm.jit(self.call, static_argnums=0) + self.call1 = bm.jit(self.call, static_argnums=1) self.call2 = bm.jit(self.call, static_argnames=['fit']) def call(self, fit=True): @@ -157,7 +157,7 @@ class MyObj: def __init__(self): self.a = bm.Variable(bm.ones(2)) - @bm.cls_jit(static_argnums=1) + @bm.cls_jit(static_argnums=0) def f(self, b, c): self.a.value *= b self.a.value /= c