From ac6e04773b39cd2f2dd5b1a148fbea8a43702b64 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 14 Jan 2025 04:15:14 -0800 Subject: [PATCH] jax.numpy.clip: update use of deprecated arguments. - a is now positional-only - a_min is now min - a_max is now max The old argument names have been deprecated since JAX v0.4.27. PiperOrigin-RevId: 715321661 --- chirp/models/hubert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chirp/models/hubert.py b/chirp/models/hubert.py index 1e1dcfbf..6fd8b1ca 100644 --- a/chirp/models/hubert.py +++ b/chirp/models/hubert.py @@ -77,7 +77,7 @@ def compute_mask_indices( num_mask = mask_prob * sz / jnp.array(mask_length, float) + rounding_offset num_mask = jnp.full(bsz, num_mask).astype(int) max_masks = sz - mask_length + 1 - num_mask = jnp.clip(num_mask, a_min=min_masks, a_max=max_masks) + num_mask = jnp.clip(num_mask, min=min_masks, max=max_masks) # First, sample a set of start indices for the max possible number of masks. # Do this sampling separately for each batch sample, to allow `replace`=False.