Skip to content

Commit

Permalink
Update taichi_random_time_test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Dec 22, 2023
1 parent 5177634 commit 28d8e58
Showing 1 changed file with 0 additions and 34 deletions.
34 changes: 0 additions & 34 deletions brainpy/_src/math/tests/taichi_random_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,6 @@ def test_taichi_xorwow(seed: ti.types.ndarray(ndim=1, dtype=ti.u32),
seeds1, seeds2, result = taichi_xorwow(seeds1, seeds2)
out[i] = result

@ti.kernel
def test_taichi_lfsr88_0(seed: ti.types.ndarray(ndim=1, dtype=ti.u32),
out: ti.types.ndarray(ndim=1, dtype=ti.f32)):
s1, s2, s3, b = init_lfsr88_seeds_0(seed[0])
for i in range(out.shape[0]):
s1, s2, s3, b, result = taichi_lfsr88_0(s1, s2, s3, b)
out[i] = result

@ti.kernel
def test_taichi_xorwow_0(seed: ti.types.ndarray(ndim=1, dtype=ti.u32),
out: ti.types.ndarray(ndim=1, dtype=ti.f32)):
x, y, z, w, v, d = init_xorwow_seeds_0(seed[0])
# print(seeds1, seeds2)
for i in range(out.shape[0]):
x, y, z, w, v, d, result = taichi_xorwow_0(x, y, z, w, v, d)
out[i] = result

n = 100000000
seed = jnp.array([1234, ], dtype=jnp.uint32)

Expand All @@ -60,12 +43,6 @@ def test_taichi_xorwow_0(seed: ti.types.ndarray(ndim=1, dtype=ti.u32),
prim_xorwow = bm.XLACustomOp(cpu_kernel=test_taichi_xorwow,
gpu_kernel=test_taichi_xorwow)

prim_lfsr88_0 = bm.XLACustomOp(cpu_kernel=test_taichi_lfsr88_0,
gpu_kernel=test_taichi_lfsr88_0)

prim_xorwow_0 = bm.XLACustomOp(cpu_kernel=test_taichi_xorwow_0,
gpu_kernel=test_taichi_xorwow_0)

out = jax.block_until_ready(prim_lfsr88(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]))
time0 = time.time()
out = jax.block_until_ready(prim_lfsr88(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]))
Expand All @@ -76,17 +53,6 @@ def test_taichi_xorwow_0(seed: ti.types.ndarray(ndim=1, dtype=ti.u32),
out = jax.block_until_ready(prim_xorwow(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]))
time3 = time.time()

out = jax.block_until_ready(prim_lfsr88_0(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]))
time4 = time.time()
out = jax.block_until_ready(prim_lfsr88_0(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]))
time5 = time.time()

out = jax.block_until_ready(prim_xorwow_0(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]))
time6 = time.time()
out = jax.block_until_ready(prim_xorwow_0(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]))
time7 = time.time()

print('lfsr88: ', time1 - time0)
print('xorwow: ', time3 - time2)
print('lfsr88_0: ', time5 - time4)
print('xorwow_0: ', time7 - time6)

0 comments on commit 28d8e58

Please sign in to comment.