Skip to content

Commit

Permalink
[taichi] Make taichi caches more transparent and Add clean caches fun…
Browse files Browse the repository at this point in the history
…ction
  • Loading branch information
Routhleck committed Jan 17, 2024
1 parent 02b85b2 commit 015dc60
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 4 deletions.
1 change: 1 addition & 0 deletions brainpy/_src/math/op_register/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .numba_approach import (CustomOpByNumba,
register_op_with_numba,
compile_cpu_signature_with_numba)
from .taichi_aot_based import clean_caches
from .base import XLACustomOp
from .utils import register_general_batching
8 changes: 7 additions & 1 deletion brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_cpu_translation_rule,
register_taichi_gpu_translation_rule,)
register_taichi_gpu_translation_rule,
check_kernels_count,
clean_caches)
from .utils import register_general_batching
from brainpy._src.math.op_register.ad_support import defjvp

Expand Down Expand Up @@ -138,6 +140,10 @@ def __init__(
if transpose_translation is not None:
ad.primitive_transposes[self.primitive] = transpose_translation

# check cache size and clean cache (the size of 3000 kernels is about 100MB)
if check_kernels_count() > 3000:
clean_caches()

def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None, **kwargs):
if outs is None:
outs = self.outs
Expand Down
40 changes: 37 additions & 3 deletions brainpy/_src/math/op_register/taichi_aot_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,34 @@ def encode_md5(source: str) -> str:

return md5.hexdigest()

# check kernels count
def check_kernels_count() -> int:
if not os.path.exists(kernels_aot_path):
return 0
kernels_count = 0
dir1 = os.listdir(kernels_aot_path)
for i in dir1:
dir2 = os.listdir(os.path.join(kernels_aot_path, i))
kernels_count += len(dir2)
return kernels_count

# clean caches
def clean_caches(kernels_name: list[str]):
if kernels_name is None:
if not os.path.exists(kernels_aot_path):
raise FileNotFoundError("The kernels cache folder does not exist. \
Please define a kernel using `taichi.kernel` \
and customize the operator using `bm.XLACustomOp` \
before calling the operator.")
os.removedirs(kernels_aot_path)
print('Clean all kernel\'s cache successfully')
return
for kernel_name in kernels_name:
try:
os.removedirs(os.path.join(kernels_aot_path, kernel_name))
except FileNotFoundError:
raise FileNotFoundError(f'Kernel {kernel_name} does not exist.')
print('Clean kernel\'s cache successfully')

# TODO
# not a very good way
Expand Down Expand Up @@ -151,6 +179,9 @@ def _build_kernel(
if ti.lang.impl.current_cfg().arch != arch:
raise RuntimeError(f"Arch {arch} is not available")

# get kernel name
kernel_name = kernel.__name__

# replace the name of the func
kernel.__name__ = f'taichi_kernel_{device}'

Expand All @@ -170,6 +201,9 @@ def _build_kernel(
mod.add_kernel(kernel, template_args=template_args_dict)
mod.save(kernel_path)

# rename kernel name
kernel.__name__ = kernel_name


### KERNEL CALL PREPROCESS ###

Expand Down Expand Up @@ -246,7 +280,7 @@ def _preprocess_kernel_call_cpu(
return in_out_info


def preprocess_kernel_call_gpu(
def _preprocess_kernel_call_gpu(
source_md5_encode: str,
ins: dict,
outs: dict,
Expand Down Expand Up @@ -312,7 +346,7 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs):

# kernel to code
codes = _kernel_to_code(kernel, abs_ins, abs_outs, platform)
source_md5_encode = encode_md5(codes)
source_md5_encode = kernel.__name__ + '/' + encode_md5(codes)

# create ins, outs dict from kernel's args
in_num = len(ins)
Expand All @@ -332,7 +366,7 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs):
# returns
if platform in ['gpu', 'cuda']:
import_brainpylib_gpu_ops()
opaque = preprocess_kernel_call_gpu(source_md5_encode, ins_dict, outs_dict)
opaque = _preprocess_kernel_call_gpu(source_md5_encode, ins_dict, outs_dict)
return opaque
elif platform == 'cpu':
import_brainpylib_cpu_ops()
Expand Down

0 comments on commit 015dc60

Please sign in to comment.