diff --git a/CHANGELOG.md b/CHANGELOG.md index de6fa307..2cb77896 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ Changelog follow the https://keepachangelog.com/ standard (at least the headers) ## [Unreleased] +* Add `kd.nn.WrapperModule` to make a inner-module transparent with + respect of . + ## [1.0.0] - 2024-11-21 * `kd.kontext.Path` now supports tensor slicing. So for example using keys like diff --git a/kauldron/modules/__init__.py b/kauldron/modules/__init__.py index ba37c706..1151e4a9 100644 --- a/kauldron/modules/__init__.py +++ b/kauldron/modules/__init__.py @@ -32,6 +32,7 @@ # Modules from kauldron.modules.adapter import ExternalModule +from kauldron.modules.adapter import WrapperModule from kauldron.modules.misc import Dropout from kauldron.modules.misc import DummyModel from kauldron.modules.misc import Identity diff --git a/kauldron/modules/adapter.py b/kauldron/modules/adapter.py index 2b203b32..b961b4fd 100644 --- a/kauldron/modules/adapter.py +++ b/kauldron/modules/adapter.py @@ -20,7 +20,25 @@ from kauldron.utils import train_property -class ExternalModule(nn.Module): +class WrapperModule(nn.Module): + """Base class to wrapper a module. + + The wrapper module transparent with respect to the inner parameters ( + `{'params': inner_params}` instead of nesting + `{'params': {'model': inner_params}}`). + """ + + model: nn.Module + + def __post_init__(self): + super().__post_init__() + # Share scope, to make the wrapper module transparent with respect to the + # parameters (instead of nesting `{'params': {'model': params}}`). + if self.scope is not None: + nn.share_scope(self, self.model) + + +class ExternalModule(WrapperModule): """Module that is defined outside Kauldron. This is a **very** thin wrapper around `flax.linen.Module` that add: @@ -52,7 +70,6 @@ class ExternalModule(nn.Module): can be inverted with `~` (e.g. `train_kwarg_name='~deterministic'`) """ - model: nn.Module keys: str | dict[str, str] train_kwarg_name: Optional[str] = None diff --git a/kauldron/modules/adapter_test.py b/kauldron/modules/adapter_test.py index 65fb021c..64ddd51d 100644 --- a/kauldron/modules/adapter_test.py +++ b/kauldron/modules/adapter_test.py @@ -14,6 +14,8 @@ """Test.""" +from typing import Any + from flax import linen as nn import jax import jax.numpy as jnp @@ -48,3 +50,18 @@ def test_external(): assert not np.array_equal(out_train, inputs) np.testing.assert_array_equal(out_eval, inputs) + + +def test_wrapper(): + class MyWrapper(kd.nn.WrapperModule): + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self.model(*args, **kwargs) + + model = MyWrapper( + model=nn.Dense(2), + ) + + inputs = jnp.ones((5,)) + params = model.init(jax.random.PRNGKey(0), inputs) + assert list(params['params']) == ['kernel', 'bias']