Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean taichi AOT caches #643

Merged
merged 8 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions brainpy/_src/math/event/tests/test_event_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

import platform
force_test = False # turn on to force test on windows locally
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)


seed = 1234


Expand Down
7 changes: 6 additions & 1 deletion brainpy/_src/math/jitconn/tests/test_event_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)]
import platform
force_test = False # turn on to force test on windows locally
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)


shapes = [(100, 200), (1000, 10)]


Expand Down
6 changes: 6 additions & 0 deletions brainpy/_src/math/jitconn/tests/test_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

import platform
force_test = False # turn on to force test on windows locally
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)


shapes = [(100, 200), (1000, 10)]


Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/op_register/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
compile_cpu_signature_with_numba)
from .base import XLACustomOp
from .utils import register_general_batching
from .taichi_aot_based import clean_caches, check_kernels_count
from .taichi_aot_based import clear_taichi_aot_caches, count_taichi_aot_kernels
from .base import XLACustomOp
from .utils import register_general_batching
76 changes: 50 additions & 26 deletions brainpy/_src/math/op_register/taichi_aot_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,25 @@
import re
import shutil
from functools import partial, reduce
from typing import Any, Sequence
from typing import Any, Sequence, Union

import jax.core
import numpy as np
from jax.interpreters import xla, mlir
from jax.lib import xla_client
from jaxlib.hlo_helpers import custom_call

from brainpy.errors import PackageMissingError
from brainpy._src.dependency_check import (import_taichi,
import_brainpylib_cpu_ops,
import_brainpylib_gpu_ops)
from brainpy.errors import PackageMissingError
from .utils import _shape_to_layout


### UTILS ###
taichi_cache_path = None


# --- UTILS ###

# get the path of home directory on Linux, Windows, Mac
def get_home_dir():
Expand All @@ -43,8 +46,18 @@ def encode_md5(source: str) -> str:

return md5.hexdigest()


# check kernels count
def check_kernels_count() -> int:
def count_taichi_aot_kernels() -> int:
"""
Count the number of AOT compiled kernels.

Returns
-------
kernels_count: int
The number of AOT compiled kernels.

"""
if not os.path.exists(kernels_aot_path):
return 0
kernels_count = 0
Expand All @@ -54,23 +67,37 @@ def check_kernels_count() -> int:
kernels_count += len(dir2)
return kernels_count

# clean caches
def clean_caches(kernels_name: list[str]=None):
if kernels_name is None:
if not os.path.exists(kernels_aot_path):
raise FileNotFoundError("The kernels cache folder does not exist. \
Please define a kernel using `taichi.kernel` \
and customize the operator using `bm.XLACustomOp` \
before calling the operator.")
shutil.rmtree(kernels_aot_path)
print('Clean all kernel\'s cache successfully')

def clear_taichi_aot_caches(kernels: Union[str, Sequence[str]] = None):
"""
Clean the cache of the AOT compiled kernels.

Parameters
----------
kernels: str or list of str
The name of the kernel to be cleaned. If None, all the kernels will be cleaned.
"""
if kernels is None:
global taichi_cache_path
if taichi_cache_path is None:
from taichi._lib.utils import import_ti_python_core
taichi_cache_path = import_ti_python_core().get_repo_dir()
# clean taichi cache
if os.path.exists(taichi_cache_path):
shutil.rmtree(taichi_cache_path)
# clean brainpy-taichi AOT cache
if os.path.exists(kernels_aot_path):
shutil.rmtree(kernels_aot_path)
return
for kernel_name in kernels_name:
try:
if isinstance(kernels, str):
kernels = [kernels]
if not isinstance(kernels, list):
raise TypeError(f'kernels_name must be a list of str, but got {type(kernels)}')
# clear brainpy kernel cache
for kernel_name in kernels:
if os.path.exists(os.path.join(kernels_aot_path, kernel_name)):
shutil.rmtree(os.path.join(kernels_aot_path, kernel_name))
except FileNotFoundError:
raise FileNotFoundError(f'Kernel {kernel_name} does not exist.')
print('Clean kernel\'s cache successfully')


# TODO
# not a very good way
Expand Down Expand Up @@ -104,7 +131,7 @@ def is_metal_supported():
return True


