diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index 74190cb2..3f3a8446 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -67,10 +67,9 @@ def _size2shape(size): def _check_shape(name, shape, *param_shapes): - shape = core.as_named_shape(shape) if param_shapes: - shape_ = lax.broadcast_shapes(shape.positional, *param_shapes) - if shape.positional != shape_: + shape_ = lax.broadcast_shapes(shape, *param_shapes) + if shape != shape_: msg = ("{} parameter shapes must be broadcast-compatible with shape " "argument, and the result of broadcasting the shapes must equal " "the shape argument, but got result {} for shape argument {}.")