diff --git a/brainpy/_src/math/tests/taichi_random_time_test.py b/brainpy/_src/math/tests/taichi_random_time_test.py index 45d836e9c..4d0cddf9b 100644 --- a/brainpy/_src/math/tests/taichi_random_time_test.py +++ b/brainpy/_src/math/tests/taichi_random_time_test.py @@ -34,23 +34,6 @@ def test_taichi_xorwow(seed: ti.types.ndarray(ndim=1, dtype=ti.u32), seeds1, seeds2, result = taichi_xorwow(seeds1, seeds2) out[i] = result -@ti.kernel -def test_taichi_lfsr88_0(seed: ti.types.ndarray(ndim=1, dtype=ti.u32), - out: ti.types.ndarray(ndim=1, dtype=ti.f32)): - s1, s2, s3, b = init_lfsr88_seeds_0(seed[0]) - for i in range(out.shape[0]): - s1, s2, s3, b, result = taichi_lfsr88_0(s1, s2, s3, b) - out[i] = result - -@ti.kernel -def test_taichi_xorwow_0(seed: ti.types.ndarray(ndim=1, dtype=ti.u32), - out: ti.types.ndarray(ndim=1, dtype=ti.f32)): - x, y, z, w, v, d = init_xorwow_seeds_0(seed[0]) - # print(seeds1, seeds2) - for i in range(out.shape[0]): - x, y, z, w, v, d, result = taichi_xorwow_0(x, y, z, w, v, d) - out[i] = result - n = 100000000 seed = jnp.array([1234, ], dtype=jnp.uint32) @@ -60,12 +43,6 @@ def test_taichi_xorwow_0(seed: ti.types.ndarray(ndim=1, dtype=ti.u32), prim_xorwow = bm.XLACustomOp(cpu_kernel=test_taichi_xorwow, gpu_kernel=test_taichi_xorwow) -prim_lfsr88_0 = bm.XLACustomOp(cpu_kernel=test_taichi_lfsr88_0, - gpu_kernel=test_taichi_lfsr88_0) - -prim_xorwow_0 = bm.XLACustomOp(cpu_kernel=test_taichi_xorwow_0, - gpu_kernel=test_taichi_xorwow_0) - out = jax.block_until_ready(prim_lfsr88(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)])) time0 = time.time() out = jax.block_until_ready(prim_lfsr88(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)])) @@ -76,17 +53,6 @@ def test_taichi_xorwow_0(seed: ti.types.ndarray(ndim=1, dtype=ti.u32), out = jax.block_until_ready(prim_xorwow(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)])) time3 = time.time() -out = jax.block_until_ready(prim_lfsr88_0(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)])) -time4 = time.time() -out = jax.block_until_ready(prim_lfsr88_0(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)])) -time5 = time.time() - -out = jax.block_until_ready(prim_xorwow_0(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)])) -time6 = time.time() -out = jax.block_until_ready(prim_xorwow_0(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)])) -time7 = time.time() print('lfsr88: ', time1 - time0) print('xorwow: ', time3 - time2) -print('lfsr88_0: ', time5 - time4) -print('xorwow_0: ', time7 - time6)