diff --git a/README.md b/README.md index 33884979a..15ae9d46c 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ You can learn here how Trax works, how to create new models and how to train the The basic units flowing through Trax models are *tensors* - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- `numpy`. You should take a look at the [numpy guide](https://numpy.org/doc/stable/user/quickstart.html) if you don't know how to operate on tensors: Trax also uses the numpy API for that. -In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/google/jax) and [TensorFlow numpy](https://tensorflow.org/guide/tf_numpy). +In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/jax-ml/jax) and [TensorFlow numpy](https://tensorflow.org/guide/tf_numpy). ```python diff --git a/docs/source/notebooks/tf_numpy_and_keras.ipynb b/docs/source/notebooks/tf_numpy_and_keras.ipynb index 70c9e38fa..96bcc0b1f 100644 --- a/docs/source/notebooks/tf_numpy_and_keras.ipynb +++ b/docs/source/notebooks/tf_numpy_and_keras.ipynb @@ -26,7 +26,7 @@ "\n", "In Trax, all computations rely on accelerated math operations happening in the `fastmath` module. This module can use different backends for acceleration. One of them is [TensorFlow NumPy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) which uses [TensorFlow 2](https://www.tensorflow.org/) to accelerate the computations.\n", "\n", - "The backend can be set using a call to `trax.fastmath.set_backend` as you'll see below. Currently available backends are `jax` (default), `tensorflow-numpy` and `numpy` (for debugging). The `tensorflow-numpy` backend uses [TensorFlow Numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) for executing `fastmath` functions on TensorFlow, while the `jax` backend calls [JAX](https://github.com/google/jax) which lowers to TensorFlow XLA.\n", + "The backend can be set using a call to `trax.fastmath.set_backend` as you'll see below. Currently available backends are `jax` (default), `tensorflow-numpy` and `numpy` (for debugging). The `tensorflow-numpy` backend uses [TensorFlow Numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) for executing `fastmath` functions on TensorFlow, while the `jax` backend calls [JAX](https://github.com/jax-ml/jax) which lowers to TensorFlow XLA.\n", "\n", "You may see that `tensorflow-numpy` and `jax` backends show different speed and memory characteristics. You may also see different error messages when debugging since it might expose you to the internals of the backends. However for the most part, users can choose a backend and not worry about the internal details of these backends.\n", "\n", diff --git a/docs/source/notebooks/trax_intro.ipynb b/docs/source/notebooks/trax_intro.ipynb index 6641e295c..7b85099cc 100644 --- a/docs/source/notebooks/trax_intro.ipynb +++ b/docs/source/notebooks/trax_intro.ipynb @@ -228,7 +228,7 @@ "\n", "The basic units flowing through Trax models are *tensors* - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- `numpy`. You should take a look at the [numpy guide](https://numpy.org/doc/stable/user/quickstart.html) if you don't know how to operate on tensors: Trax also uses the numpy API for that.\n", "\n", - "In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/google/jax) and [TensorFlow numpy](https://tensorflow.org)." + "In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/jax-ml/jax) and [TensorFlow numpy](https://tensorflow.org)." ] }, { diff --git a/trax/intro.ipynb b/trax/intro.ipynb index 6641e295c..7b85099cc 100644 --- a/trax/intro.ipynb +++ b/trax/intro.ipynb @@ -228,7 +228,7 @@ "\n", "The basic units flowing through Trax models are *tensors* - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- `numpy`. You should take a look at the [numpy guide](https://numpy.org/doc/stable/user/quickstart.html) if you don't know how to operate on tensors: Trax also uses the numpy API for that.\n", "\n", - "In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/google/jax) and [TensorFlow numpy](https://tensorflow.org)." + "In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/jax-ml/jax) and [TensorFlow numpy](https://tensorflow.org)." ] }, { diff --git a/trax/supervised/training_test.py b/trax/supervised/training_test.py index 25ae75f73..f0adc73fd 100644 --- a/trax/supervised/training_test.py +++ b/trax/supervised/training_test.py @@ -143,7 +143,7 @@ def test_loop_with_initialized_model(self): def test_train_save_restore_dense(self): """Saves and restores a checkpoint to check for equivalence.""" - self.skipTest('Broken by https://github.com/google/jax/pull/11234') + self.skipTest('Broken by https://github.com/jax-ml/jax/pull/11234') train_data = data.Serial(lambda _: _very_simple_data(), data.CountAndSkip('simple_data')) task = training.TrainTask( @@ -327,7 +327,7 @@ def test_restores_step(self): def test_restores_memory_efficient_from_standard(self): """Training restores step from directory where it saved it.""" - self.skipTest('Broken by https://github.com/google/jax/pull/11234') + self.skipTest('Broken by https://github.com/jax-ml/jax/pull/11234') model = tl.Serial(tl.Dense(4), tl.Dense(1)) task_std = training.TrainTask( _very_simple_data(), tl.L2Loss(), optimizers.Adam(.0001)) @@ -345,7 +345,7 @@ def test_restores_memory_efficient_from_standard(self): def test_restores_from_smaller_model(self): """Training restores from a checkpoint created with smaller model.""" - self.skipTest('Broken by https://github.com/google/jax/pull/11234') + self.skipTest('Broken by https://github.com/jax-ml/jax/pull/11234') model1 = tl.Serial(tl.Dense(1)) task = training.TrainTask( _very_simple_data(), tl.L2Loss(), optimizers.Adam(.01)) @@ -374,7 +374,7 @@ def test_restore_fails_different_model(self): def test_restores_step_bfloat16(self): """Training restores step from directory where it saved it, w/ bfloat16.""" - self.skipTest('Broken by https://github.com/google/jax/pull/11234') + self.skipTest('Broken by https://github.com/jax-ml/jax/pull/11234') model = tl.Serial(tl.Dense(1, use_bfloat16=True)) # We'll also use Adafactor with bfloat16 to check restoring bfloat slots. opt = optimizers.Adafactor(.01, do_momentum=True, momentum_in_bfloat16=True) diff --git a/trax/tf_numpy/jax_tests/lax_numpy_einsum_test.py b/trax/tf_numpy/jax_tests/lax_numpy_einsum_test.py index dcb2189cf..0b31e6497 100644 --- a/trax/tf_numpy/jax_tests/lax_numpy_einsum_test.py +++ b/trax/tf_numpy/jax_tests/lax_numpy_einsum_test.py @@ -89,7 +89,7 @@ def test_two_operands_5(self): self._check(s, x, y) def test_two_operands_6(self): - # based on https://github.com/google/jax/issues/37#issuecomment-448572187 + # based on https://github.com/jax-ml/jax/issues/37#issuecomment-448572187 r = self.rng() x = r.randn(2, 1) y = r.randn(2, 3, 4) diff --git a/trax/tf_numpy/jax_tests/lax_numpy_indexing_test.py b/trax/tf_numpy/jax_tests/lax_numpy_indexing_test.py index a71783c43..60e76c942 100644 --- a/trax/tf_numpy/jax_tests/lax_numpy_indexing_test.py +++ b/trax/tf_numpy/jax_tests/lax_numpy_indexing_test.py @@ -720,7 +720,7 @@ def testFloatIndexingError(self): with self.assertRaisesRegex(IndexError, error_regex): npe.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.)) - def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245 + def testIndexOutOfBounds(self): # https://github.com/jax-ml/jax/issues/2245 array = jnp.ones(5) self.assertAllClose(array, array[:10], check_dtypes=True) diff --git a/trax/tf_numpy/jax_tests/lax_numpy_test.py b/trax/tf_numpy/jax_tests/lax_numpy_test.py index b3d1662a9..dd98677c5 100644 --- a/trax/tf_numpy/jax_tests/lax_numpy_test.py +++ b/trax/tf_numpy/jax_tests/lax_numpy_test.py @@ -2051,14 +2051,14 @@ def testAstype(self): # TODO(mattjj): test other ndarray-like method overrides def testOnpMean(self): - # from https://github.com/google/jax/issues/125 + # from https://github.com/jax-ml/jax/issues/125 x = lnp.add(lnp.eye(3, dtype=lnp.float_), 0.) ans = onp.mean(x) self.assertAllClose(ans, onp.array(1./3), check_dtypes=False) @jtu.disable def testArangeOnFloats(self): - # from https://github.com/google/jax/issues/145 + # from https://github.com/jax-ml/jax/issues/145 expected = onp.arange(0.0, 1.0, 0.1, dtype=lnp.float_) ans = lnp.arange(0.0, 1.0, 0.1) self.assertAllClose(expected, ans, check_dtypes=True) @@ -2407,7 +2407,7 @@ def testSymmetrizeDtypePromotion(self): @jtu.disable def testIssue347(self): - # https://github.com/google/jax/issues/347 + # https://github.com/jax-ml/jax/issues/347 def test_fail(x): x = lnp.sqrt(lnp.sum(x ** 2, axis=1)) ones = lnp.ones_like(x) @@ -2419,7 +2419,7 @@ def test_fail(x): assert not onp.any(onp.isnan(result)) def testIssue453(self): - # https://github.com/google/jax/issues/453 + # https://github.com/jax-ml/jax/issues/453 a = onp.arange(6) + 1 ans = lnp.reshape(a, (3, 2), order='F') expected = onp.reshape(a, (3, 2), order='F') @@ -2432,7 +2432,7 @@ def testIssue453(self): (bool, lnp.bool_), (complex, lnp.complex_)] for op in ["atleast_1d", "atleast_2d", "atleast_3d"])) def testAtLeastNdLiterals(self, pytype, dtype, op): - # Fixes: https://github.com/google/jax/issues/634 + # Fixes: https://github.com/jax-ml/jax/issues/634 onp_fun = lambda arg: getattr(onp, op)(arg).astype(dtype) lnp_fun = lambda arg: getattr(lnp, op)(arg) args_maker = lambda: [pytype(2)] @@ -2550,7 +2550,7 @@ def testMathSpecialFloatValues(self, op, dtype): rtol=tol) def testIssue883(self): - # from https://github.com/google/jax/issues/883 + # from https://github.com/jax-ml/jax/issues/883 @partial(npe.jit, static_argnums=(1,)) def f(x, v): @@ -2907,7 +2907,7 @@ def testDisableNumpyRankPromotionBroadcasting(self): FLAGS.jax_numpy_rank_promotion = prev_flag def testStackArrayArgument(self): - # tests https://github.com/google/jax/issues/1271 + # tests https://github.com/jax-ml/jax/issues/1271 @npe.jit def foo(x): return lnp.stack(x) @@ -3120,7 +3120,7 @@ def testOpGradSpecialValue(self, op, special_value, order): @jtu.disable def testTakeAlongAxisIssue1521(self): - # https://github.com/google/jax/issues/1521 + # https://github.com/jax-ml/jax/issues/1521 idx = lnp.repeat(lnp.arange(3), 10).reshape((30, 1)) def f(x): diff --git a/trax/tf_numpy/jax_tests/test_util.py b/trax/tf_numpy/jax_tests/test_util.py index 0110f0db6..9734bc1e3 100644 --- a/trax/tf_numpy/jax_tests/test_util.py +++ b/trax/tf_numpy/jax_tests/test_util.py @@ -316,7 +316,7 @@ def test_method_wrapper(self, *args, **kwargs): return test_method_wrapper return skip -# TODO(phawkins): workaround for bug https://github.com/google/jax/issues/432 +# TODO(phawkins): workaround for bug https://github.com/jax-ml/jax/issues/432 # Delete this code after the minimum jaxlib version is 0.1.46 or greater. skip_on_mac_linalg_bug = partial( unittest.skipIf, diff --git a/trax/tf_numpy/jax_tests/vmap_test.py b/trax/tf_numpy/jax_tests/vmap_test.py index bb9138d18..9f6f24c8c 100644 --- a/trax/tf_numpy/jax_tests/vmap_test.py +++ b/trax/tf_numpy/jax_tests/vmap_test.py @@ -28,7 +28,7 @@ class VmapTest(tf.test.TestCase, parameterized.TestCase): def test_vmap_in_axes_list(self): - # https://github.com/google/jax/issues/2367 + # https://github.com/jax-ml/jax/issues/2367 dictionary = {'a': 5., 'b': tf_np.ones(2)} x = tf_np.zeros(3) y = tf_np.arange(3.) @@ -41,7 +41,7 @@ def f(dct, x, y): self.assertAllClose(out1, out2) def test_vmap_in_axes_tree_prefix_error(self): - # https://github.com/google/jax/issues/795 + # https://github.com/jax-ml/jax/issues/795 self.assertRaisesRegex( ValueError, 'vmap in_axes specification must be a tree prefix of the corresponding ' @@ -63,14 +63,14 @@ def test_vmap_out_axes_leaf_types(self): tf_np.array([1., 2.])) def test_vmap_unbatched_object_passthrough_issue_183(self): - # https://github.com/google/jax/issues/183 + # https://github.com/jax-ml/jax/issues/183 fun = lambda f, x: f(x) vfun = extensions.vmap(fun, (None, 0)) ans = vfun(lambda x: x + 1, tf_np.arange(3)) self.assertAllClose(ans, np.arange(1, 4)) def test_vmap_mismatched_axis_sizes_error_message_issue_705(self): - # https://github.com/google/jax/issues/705 + # https://github.com/jax-ml/jax/issues/705 with self.assertRaisesRegex( ValueError, 'vmap must have at least one non-None value in in_axes'): # If the output is mapped, there must be a non-None in_axes diff --git a/trax/tf_numpy_and_keras.ipynb b/trax/tf_numpy_and_keras.ipynb index 70c9e38fa..96bcc0b1f 100644 --- a/trax/tf_numpy_and_keras.ipynb +++ b/trax/tf_numpy_and_keras.ipynb @@ -26,7 +26,7 @@ "\n", "In Trax, all computations rely on accelerated math operations happening in the `fastmath` module. This module can use different backends for acceleration. One of them is [TensorFlow NumPy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) which uses [TensorFlow 2](https://www.tensorflow.org/) to accelerate the computations.\n", "\n", - "The backend can be set using a call to `trax.fastmath.set_backend` as you'll see below. Currently available backends are `jax` (default), `tensorflow-numpy` and `numpy` (for debugging). The `tensorflow-numpy` backend uses [TensorFlow Numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) for executing `fastmath` functions on TensorFlow, while the `jax` backend calls [JAX](https://github.com/google/jax) which lowers to TensorFlow XLA.\n", + "The backend can be set using a call to `trax.fastmath.set_backend` as you'll see below. Currently available backends are `jax` (default), `tensorflow-numpy` and `numpy` (for debugging). The `tensorflow-numpy` backend uses [TensorFlow Numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) for executing `fastmath` functions on TensorFlow, while the `jax` backend calls [JAX](https://github.com/jax-ml/jax) which lowers to TensorFlow XLA.\n", "\n", "You may see that `tensorflow-numpy` and `jax` backends show different speed and memory characteristics. You may also see different error messages when debugging since it might expose you to the internals of the backends. However for the most part, users can choose a backend and not worry about the internal details of these backends.\n", "\n",