Skip to content

Latest commit

 

History

History
18 lines (13 loc) · 460 Bytes

JAX.md

File metadata and controls

18 lines (13 loc) · 460 Bytes

JAX is a functional programming language that is a competitor to Pytorch. Often times, we'll find that Jax has a different way of operating as compared to Pytorch due to its functional nature.

Generating Random Seeds

import jax.random as random
key = random.PRNGKey(0)
x = random.uniform(key,shape=[3,3])
print(x)

key,subkey = random.split(key)
x = random.uniform(key,shape=[3,3])
x

y = random.uniform(subkey,shape=[3,3])
y