From 7691e9c6ec2bbb3d08b6a2dee5d30af1ebfd3d26 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 19 Oct 2021 18:57:42 -0400 Subject: [PATCH 1/2] Add pattern to fuse nested Subs --- funsor/terms.py | 16 +++++++++++++++- test/test_gaussian.py | 27 ++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/funsor/terms.py b/funsor/terms.py index ad3ee04e..2c39bab0 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -962,13 +962,27 @@ def _sample(self, sampled_vars, sample_inputs, rng_key=None): @lazy.register(Subs, Funsor, object) @eager.register(Subs, Funsor, object) -def eager_subs(arg, subs): +def eager_subs_funsor(arg, subs): assert isinstance(subs, tuple) if not any(k in arg.inputs for k, v in subs): return arg return substitute(arg, subs) +@lazy.register(Subs, Subs, object) +@eager.register(Subs, Subs, object) +def eager_subs_subs(arg, subs): + assert isinstance(subs, tuple) + subs = tuple((k, v) for k, v in subs if k in arg.inputs) + if not subs: + return arg + + # Fuse substitutions. + fused_subs = tuple((k, Subs(v, subs)) for k, v in arg.subs.items()) + fused_subs += subs + return Subs(arg.arg, fused_subs) + + @die.register(Subs, Funsor, tuple) def die_subs(arg, subs): expr = reflect.interpret(Subs, arg, subs) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index edb90ab8..20d61b24 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -23,8 +23,9 @@ _vm, ) from funsor.integrate import Integrate +from funsor.interpretations import eager, lazy from funsor.tensor import Einsum, Tensor, numeric_array -from funsor.terms import Number, Unary, Variable +from funsor.terms import Number, Subs, Unary, Variable from funsor.testing import ( assert_close, id_from_inputs, @@ -890,3 +891,27 @@ def test_eager_add(): actual = Contraction(ops.logaddexp, ops.add, frozenset({a}), (g1, g2)) assert isinstance(actual, Tensor) + + +@pytest.mark.parametrize("interp", [eager, lazy]) +def test_nested_subs_1(interp): + with interp: + g = Gaussian(randn(3), randn(2, 3), OrderedDict([("b", Real), ("a", Real)])) + a = ops.abs(Variable("aux_0", Real)) + b = ops.abs(Variable("aux_1", Real)) + g_ab = g(a=a, b=b) + g_a_b = g(a=a)(b=b) + g_b_a = g(b=b)(a=a) + + # Test subs fusion. + assert isinstance(g_ab, Subs) + assert isinstance(g_ab.arg, Gaussian) + assert isinstance(g_a_b, Subs) + assert isinstance(g_a_b.arg, Gaussian) + assert isinstance(g_b_a, Subs) + assert isinstance(g_b_a.arg, Gaussian) + + # Compare on ground data. + subs = {"aux_0": randn(()), "aux_1": randn(())} + assert_close(g_ab(**subs), g_a_b(**subs)) + assert_close(g_ab(**subs), g_b_a(**subs)) From cfb0d7c965a0658f76ba6e635b864ec929d03eb6 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 19 Oct 2021 19:06:10 -0400 Subject: [PATCH 2/2] Relax tolerance --- test/test_gaussian.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 20d61b24..ee11d092 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -894,7 +894,7 @@ def test_eager_add(): @pytest.mark.parametrize("interp", [eager, lazy]) -def test_nested_subs_1(interp): +def test_nested_subs(interp): with interp: g = Gaussian(randn(3), randn(2, 3), OrderedDict([("b", Real), ("a", Real)])) a = ops.abs(Variable("aux_0", Real)) @@ -913,5 +913,5 @@ def test_nested_subs_1(interp): # Compare on ground data. subs = {"aux_0": randn(()), "aux_1": randn(())} - assert_close(g_ab(**subs), g_a_b(**subs)) - assert_close(g_ab(**subs), g_b_a(**subs)) + assert_close(g_ab(**subs), g_a_b(**subs), atol=1e-3, rtol=1e-3) + assert_close(g_ab(**subs), g_b_a(**subs), atol=1e-3, rtol=1e-3)