From e6f08d7969f97edda21d2a5f0efd506f77315597 Mon Sep 17 00:00:00 2001 From: Surya2k1 Date: Wed, 29 Jan 2025 00:05:01 +0530 Subject: [PATCH 1/3] Fix deserialization issue with custom_loss function while reloading saved models of <3.7v --- keras/src/saving/object_registration.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/keras/src/saving/object_registration.py b/keras/src/saving/object_registration.py index 8c0f538917bd..976b28af4ff3 100644 --- a/keras/src/saving/object_registration.py +++ b/keras/src/saving/object_registration.py @@ -227,4 +227,13 @@ def from_config(cls, config, custom_objects=None): return custom_objects[name] elif module_objects and name in module_objects: return module_objects[name] + # Check if there are objects without Package name appended to them. + # For Backward compatibility of custom_loss functions in versions <3.7 + elif name is not None and any( + lambda key: key.contains(name) + for key in custom_objects_scope_dict.keys() + ): + for key in custom_objects_scope_dict.keys(): + if name in key: + return custom_objects_scope_dict[key] return None From 4bd7e38de89565364b40bd11bf372fdbe4af4162 Mon Sep 17 00:00:00 2001 From: Surya2k1 Date: Wed, 29 Jan 2025 14:24:04 +0530 Subject: [PATCH 2/3] Fix failing test cases for numpy,jax and TF backends --- keras/src/saving/object_registration.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keras/src/saving/object_registration.py b/keras/src/saving/object_registration.py index 976b28af4ff3..8d12cac31d5c 100644 --- a/keras/src/saving/object_registration.py +++ b/keras/src/saving/object_registration.py @@ -227,9 +227,9 @@ def from_config(cls, config, custom_objects=None): return custom_objects[name] elif module_objects and name in module_objects: return module_objects[name] - # Check if there are objects without Package name appended to them. - # For Backward compatibility of custom_loss functions in versions <3.7 - elif name is not None and any( + # # Check if there are objects without Package name appended to them. + # # For Backward compatibility of custom_loss functions in versions <3.7 + elif name not in (None, "Functional") and any( lambda key: key.contains(name) for key in custom_objects_scope_dict.keys() ): From cb987ec1f95d7797c08fd88d13dd5cfb9aad73e2 Mon Sep 17 00:00:00 2001 From: Surya2k1 Date: Wed, 29 Jan 2025 21:37:06 +0530 Subject: [PATCH 3/3] Fallback code for savedmodel from <=3.6v to >=3.7 --- keras/src/saving/object_registration.py | 9 --------- keras/src/saving/serialization_lib.py | 10 +++++++++- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/keras/src/saving/object_registration.py b/keras/src/saving/object_registration.py index 8d12cac31d5c..8c0f538917bd 100644 --- a/keras/src/saving/object_registration.py +++ b/keras/src/saving/object_registration.py @@ -227,13 +227,4 @@ def from_config(cls, config, custom_objects=None): return custom_objects[name] elif module_objects and name in module_objects: return module_objects[name] - # # Check if there are objects without Package name appended to them. - # # For Backward compatibility of custom_loss functions in versions <3.7 - elif name not in (None, "Functional") and any( - lambda key: key.contains(name) - for key in custom_objects_scope_dict.keys() - ): - for key in custom_objects_scope_dict.keys(): - if name in key: - return custom_objects_scope_dict[key] return None diff --git a/keras/src/saving/serialization_lib.py b/keras/src/saving/serialization_lib.py index 48c70808b405..9bba12aabc37 100644 --- a/keras/src/saving/serialization_lib.py +++ b/keras/src/saving/serialization_lib.py @@ -780,7 +780,15 @@ def _retrieve_class_or_fn( ) if obj is not None: return obj - + # Fall back code for reloading saved models of versions <=3.6 + # into versions >=3.7 + filtered_dict = { + k: v + for k, v in custom_objects.items() + if k.endswith(full_config["config"]) + } + if filtered_dict: + return next(iter(filtered_dict.values())) # Otherwise, attempt to retrieve the class object given the `module` # and `class_name`. Import the module, find the class. package = module.split(".", maxsplit=1)[0]