-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[dependency] remove all numba and taichi dependency
- Loading branch information
Showing
19 changed files
with
2,204 additions
and
2,089 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,87 +1,131 @@ | ||
import os | ||
import sys | ||
from jax.lib import xla_client | ||
|
||
__all__ = [ | ||
'import_taichi', | ||
'import_brainpylib_cpu_ops', | ||
'import_brainpylib_gpu_ops', | ||
] | ||
|
||
_minimal_brainpylib_version = '0.1.10' | ||
_minimal_taichi_version = (1, 7, 0) | ||
|
||
taichi = None | ||
brainpylib_cpu_ops = None | ||
brainpylib_gpu_ops = None | ||
|
||
taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. ' | ||
f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n' | ||
'> pip install taichi==1.7.0') | ||
os.environ["TI_LOG_LEVEL"] = "error" | ||
|
||
|
||
def import_taichi(): | ||
global taichi | ||
if taichi is None: | ||
with open(os.devnull, 'w') as devnull: | ||
old_stdout = sys.stdout | ||
sys.stdout = devnull | ||
try: | ||
import taichi as taichi # noqa | ||
except ModuleNotFoundError: | ||
raise ModuleNotFoundError(taichi_install_info) | ||
finally: | ||
sys.stdout = old_stdout | ||
|
||
if taichi.__version__ != _minimal_taichi_version: | ||
raise RuntimeError(taichi_install_info) | ||
return taichi | ||
|
||
|
||
def is_brainpylib_gpu_installed(): | ||
return False if brainpylib_gpu_ops is None else True | ||
|
||
|
||
def import_brainpylib_cpu_ops(): | ||
global brainpylib_cpu_ops | ||
if brainpylib_cpu_ops is None: | ||
try: | ||
from brainpylib import cpu_ops as brainpylib_cpu_ops | ||
|
||
for _name, _value in brainpylib_cpu_ops.registrations().items(): | ||
xla_client.register_custom_call_target(_name, _value, platform="cpu") | ||
|
||
import brainpylib | ||
if brainpylib.__version__ < _minimal_brainpylib_version: | ||
raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') | ||
if hasattr(brainpylib, 'check_brainpy_version'): | ||
brainpylib.check_brainpy_version() | ||
|
||
except ImportError: | ||
raise ImportError('Please install brainpylib. \n' | ||
'See https://brainpy.readthedocs.io for installation instructions.') | ||
|
||
return brainpylib_cpu_ops | ||
|
||
|
||
def import_brainpylib_gpu_ops(): | ||
global brainpylib_gpu_ops | ||
if brainpylib_gpu_ops is None: | ||
try: | ||
from brainpylib import gpu_ops as brainpylib_gpu_ops | ||
|
||
for _name, _value in brainpylib_gpu_ops.registrations().items(): | ||
xla_client.register_custom_call_target(_name, _value, platform="gpu") | ||
|
||
import brainpylib | ||
if brainpylib.__version__ < _minimal_brainpylib_version: | ||
raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') | ||
if hasattr(brainpylib, 'check_brainpy_version'): | ||
brainpylib.check_brainpy_version() | ||
|
||
except ImportError: | ||
raise ImportError('Please install GPU version of brainpylib. \n' | ||
'See https://brainpy.readthedocs.io for installation instructions.') | ||
|
||
return brainpylib_gpu_ops | ||
import os | ||
import sys | ||
from jax.lib import xla_client | ||
|
||
__all__ = [ | ||
'import_taichi', | ||
'import_taichi_else_None', | ||
'import_numba', | ||
'import_numba_else_None', | ||
'import_brainpylib_cpu_ops', | ||
'import_brainpylib_gpu_ops', | ||
] | ||
|
||
_minimal_brainpylib_version = '0.1.10' | ||
_minimal_taichi_version = (1, 7, 0) | ||
|
||
taichi = None | ||
numba = None | ||
brainpylib_cpu_ops = None | ||
brainpylib_gpu_ops = None | ||
|
||
taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. ' | ||
f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n' | ||
'> pip install taichi==1.7.0') | ||
os.environ["TI_LOG_LEVEL"] = "error" | ||
|
||
|
||
def import_taichi(): | ||
global taichi | ||
if taichi is None: | ||
with open(os.devnull, 'w') as devnull: | ||
old_stdout = sys.stdout | ||
sys.stdout = devnull | ||
try: | ||
import taichi as taichi # noqa | ||
except ModuleNotFoundError: | ||
raise ModuleNotFoundError(taichi_install_info) | ||
finally: | ||
sys.stdout = old_stdout | ||
|
||
if taichi.__version__ != _minimal_taichi_version: | ||
raise RuntimeError(taichi_install_info) | ||
return taichi | ||
|
||
|
||
def import_taichi_else_None(): | ||
global taichi | ||
if taichi is None: | ||
with open(os.devnull, 'w') as devnull: | ||
old_stdout = sys.stdout | ||
sys.stdout = devnull | ||
try: | ||
import taichi as taichi # noqa | ||
except: | ||
return None | ||
finally: | ||
sys.stdout = old_stdout | ||
|
||
if taichi.__version__ != _minimal_taichi_version: | ||
raise RuntimeError(taichi_install_info) | ||
return taichi | ||
|
||
|
||
def import_numba(): | ||
global numba | ||
if numba is None: | ||
try: | ||
import numba as numba | ||
except ModuleNotFoundError: | ||
raise ModuleNotFoundError('We need numba. Please install numba by pip . \n' | ||
'> pip install numba' | ||
) | ||
return numba | ||
|
||
|
||
def import_numba_else_None(): | ||
global numba | ||
if numba is None: | ||
try: | ||
import numba as numba | ||
except: | ||
return None | ||
return numba | ||
|
||
|
||
def is_brainpylib_gpu_installed(): | ||
return False if brainpylib_gpu_ops is None else True | ||
|
||
|
||
def import_brainpylib_cpu_ops(): | ||
global brainpylib_cpu_ops | ||
if brainpylib_cpu_ops is None: | ||
try: | ||
from brainpylib import cpu_ops as brainpylib_cpu_ops | ||
|
||
for _name, _value in brainpylib_cpu_ops.registrations().items(): | ||
xla_client.register_custom_call_target(_name, _value, platform="cpu") | ||
|
||
import brainpylib | ||
if brainpylib.__version__ < _minimal_brainpylib_version: | ||
raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') | ||
if hasattr(brainpylib, 'check_brainpy_version'): | ||
brainpylib.check_brainpy_version() | ||
|
||
except ImportError: | ||
raise ImportError('Please install brainpylib. \n' | ||
'See https://brainpy.readthedocs.io for installation instructions.') | ||
|
||
return brainpylib_cpu_ops | ||
|
||
|
||
def import_brainpylib_gpu_ops(): | ||
global brainpylib_gpu_ops | ||
if brainpylib_gpu_ops is None: | ||
try: | ||
from brainpylib import gpu_ops as brainpylib_gpu_ops | ||
|
||
for _name, _value in brainpylib_gpu_ops.registrations().items(): | ||
xla_client.register_custom_call_target(_name, _value, platform="gpu") | ||
|
||
import brainpylib | ||
if brainpylib.__version__ < _minimal_brainpylib_version: | ||
raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') | ||
if hasattr(brainpylib, 'check_brainpy_version'): | ||
brainpylib.check_brainpy_version() | ||
|
||
except ImportError: | ||
raise ImportError('Please install GPU version of brainpylib. \n' | ||
'See https://brainpy.readthedocs.io for installation instructions.') | ||
|
||
return brainpylib_gpu_ops |
Oops, something went wrong.