From dea0fca781952c1a047bd0ad26e9fdc780bd7436 Mon Sep 17 00:00:00 2001 From: "Andrew C. Sweet" Date: Wed, 29 Jan 2025 13:45:33 -0800 Subject: [PATCH 1/2] patch for pytest to run with mlx --- keras/src/backend/__init__.py | 3 ++- keras/src/backend/mlx/__init__.py | 1 + keras/src/backend/mlx/export.py | 10 ++++++++++ keras/src/backend/mlx/nn.py | 5 +++++ keras/src/export/saved_model.py | 4 ++++ 5 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 keras/src/backend/mlx/export.py diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 3ab350d30aaa..b11ca34d22ee 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -57,7 +57,8 @@ distribution_lib = None elif backend() == "mlx": from keras.src.backend.mlx import * # noqa: F403 - + from keras.src.backend.mlx.core import Variable as BackendVariable + distribution_lib = None else: raise ValueError(f"Unable to import backend : {backend()}") diff --git a/keras/src/backend/mlx/__init__.py b/keras/src/backend/mlx/__init__.py index 0b2dd40f4dab..556667843c80 100644 --- a/keras/src/backend/mlx/__init__.py +++ b/keras/src/backend/mlx/__init__.py @@ -1,5 +1,6 @@ """MLX backend APIs.""" +from keras.src.backend.common.name_scope import name_scope from keras.src.backend.mlx import core from keras.src.backend.mlx import image from keras.src.backend.mlx import linalg diff --git a/keras/src/backend/mlx/export.py b/keras/src/backend/mlx/export.py new file mode 100644 index 000000000000..695f02969e94 --- /dev/null +++ b/keras/src/backend/mlx/export.py @@ -0,0 +1,10 @@ +class MlxExportArchive: + def track(self, resource): + raise NotImplementedError( + "`track` is not implemented in the mlx backend." + ) + + def add_endpoint(self, name, fn, input_signature=None, **kwargs): + raise NotImplementedError( + "`add_endpoint` is not implemented in the mlx backend." + ) diff --git a/keras/src/backend/mlx/nn.py b/keras/src/backend/mlx/nn.py index cfb794019257..b9755caae27d 100644 --- a/keras/src/backend/mlx/nn.py +++ b/keras/src/backend/mlx/nn.py @@ -115,6 +115,11 @@ def gelu_tanh_approx(x): return f(x) +def celu(x, alpha=1.0): + x = convert_to_tensor(x) + return nn.celu(x, alpha=alpha) + + def softmax(x, axis=-1): x = convert_to_tensor(x) return mx.softmax(x, axis=axis) diff --git a/keras/src/export/saved_model.py b/keras/src/export/saved_model.py index f52c73e54618..f0900eacbe8a 100644 --- a/keras/src/export/saved_model.py +++ b/keras/src/export/saved_model.py @@ -29,6 +29,10 @@ from keras.src.backend.openvino.export import ( OpenvinoExportArchive as BackendExportArchive, ) +elif backend.backend() == "mlx": + from keras.src.backend.mlx.export import ( + MlxExportArchive as BackendExportArchive, + ) else: raise RuntimeError( f"Backend '{backend.backend()}' must implement a layer mixin class." From 896e9adfdd4060c545899d1938a870e33abd8884 Mon Sep 17 00:00:00 2001 From: "Andrew C. Sweet" Date: Wed, 29 Jan 2025 13:52:04 -0800 Subject: [PATCH 2/2] formatting --- keras/src/backend/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index b11ca34d22ee..5a91433bbf1c 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -58,7 +58,7 @@ elif backend() == "mlx": from keras.src.backend.mlx import * # noqa: F403 from keras.src.backend.mlx.core import Variable as BackendVariable - + distribution_lib = None else: raise ValueError(f"Unable to import backend : {backend()}")