Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update references to JAX's GitHub repo #1805

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ You can learn here how Trax works, how to create new models and how to train the

The basic units flowing through Trax models are *tensors* - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- `numpy`. You should take a look at the [numpy guide](https://numpy.org/doc/stable/user/quickstart.html) if you don't know how to operate on tensors: Trax also uses the numpy API for that.

In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/google/jax) and [TensorFlow numpy](https://tensorflow.org/guide/tf_numpy).
In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/jax-ml/jax) and [TensorFlow numpy](https://tensorflow.org/guide/tf_numpy).


```python
Expand Down
2 changes: 1 addition & 1 deletion docs/source/notebooks/tf_numpy_and_keras.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"\n",
"In Trax, all computations rely on accelerated math operations happening in the `fastmath` module. This module can use different backends for acceleration. One of them is [TensorFlow NumPy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) which uses [TensorFlow 2](https://www.tensorflow.org/) to accelerate the computations.\n",
"\n",
"The backend can be set using a call to `trax.fastmath.set_backend` as you'll see below. Currently available backends are `jax` (default), `tensorflow-numpy` and `numpy` (for debugging). The `tensorflow-numpy` backend uses [TensorFlow Numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) for executing `fastmath` functions on TensorFlow, while the `jax` backend calls [JAX](https://github.com/google/jax) which lowers to TensorFlow XLA.\n",
"The backend can be set using a call to `trax.fastmath.set_backend` as you'll see below. Currently available backends are `jax` (default), `tensorflow-numpy` and `numpy` (for debugging). The `tensorflow-numpy` backend uses [TensorFlow Numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) for executing `fastmath` functions on TensorFlow, while the `jax` backend calls [JAX](https://github.com/jax-ml/jax) which lowers to TensorFlow XLA.\n",
"\n",
"You may see that `tensorflow-numpy` and `jax` backends show different speed and memory characteristics. You may also see different error messages when debugging since it might expose you to the internals of the backends. However for the most part, users can choose a backend and not worry about the internal details of these backends.\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/source/notebooks/trax_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@
"\n",
"The basic units flowing through Trax models are *tensors* - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- `numpy`. You should take a look at the [numpy guide](https://numpy.org/doc/stable/user/quickstart.html) if you don't know how to operate on tensors: Trax also uses the numpy API for that.\n",
"\n",
"In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/google/jax) and [TensorFlow numpy](https://tensorflow.org)."
"In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/jax-ml/jax) and [TensorFlow numpy](https://tensorflow.org)."
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion trax/intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@
"\n",
"The basic units flowing through Trax models are *tensors* - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- `numpy`. You should take a look at the [numpy guide](https://numpy.org/doc/stable/user/quickstart.html) if you don't know how to operate on tensors: Trax also uses the numpy API for that.\n",
"\n",
"In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/google/jax) and [TensorFlow numpy](https://tensorflow.org)."
"In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the `trax.fastmath` package thanks to its backends -- [JAX](https://github.com/jax-ml/jax) and [TensorFlow numpy](https://tensorflow.org)."
]
},
{
Expand Down
8 changes: 4 additions & 4 deletions trax/supervised/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_loop_with_initialized_model(self):

def test_train_save_restore_dense(self):
"""Saves and restores a checkpoint to check for equivalence."""
self.skipTest('Broken by https://github.com/google/jax/pull/11234')
self.skipTest('Broken by https://github.com/jax-ml/jax/pull/11234')
train_data = data.Serial(lambda _: _very_simple_data(),
data.CountAndSkip('simple_data'))
task = training.TrainTask(
Expand Down Expand Up @@ -327,7 +327,7 @@ def test_restores_step(self):

def test_restores_memory_efficient_from_standard(self):
"""Training restores step from directory where it saved it."""
self.skipTest('Broken by https://github.com/google/jax/pull/11234')
self.skipTest('Broken by https://github.com/jax-ml/jax/pull/11234')
model = tl.Serial(tl.Dense(4), tl.Dense(1))
task_std = training.TrainTask(
_very_simple_data(), tl.L2Loss(), optimizers.Adam(.0001))
Expand All @@ -345,7 +345,7 @@ def test_restores_memory_efficient_from_standard(self):

def test_restores_from_smaller_model(self):
"""Training restores from a checkpoint created with smaller model."""
self.skipTest('Broken by https://github.com/google/jax/pull/11234')
self.skipTest('Broken by https://github.com/jax-ml/jax/pull/11234')
model1 = tl.Serial(tl.Dense(1))
task = training.TrainTask(
_very_simple_data(), tl.L2Loss(), optimizers.Adam(.01))
Expand Down Expand Up @@ -374,7 +374,7 @@ def test_restore_fails_different_model(self):

def test_restores_step_bfloat16(self):
"""Training restores step from directory where it saved it, w/ bfloat16."""
self.skipTest('Broken by https://github.com/google/jax/pull/11234')
self.skipTest('Broken by https://github.com/jax-ml/jax/pull/11234')
model = tl.Serial(tl.Dense(1, use_bfloat16=True))
# We'll also use Adafactor with bfloat16 to check restoring bfloat slots.
opt = optimizers.Adafactor(.01, do_momentum=True, momentum_in_bfloat16=True)
Expand Down
2 changes: 1 addition & 1 deletion trax/tf_numpy/jax_tests/lax_numpy_einsum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_two_operands_5(self):
self._check(s, x, y)

def test_two_operands_6(self):
# based on https://github.com/google/jax/issues/37#issuecomment-448572187
# based on https://github.com/jax-ml/jax/issues/37#issuecomment-448572187
r = self.rng()
x = r.randn(2, 1)
y = r.randn(2, 3, 4)
Expand Down
2 changes: 1 addition & 1 deletion trax/tf_numpy/jax_tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def testFloatIndexingError(self):
with self.assertRaisesRegex(IndexError, error_regex):
npe.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.))

def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245
def testIndexOutOfBounds(self): # https://github.com/jax-ml/jax/issues/2245
array = jnp.ones(5)
self.assertAllClose(array, array[:10], check_dtypes=True)

Expand Down
16 changes: 8 additions & 8 deletions trax/tf_numpy/jax_tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2051,14 +2051,14 @@ def testAstype(self):
# TODO(mattjj): test other ndarray-like method overrides

def testOnpMean(self):
# from https://github.com/google/jax/issues/125
# from https://github.com/jax-ml/jax/issues/125
x = lnp.add(lnp.eye(3, dtype=lnp.float_), 0.)
ans = onp.mean(x)
self.assertAllClose(ans, onp.array(1./3), check_dtypes=False)

@jtu.disable
def testArangeOnFloats(self):
# from https://github.com/google/jax/issues/145
# from https://github.com/jax-ml/jax/issues/145
expected = onp.arange(0.0, 1.0, 0.1, dtype=lnp.float_)
ans = lnp.arange(0.0, 1.0, 0.1)
self.assertAllClose(expected, ans, check_dtypes=True)
Expand Down Expand Up @@ -2407,7 +2407,7 @@ def testSymmetrizeDtypePromotion(self):

@jtu.disable
def testIssue347(self):
# https://github.com/google/jax/issues/347
# https://github.com/jax-ml/jax/issues/347
def test_fail(x):
x = lnp.sqrt(lnp.sum(x ** 2, axis=1))
ones = lnp.ones_like(x)
Expand All @@ -2419,7 +2419,7 @@ def test_fail(x):
assert not onp.any(onp.isnan(result))

def testIssue453(self):
# https://github.com/google/jax/issues/453
# https://github.com/jax-ml/jax/issues/453
a = onp.arange(6) + 1
ans = lnp.reshape(a, (3, 2), order='F')
expected = onp.reshape(a, (3, 2), order='F')
Expand All @@ -2432,7 +2432,7 @@ def testIssue453(self):
(bool, lnp.bool_), (complex, lnp.complex_)]
for op in ["atleast_1d", "atleast_2d", "atleast_3d"]))
def testAtLeastNdLiterals(self, pytype, dtype, op):
# Fixes: https://github.com/google/jax/issues/634
# Fixes: https://github.com/jax-ml/jax/issues/634
onp_fun = lambda arg: getattr(onp, op)(arg).astype(dtype)
lnp_fun = lambda arg: getattr(lnp, op)(arg)
args_maker = lambda: [pytype(2)]
Expand Down Expand Up @@ -2550,7 +2550,7 @@ def testMathSpecialFloatValues(self, op, dtype):
rtol=tol)

def testIssue883(self):
# from https://github.com/google/jax/issues/883
# from https://github.com/jax-ml/jax/issues/883

@partial(npe.jit, static_argnums=(1,))
def f(x, v):
Expand Down Expand Up @@ -2907,7 +2907,7 @@ def testDisableNumpyRankPromotionBroadcasting(self):
FLAGS.jax_numpy_rank_promotion = prev_flag

def testStackArrayArgument(self):
# tests https://github.com/google/jax/issues/1271
# tests https://github.com/jax-ml/jax/issues/1271
@npe.jit
def foo(x):
return lnp.stack(x)
Expand Down Expand Up @@ -3120,7 +3120,7 @@ def testOpGradSpecialValue(self, op, special_value, order):

@jtu.disable
def testTakeAlongAxisIssue1521(self):
# https://github.com/google/jax/issues/1521
# https://github.com/jax-ml/jax/issues/1521
idx = lnp.repeat(lnp.arange(3), 10).reshape((30, 1))

def f(x):
Expand Down
2 changes: 1 addition & 1 deletion trax/tf_numpy/jax_tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def test_method_wrapper(self, *args, **kwargs):
return test_method_wrapper
return skip

# TODO(phawkins): workaround for bug https://github.com/google/jax/issues/432
# TODO(phawkins): workaround for bug https://github.com/jax-ml/jax/issues/432
# Delete this code after the minimum jaxlib version is 0.1.46 or greater.
skip_on_mac_linalg_bug = partial(
unittest.skipIf,
Expand Down
8 changes: 4 additions & 4 deletions trax/tf_numpy/jax_tests/vmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
class VmapTest(tf.test.TestCase, parameterized.TestCase):

def test_vmap_in_axes_list(self):
# https://github.com/google/jax/issues/2367
# https://github.com/jax-ml/jax/issues/2367
dictionary = {'a': 5., 'b': tf_np.ones(2)}
x = tf_np.zeros(3)
y = tf_np.arange(3.)
Expand All @@ -41,7 +41,7 @@ def f(dct, x, y):
self.assertAllClose(out1, out2)

def test_vmap_in_axes_tree_prefix_error(self):
# https://github.com/google/jax/issues/795
# https://github.com/jax-ml/jax/issues/795
self.assertRaisesRegex(
ValueError,
'vmap in_axes specification must be a tree prefix of the corresponding '
Expand All @@ -63,14 +63,14 @@ def test_vmap_out_axes_leaf_types(self):
tf_np.array([1., 2.]))

def test_vmap_unbatched_object_passthrough_issue_183(self):
# https://github.com/google/jax/issues/183
# https://github.com/jax-ml/jax/issues/183
fun = lambda f, x: f(x)
vfun = extensions.vmap(fun, (None, 0))
ans = vfun(lambda x: x + 1, tf_np.arange(3))
self.assertAllClose(ans, np.arange(1, 4))

def test_vmap_mismatched_axis_sizes_error_message_issue_705(self):
# https://github.com/google/jax/issues/705
# https://github.com/jax-ml/jax/issues/705
with self.assertRaisesRegex(
ValueError, 'vmap must have at least one non-None value in in_axes'):
# If the output is mapped, there must be a non-None in_axes
Expand Down
2 changes: 1 addition & 1 deletion trax/tf_numpy_and_keras.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"\n",
"In Trax, all computations rely on accelerated math operations happening in the `fastmath` module. This module can use different backends for acceleration. One of them is [TensorFlow NumPy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) which uses [TensorFlow 2](https://www.tensorflow.org/) to accelerate the computations.\n",
"\n",
"The backend can be set using a call to `trax.fastmath.set_backend` as you'll see below. Currently available backends are `jax` (default), `tensorflow-numpy` and `numpy` (for debugging). The `tensorflow-numpy` backend uses [TensorFlow Numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) for executing `fastmath` functions on TensorFlow, while the `jax` backend calls [JAX](https://github.com/google/jax) which lowers to TensorFlow XLA.\n",
"The backend can be set using a call to `trax.fastmath.set_backend` as you'll see below. Currently available backends are `jax` (default), `tensorflow-numpy` and `numpy` (for debugging). The `tensorflow-numpy` backend uses [TensorFlow Numpy](https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) for executing `fastmath` functions on TensorFlow, while the `jax` backend calls [JAX](https://github.com/jax-ml/jax) which lowers to TensorFlow XLA.\n",
"\n",
"You may see that `tensorflow-numpy` and `jax` backends show different speed and memory characteristics. You may also see different error messages when debugging since it might expose you to the internals of the backends. However for the most part, users can choose a backend and not worry about the internal details of these backends.\n",
"\n",
Expand Down
Loading