### VARIABLES ###
# --- VARIABLES ###
home_path = get_home_dir()
kernels_aot_path = os.path.join(home_path, '.brainpy', 'kernels')
is_metal_device = is_metal_supported()
Expand All @@ -122,7 +149,7 @@ def _check_kernel_exist(source_md5_encode: str) -> bool:
return False


### KERNEL AOT BUILD ###
# --- KERNEL AOT BUILD ###


def _array_to_field(dtype, shape) -> Any:
Expand Down Expand Up @@ -212,7 +239,7 @@ def _build_kernel(
kernel.__name__ = kernel_name


### KERNEL CALL PREPROCESS ###
# --- KERNEL CALL PREPROCESS ###

# convert type to number
type_number_map = {
Expand Down Expand Up @@ -334,9 +361,6 @@ def _preprocess_kernel_call_gpu(
return opaque





def _XlaOp_to_ShapedArray(c, xla_op):
xla_op = c.get_shape(xla_op)
return jax.core.ShapedArray(xla_op.dimensions(), xla_op.element_type())
Expand Down Expand Up @@ -376,7 +400,7 @@ def _compile_kernel(abs_ins, kernel, platform: str, **kwargs):
try:
os.removedirs(os.path.join(kernels_aot_path, source_md5_encode))
except Exception:
raise RuntimeError(f'Failed to preprocess info to build kernel:\n\n {codes}') from e
raise RuntimeError(f'Failed to preprocess info to build kernel:\n\n {codes}') from e
raise RuntimeError(f'Failed to build kernel:\n\n {codes}') from e

# returns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def test_taichi_clean_cache():
print(out)
bm.clear_buffer_memory()

print('kernels: ', bm.check_kernels_count())
print('kernels: ', bm.count_taichi_aot_kernels())

bm.clean_caches()
bm.clear_taichi_aot_caches()

print('kernels: ', bm.check_kernels_count())
print('kernels: ', bm.count_taichi_aot_kernels())

# test_taichi_clean_cache()
7 changes: 6 additions & 1 deletion brainpy/_src/math/sparse/tests/test_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@
import brainpy as bp
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

import platform
force_test = False # turn on to force test on windows locally
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)


seed = 1234


Expand Down
4 changes: 2 additions & 2 deletions brainpy/math/op_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from brainpy._src.math.op_register import (
CustomOpByNumba,
compile_cpu_signature_with_numba,
clean_caches,
check_kernels_count,
clear_taichi_aot_caches,
count_taichi_aot_kernels,
)

from brainpy._src.math.op_register.base import XLACustomOp
Expand Down
32 changes: 31 additions & 1 deletion docs/apis/brainpy.math.op_register.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ General Operator Customization Interface



CPU Operator Customization with Taichi
-------------------------------------

.. currentmodule:: brainpy.math
.. automodule:: brainpy.math

.. autosummary::
:toctree: generated/

clear_taichi_aot_caches
count_taichi_aot_kernels






CPU Operator Customization with Numba
-------------------------------------

Expand All @@ -34,7 +51,6 @@ CPU Operator Customization with Numba
:template: classtemplate.rst

CustomOpByNumba
XLACustomOp


.. autosummary::
Expand All @@ -43,3 +59,17 @@ CPU Operator Customization with Numba
register_op_with_numba
compile_cpu_signature_with_numba



Operator Autograd Customization
-------------------------------

.. currentmodule:: brainpy.math
.. automodule:: brainpy.math

.. autosummary::
:toctree: generated/

defjvp


4 changes: 2 additions & 2 deletions examples/dynamics_simulation/ei_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def __init__(self):
super().__init__()
self.N = bp.dyn.LifRefLTC(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 1.))
self.delay = bp.VarDelay(self.N.spike, entries={'delay': 2})
self.delay = bp.VarDelay(self.N.spike, entries={'delay': 0.})
self.syn1 = bp.dyn.Expon(size=3200, tau=5.)
self.syn2 = bp.dyn.Expon(size=800, tau=10.)
self.E = bp.dyn.VanillaProj(
Expand All @@ -228,7 +228,7 @@ def __init__(self):
)

def update(self, input):
spk = self.delay.at('I')
spk = self.delay.at('delay')
self.E(self.syn1(spk[:3200]))
self.I(self.syn2(spk[3200:]))
self.delay(self.N(input))
Expand Down
Loading