diff --git a/brainpy/_src/math/op_register/__init__.py b/brainpy/_src/math/op_register/__init__.py index 6f2dbd4f..a715abfe 100644 --- a/brainpy/_src/math/op_register/__init__.py +++ b/brainpy/_src/math/op_register/__init__.py @@ -2,5 +2,6 @@ from .numba_approach import (CustomOpByNumba, register_op_with_numba, compile_cpu_signature_with_numba) +from .taichi_aot_based import clean_caches from .base import XLACustomOp from .utils import register_general_batching diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index cb05ece8..18285524 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -14,7 +14,9 @@ # from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule from .taichi_aot_based import (register_taichi_cpu_translation_rule, - register_taichi_gpu_translation_rule,) + register_taichi_gpu_translation_rule, + check_kernels_count, + clean_caches) from .utils import register_general_batching from brainpy._src.math.op_register.ad_support import defjvp @@ -138,6 +140,10 @@ def __init__( if transpose_translation is not None: ad.primitive_transposes[self.primitive] = transpose_translation + # check cache size and clean cache (the size of 3000 kernels is about 100MB) + if check_kernels_count() > 3000: + clean_caches() + def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None, **kwargs): if outs is None: outs = self.outs diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index ab7b9801..bb9aecc8 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -36,6 +36,34 @@ def encode_md5(source: str) -> str: return md5.hexdigest() +# check kernels count +def check_kernels_count() -> int: + if not os.path.exists(kernels_aot_path): + return 0 + kernels_count = 0 + dir1 = os.listdir(kernels_aot_path) + for i in dir1: + dir2 = os.listdir(os.path.join(kernels_aot_path, i)) + kernels_count += len(dir2) + return kernels_count + +# clean caches +def clean_caches(kernels_name: list[str]): + if kernels_name is None: + if not os.path.exists(kernels_aot_path): + raise FileNotFoundError("The kernels cache folder does not exist. \ + Please define a kernel using `taichi.kernel` \ + and customize the operator using `bm.XLACustomOp` \ + before calling the operator.") + os.removedirs(kernels_aot_path) + print('Clean all kernel\'s cache successfully') + return + for kernel_name in kernels_name: + try: + os.removedirs(os.path.join(kernels_aot_path, kernel_name)) + except FileNotFoundError: + raise FileNotFoundError(f'Kernel {kernel_name} does not exist.') + print('Clean kernel\'s cache successfully') # TODO # not a very good way @@ -151,6 +179,9 @@ def _build_kernel( if ti.lang.impl.current_cfg().arch != arch: raise RuntimeError(f"Arch {arch} is not available") + # get kernel name + kernel_name = kernel.__name__ + # replace the name of the func kernel.__name__ = f'taichi_kernel_{device}' @@ -170,6 +201,9 @@ def _build_kernel( mod.add_kernel(kernel, template_args=template_args_dict) mod.save(kernel_path) + # rename kernel name + kernel.__name__ = kernel_name + ### KERNEL CALL PREPROCESS ### @@ -246,7 +280,7 @@ def _preprocess_kernel_call_cpu( return in_out_info -def preprocess_kernel_call_gpu( +def _preprocess_kernel_call_gpu( source_md5_encode: str, ins: dict, outs: dict, @@ -312,7 +346,7 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs): # kernel to code codes = _kernel_to_code(kernel, abs_ins, abs_outs, platform) - source_md5_encode = encode_md5(codes) + source_md5_encode = kernel.__name__ + '/' + encode_md5(codes) # create ins, outs dict from kernel's args in_num = len(ins) @@ -332,7 +366,7 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs): # returns if platform in ['gpu', 'cuda']: import_brainpylib_gpu_ops() - opaque = preprocess_kernel_call_gpu(source_md5_encode, ins_dict, outs_dict) + opaque = _preprocess_kernel_call_gpu(source_md5_encode, ins_dict, outs_dict) return opaque elif platform == 'cpu': import_brainpylib_cpu_ops()