You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importnumpyasnp# regular ol' numpyfromtraximportlayersastl# core building blockfromtraximportshapes# data signatures: dimensionality and typefromtraximportfastmath# uses jax, offers numpy on steroids
Upon import it errors out doing the basics here. What am I doing wrong? Should I be pinning a different version of the code?
### For bugs: reproduction and error logs
# Error logs:
...
1 # coding=utf-8
2 # Copyright 2021 The Trax Authors.
3 #
(...)
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
16 """Trax top level import."""
---> 18 from trax import data
19 from trax import fastmath
20 from trax import layers
File ./ds_work/miniconda3/envs/coursera-nlp/lib/python3.9/site-packages/trax/data/__init__.py:36, in <module>
16 """Functions and classes for obtaining and preprocesing data.
17
18 The ``trax.data`` module presents a flattened (no subpackages) public API.
(...)
...
217 'vjp': jax.vjp,
218 'vmap': jax.vmap,
219 }
AttributeError: module 'jax.ops' has no attribute 'index_add'
The text was updated successfully, but these errors were encountered:
Description
I am trying to do something basic in my code:
Upon import it errors out doing the basics here. What am I doing wrong? Should I be pinning a different version of the code?
Environment information
OS: Cento
lsb_release
LSB Version: :core-4.1-amd64:core-4.1-ia32:core-4.1-noarch:cxx-4.1-amd64:cxx-4.1-ia32:cxx-4.1-noarch:desktop-4.1-amd64:desktop-4.1-ia32:desktop-4.1-noarch:languages-4.1-amd64:languages-4.1-noarch:printing-4.1-amd64:printing-4.1-noarch
$ pip freeze | grep trax
trax==1.3.9
$ pip freeze | grep tensor
mesh-tensorflow==0.1.21
tensorboard==2.11.2
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-datasets==4.8.2
tensorflow-estimator==2.11.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.30.0
tensorflow-metadata==1.12.0
tensorflow-text==2.11.0
$ pip freeze | grep jax
jax==0.4.4
jaxlib==0.4.4
$ python -V
Python 3.9.16
The text was updated successfully, but these errors were encountered: