diff --git a/tests/test_lora.py b/tests/test_lora.py index f7d852531..b6933f935 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -74,8 +74,8 @@ def __call__(self, x): @staticmethod def init(*, key): k1, k2 = jax.random.split(key) - first = hnn.Linear.init(In, Mid, key=k1) - second = hnn.Linear.init(Mid, In, key=k2) + first = hnn.Linear.init(In, Mid, key=k1, out_first=True) + second = hnn.Linear.init(Mid, In, key=k2, out_first=True) return Module(first, second) Layers = hax.Axis("Layers", 3) @@ -91,7 +91,7 @@ def init(*, key): assert loraized.stacked.first.lora.lora_A.weight.axes == (Layers, hax.Axis("LORA_R", 8), In) assert loraized.stacked.first.lora.lora_B.weight.axes == (Layers, Mid, hax.Axis("LORA_R", 8)) - assert loraized.stacked.second.weight.axes == (Layers, Mid, In) + assert loraized.stacked.second.weight.axes == (Layers, In, Mid) input = hax.random.normal(k0, (In,)) assert not hax.all(hax.isclose(module.fold(input), loraized.fold(input)))