Skip to content

Commit

Permalink
Update references to JAX's GitHub repo
Browse files Browse the repository at this point in the history
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax

PiperOrigin-RevId: 702886981
  • Loading branch information
jakeharmon8 authored and copybara-github committed Dec 5, 2024
1 parent 8ca8408 commit fef0981
Show file tree
Hide file tree
Showing 11 changed files with 24 additions and 24 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ You can learn here how Trax works, how to create new models and how to train the

The basic units flowing through Trax models are *tensors* - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- `numpy`. You should take a look at the [numpy guide](https://numpy.org/doc/stable/user/quickstart.html) if you don't know how to operate on tensors: Trax also uses the numpy API for that.

In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/google/jax) and [TensorFlow numpy](https://tensorflow.org/guide/tf_numpy).
In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/jax-ml/jax) and [TensorFlow numpy](https://tensorflow.org/guide/tf_numpy).


```python
Expand Down
2 changes: 1 addition & 1 deletion docs/source/notebooks/tf_numpy_and_keras.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"\n",
"In Trax, all computations rely on accelerated math operations happening in the `fastmath` module. This module can use different backends for acceleration. One of them is [TensorFlow NumPy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) which uses [TensorFlow 2](https://www.tensorflow.org/) to accelerate the computations.\n",
"\n",
"The backend can be set using a call to `trax.fastmath.set_backend` as you'll see below. Currently available backends are `jax` (default), `tensorflow-numpy` and `numpy` (for debugging). The `tensorflow-numpy` backend uses [TensorFlow Numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) for executing `fastmath` functions on TensorFlow, while the `jax` backend calls [JAX](https://github.com/google/jax) which lowers to TensorFlow XLA.\n",
"The backend can be set using a call to `trax.fastmath.set_backend` as you'll see below. Currently available backends are `jax` (default), `tensorflow-numpy` and `numpy` (for debugging). The `tensorflow-numpy` backend uses [TensorFlow Numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) for executing `fastmath` functions on TensorFlow, while the `jax` backend calls [JAX](https://github.com/jax-ml/jax) which lowers to TensorFlow XLA.\n",
"\n",
"You may see that `tensorflow-numpy` and `jax` backends show different speed and memory characteristics. You may also see different error messages when debugging since it might expose you to the internals of the backends. However for the most part, users can choose a backend and not worry about the internal details of these backends.\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/source/notebooks/trax_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@
"\n",
"The basic units flowing through Trax models are *tensors* - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- `numpy`. You should take a look at the [numpy guide](https://numpy.org/doc/stable/user/quickstart.html) if you don't know how to operate on tensors: Trax also uses the numpy API for that.\n",
"\n",
"In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/google/jax) and [TensorFlow numpy](https://tensorflow.org)."
"In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/jax-ml/jax) and [TensorFlow numpy](https://tensorflow.org)."
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion trax/intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@
"\n",
"The basic units flowing through Trax models are *tensors* - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- `numpy`. You should take a look at the [numpy guide](https://numpy.org/doc/stable/user/quickstart.html) if you don't know how to operate on tensors: Trax also uses the numpy API for that.\n",
"\n",
"In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/google/jax) and [TensorFlow numpy](https://tensorflow.org)."
"In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/jax-ml/jax) and [TensorFlow numpy](https://tensorflow.org)."
]
},
{
Expand Down
8 changes: 4 additions & 4 deletions trax/supervised/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_loop_with_initialized_model(self):

def test_train_save_restore_dense(self):
"""Saves and restores a checkpoint to check for equivalence."""
self.skipTest('Broken by https://github.com/google/jax/pull/11234')
self.skipTest('Broken by https://github.com/jax-ml/jax/pull/11234')
train_data = data.Serial(lambda _: _very_simple_data(),
data.CountAndSkip('simple_data'))
task = training.TrainTask(
Expand Down Expand Up @@ -327,7 +327,7 @@ def test_restores_step(self):

def test_restores_memory_efficient_from_standard(self):
"""Training restores step from directory where it saved it."""
self.skipTest('Broken by https://github.com/google/jax/pull/11234')
self.skipTest('Broken by https://github.com/jax-ml/jax/pull/11234')
model = tl.Serial(tl.Dense(4), tl.Dense(1))
task_std = training.TrainTask(
_very_simple_data(), tl.L2Loss(), optimizers.Adam(.0001))
Expand All @@ -345,7 +345,7 @@ def test_restores_memory_efficient_from_standard(self):

def test_restores_from_smaller_model(self):
"""Training restores from a checkpoint created with smaller model."""
self.skipTest('Broken by https://github.com/google/jax/pull/11234')
self.skipTest('Broken by https://github.com/jax-ml/jax/pull/11234')
model1 = tl.Serial(tl.Dense(1))
task = training.TrainTask(
_very_simple_data(), tl.L2Loss(), optimizers.Adam(.01))
Expand Down Expand Up @@ -374,7 +374,7 @@ def test_restore_fails_different_model(self):

def test_restores_step_bfloat16(self):
"""Training restores step from directory where it saved it, w/ bfloat16."""
self.skipTest('Broken by https://github.com/google/jax/pull/11234')
self.skipTest('Broken by https://github.com/jax-ml/jax/pull/11234')
model = tl.Serial(tl.Dense(1, use_bfloat16=True))
# We'll also use Adafactor with bfloat16 to check restoring bfloat slots.
opt = optimizers.Adafactor(.01, do_momentum=True, momentum_in_bfloat16=True)
Expand Down
2 changes: 1 addition & 1 deletion trax/tf_numpy/jax_tests/lax_numpy_einsum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_two_operands_5(self):
self._check(s, x, y)

def test_two_operands_6(self):
# based on https://github.com/google/jax/issues/37#issuecomment-448572187
# based on https://github.com/jax-ml/jax/issues/37#issuecomment-448572187
r = self.rng()
x = r.randn(2, 1)
y = r.randn(2, 3, 4)
Expand Down
2 changes: 1 addition & 1 deletion trax/tf_numpy/jax_tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def testFloatIndexingError(self):
with self.assertRaisesRegex(IndexError, error_regex):
npe.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.))

def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245
def testIndexOutOfBounds(self): # https://github.com/jax-ml/jax/issues/2245
array = jnp.ones(5)
self.assertAllClose(array, array[:10], check_dtypes=True)

Expand Down
16 changes: 8 additions & 8 deletions trax/tf_numpy/jax_tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2051,14 +2051,14 @@ def testAstype(self):
# TODO(mattjj): test other ndarray-like method overrides

def testOnpMean(self):
# from https://github.com/google/jax/issues/125
# from https://github.com/jax-ml/jax/issues/125
x = lnp.add(lnp.eye(3, dtype=lnp.float_), 0.)
ans = onp.mean(x)
self.assertAllClose(ans, onp.array(1./3), check_dtypes=False)

@jtu.disable
def testArangeOnFloats(self):
# from https://github.com/google/jax/issues/145
# from https://github.com/jax-ml/jax/issues/145
expected = onp.arange(0.0, 1.0, 0.1, dtype=lnp.float_)
ans = lnp.arange(0.0, 1.0, 0.1)
self.assertAllClose(expected, ans, check_dtypes=True)
Expand Down Expand Up @@ -2407,7 +2407,7 @@ def testSymmetrizeDtypePromotion(self):

@jtu.disable
def testIssue347(self):
# https://github.com/google/jax/issues/347
# https://github.com/jax-ml/jax/issues/347
def test_fail(x):
x = lnp.sqrt(lnp.sum(x ** 2, axis=1))
ones = lnp.ones_like(x)
Expand All @@ -2419,7 +2419,7 @@ def test_fail(x):
assert not onp.any(onp.isnan(result))

def testIssue453(self):
# https://github.com/google/jax/issues/453
# https://github.com/jax-ml/jax/issues/453
a = onp.arange(6) + 1
ans = lnp.reshape(a, (3, 2), order='F')
expected = onp.reshape(a, (3, 2), order='F')
Expand All @@ -2432,7 +2432,7 @@ def testIssue453(self):
(bool, lnp.bool_), (complex, lnp.complex_)]
for op in ["atleast_1d", "atleast_2d", "atleast_3d"]))
def testAtLeastNdLiterals(self, pytype, dtype, op):
# Fixes: https://github.com/google/jax/issues/634
# Fixes: https://github.com/jax-ml/jax/issues/634
onp_fun = lambda arg: getattr(onp, op)(arg).astype(dtype)
lnp_fun = lambda arg: getattr(lnp, op)(arg)
args_maker = lambda: [pytype(2)]
Expand Down Expand Up @@ -2550,7 +2550,7 @@ def testMathSpecialFloatValues(self, op, dtype):
rtol=tol)

def testIssue883(self):
# from https://github.com/google/jax/issues/883
# from https://github.com/jax-ml/jax/issues/883

@partial(npe.jit, static_argnums=(1,))
def f(x, v):
Expand Down Expand Up @@ -2907,7 +2907,7 @@ def testDisableNumpyRankPromotionBroadcasting(self):
FLAGS.jax_numpy_rank_promotion = prev_flag

def testStackArrayArgument(self):
# tests https://github.com/google/jax/issues/1271
# tests https://github.com/jax-ml/jax/issues/1271
@npe.jit
def foo(x):
return lnp.stack(x)
Expand Down Expand Up @@ -3120,7 +3120,7 @@ def testOpGradSpecialValue(self, op, special_value, order):

@jtu.disable
def testTakeAlongAxisIssue1521(self):
# https://github.com/google/jax/issues/1521
# https://github.com/jax-ml/jax/issues/1521
idx = lnp.repeat(lnp.arange(3), 10).reshape((30, 1))

def f(x):
Expand Down
2 changes: 1 addition & 1 deletion trax/tf_numpy/jax_tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def test_method_wrapper(self, *args, **kwargs):
return test_method_wrapper
return skip

# TODO(phawkins): workaround for bug https://github.com/google/jax/issues/432
# TODO(phawkins): workaround for bug https://github.com/jax-ml/jax/issues/432
# Delete this code after the minimum jaxlib version is 0.1.46 or greater.
skip_on_mac_linalg_bug = partial(
unittest.skipIf,
Expand Down
8 changes: 4 additions & 4 deletions trax/tf_numpy/jax_tests/vmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
class VmapTest(tf.test.TestCase, parameterized.TestCase):

def test_vmap_in_axes_list(self):
# https://github.com/google/jax/issues/2367
# https://github.com/jax-ml/jax/issues/2367
dictionary = {'a': 5., 'b': tf_np.ones(2)}
x = tf_np.zeros(3)
y = tf_np.arange(3.)
Expand All @@ -41,7 +41,7 @@ def f(dct, x, y):
self.assertAllClose(out1, out2)

def test_vmap_in_axes_tree_prefix_error(self):
# https://github.com/google/jax/issues/795
# https://github.com/jax-ml/jax/issues/795
self.assertRaisesRegex(
ValueError,
'vmap in_axes specification must be a tree prefix of the corresponding '
Expand All @@ -63,14 +63,14 @@ def test_vmap_out_axes_leaf_types(self):
tf_np.array([1., 2.]))

def test_vmap_unbatched_object_passthrough_issue_183(self):
# https://github.com/google/jax/issues/183
# https://github.com/jax-ml/jax/issues/183
fun = lambda f, x: f(x)
vfun = extensions.vmap(fun, (None, 0))
ans = vfun(lambda x: x + 1, tf_np.arange(3))
self.assertAllClose(ans, np.arange(1, 4))

def test_vmap_mismatched_axis_sizes_error_message_issue_705(self):
# https://github.com/google/jax/issues/705
# https://github.com/jax-ml/jax/issues/705
with self.assertRaisesRegex(
ValueError, 'vmap must have at least one non-None value in in_axes'):
# If the output is mapped, there must be a non-None in_axes
Expand Down
2 changes: 1 addition & 1 deletion trax/tf_numpy_and_keras.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"\n",
"In Trax, all computations rely on accelerated math operations happening in the `fastmath` module. This module can use different backends for acceleration. One of them is [TensorFlow NumPy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) which uses [TensorFlow 2](https://www.tensorflow.org/) to accelerate the computations.\n",
"\n",
"The backend can be set using a call to `trax.fastmath.set_backend` as you'll see below. Currently available backends are `jax` (default), `tensorflow-numpy` and `numpy` (for debugging). The `tensorflow-numpy` backend uses [TensorFlow Numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) for executing `fastmath` functions on TensorFlow, while the `jax` backend calls [JAX](https://github.com/google/jax) which lowers to TensorFlow XLA.\n",
"The backend can be set using a call to `trax.fastmath.set_backend` as you'll see below. Currently available backends are `jax` (default), `tensorflow-numpy` and `numpy` (for debugging). The `tensorflow-numpy` backend uses [TensorFlow Numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) for executing `fastmath` functions on TensorFlow, while the `jax` backend calls [JAX](https://github.com/jax-ml/jax) which lowers to TensorFlow XLA.\n",
"\n",
"You may see that `tensorflow-numpy` and `jax` backends show different speed and memory characteristics. You may also see different error messages when debugging since it might expose you to the internals of the backends. However for the most part, users can choose a backend and not worry about the internal details of these backends.\n",
"\n",
Expand Down

0 comments on commit fef0981

Please sign in to comment.