diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index 9d653a6b..fd0314b2 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -2,6 +2,7 @@ import functools +import gc import inspect import os import re @@ -675,7 +676,8 @@ def set_host_device_count(n): def clear_buffer_memory( platform: str = None, array: bool = True, - compilation: bool = True, + compilation: bool = False, + transform: bool = True, name: bool = True, ): """Clear all on-device buffers. @@ -695,7 +697,9 @@ def clear_buffer_memory( array: bool Clear all buffer array. Default is True. compilation: bool - Clear compilation cache. Default is True. + Clear compilation cache. Default is False. + transform: bool + Clear transform cache. Default is True. name: bool Clear name cache. Default is True. @@ -705,9 +709,11 @@ def clear_buffer_memory( buf.delete() if compilation: jax.clear_caches() + if transform: naming.clear_stack_cache() if name: naming.clear_name_cache() + gc.collect() def disable_gpu_memory_preallocation(release_memory: bool = True):