-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
50 additions
and
3 deletions.
There are no files selected for viewing
53 changes: 50 additions & 3 deletions
53
brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,54 @@ | ||
import brainpy.math as bm | ||
import jax | ||
import jax.numpy as jnp | ||
import platform | ||
import pytest | ||
import taichi | ||
|
||
print(bm.check_kernels_count()) | ||
if not platform.platform().startswith('Windows'): | ||
pytest.skip(allow_module_level=True) | ||
|
||
bm.clean_caches() | ||
@taichi.func | ||
def get_weight(weight: taichi.types.ndarray(ndim=1)) -> taichi.f32: | ||
return weight[0] | ||
|
||
print(bm.check_kernels_count()) | ||
|
||
@taichi.func | ||
def update_output(out: taichi.types.ndarray(ndim=1), index: taichi.i32, weight_val: taichi.f32): | ||
out[index] += weight_val | ||
|
||
@taichi.kernel | ||
def event_ell_cpu(indices: taichi.types.ndarray(ndim=2), | ||
vector: taichi.types.ndarray(ndim=1), | ||
weight: taichi.types.ndarray(ndim=1), | ||
out: taichi.types.ndarray(ndim=1)): | ||
weight_val = get_weight(weight) | ||
num_rows, num_cols = indices.shape | ||
taichi.loop_config(serialize=True) | ||
for i in range(num_rows): | ||
if vector[i]: | ||
for j in range(num_cols): | ||
update_output(out, indices[i, j], weight_val) | ||
|
||
prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu) | ||
|
||
def test_taichi_clean_cache(): | ||
s = 1000 | ||
indices = bm.random.randint(0, s, (s, 1000)) | ||
vector = bm.random.rand(s) < 0.1 | ||
weight = bm.array([1.0]) | ||
|
||
out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) | ||
|
||
out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) | ||
|
||
print(out) | ||
bm.clear_buffer_memory() | ||
|
||
print('kernels: ', bm.check_kernels_count()) | ||
|
||
bm.clean_caches() | ||
|
||
print('kernels: ', bm.check_kernels_count()) | ||
|
||
# test_taichi_clean_cache() |