Skip to content

Commit

Permalink
Update taichi_random_time_test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Dec 22, 2023
1 parent 28d8e58 commit b88b6c2
Showing 1 changed file with 37 additions and 34 deletions.
71 changes: 37 additions & 34 deletions brainpy/_src/math/tests/taichi_random_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import brainpy.math as bm
import taichi as ti
import matplotlib.pyplot as plt
import pytest

pytestmark = pytest.mark.skip(reason="Skipped due to MacOS limitation, manual execution required for testing.")
from brainpy._src.math.taichi_random import (taichi_lcg_rand as rand,
taichi_uniform_int_distribution as randint,
taichi_uniform_real_distribution as uniform,
Expand All @@ -17,42 +19,43 @@

bm.set_platform('gpu')

@ti.kernel
def test_taichi_lfsr88(seed: ti.types.ndarray(ndim=1, dtype=ti.u32),
out: ti.types.ndarray(ndim=1, dtype=ti.f32)):
seeds = init_lfsr88_seeds(seed[0])
for i in range(out.shape[0]):
seeds, result = taichi_lfsr88(seeds)
out[i] = result

@ti.kernel
def test_taichi_xorwow(seed: ti.types.ndarray(ndim=1, dtype=ti.u32),
out: ti.types.ndarray(ndim=1, dtype=ti.f32)):
seeds1, seeds2 = init_xorwow_seeds(seed[0])
# print(seeds1, seeds2)
for i in range(out.shape[0]):
seeds1, seeds2, result = taichi_xorwow(seeds1, seeds2)
out[i] = result
def main():
@ti.kernel
def test_taichi_lfsr88(seed: ti.types.ndarray(ndim=1, dtype=ti.u32),
out: ti.types.ndarray(ndim=1, dtype=ti.f32)):
seeds = init_lfsr88_seeds(seed[0])
for i in range(out.shape[0]):
seeds, result = taichi_lfsr88(seeds)
out[i] = result

n = 100000000
seed = jnp.array([1234, ], dtype=jnp.uint32)

prim_lfsr88 = bm.XLACustomOp(cpu_kernel=test_taichi_lfsr88,
gpu_kernel=test_taichi_lfsr88)

prim_xorwow = bm.XLACustomOp(cpu_kernel=test_taichi_xorwow,
gpu_kernel=test_taichi_xorwow)
@ti.kernel
def test_taichi_xorwow(seed: ti.types.ndarray(ndim=1, dtype=ti.u32),
out: ti.types.ndarray(ndim=1, dtype=ti.f32)):
seeds1, seeds2 = init_xorwow_seeds(seed[0])
# print(seeds1, seeds2)
for i in range(out.shape[0]):
seeds1, seeds2, result = taichi_xorwow(seeds1, seeds2)
out[i] = result

n = 100000000
seed = jnp.array([1234, ], dtype=jnp.uint32)

prim_lfsr88 = bm.XLACustomOp(cpu_kernel=test_taichi_lfsr88,
gpu_kernel=test_taichi_lfsr88)

prim_xorwow = bm.XLACustomOp(cpu_kernel=test_taichi_xorwow,
gpu_kernel=test_taichi_xorwow)

out = jax.block_until_ready(prim_lfsr88(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]))
time0 = time.time()
out = jax.block_until_ready(prim_lfsr88(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]))
time1 = time.time()
out = jax.block_until_ready(prim_lfsr88(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]))
time0 = time.time()
out = jax.block_until_ready(prim_lfsr88(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]))
time1 = time.time()

out = jax.block_until_ready(prim_xorwow(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]))
time2 = time.time()
out = jax.block_until_ready(prim_xorwow(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]))
time3 = time.time()
out = jax.block_until_ready(prim_xorwow(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]))
time2 = time.time()
out = jax.block_until_ready(prim_xorwow(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]))
time3 = time.time()


print('lfsr88: ', time1 - time0)
print('xorwow: ', time3 - time2)
print('lfsr88: ', time1 - time0)
print('xorwow: ', time3 - time2)

0 comments on commit b88b6c2

Please sign in to comment.