-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add factory for dynamic TransformOp #427
Conversation
Recording discussion from slack:
|
Yes, I think I better understand your point now. This PR now adds a default EDIT to get this free behavior, you will need to add a |
@fritzo This will resolve current issues at TransformedDistribution and init_strategy so I want to make a PR for it. But I am a bit worried if it will be enough for funsor, in complicated situations like normalization flows that Eli pointed out in our last discussions. Do we need batch shape or something else? Btw, can you change Travis setting to let me rerun it? Currently, I can rerun jobs in pyro, numpyro but not in funsor. |
I believe
That would be sufficient metadata to implement
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, the implementation fits my intuition (for now, I couldn't evaluate the caching and weakref details).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code all looks good to me, but it would be good to add new TransformedDistribution
test cases to our generic distribution tests to make sure sampling and other methods also behave correctly. If getting those working turns out to be a hassle we can defer to a followup PR.
|
||
|
||
@pytest.fixture | ||
def dist(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a useful pattern, we may want to copy it in the distribution tests in a later PR.
True, | ||
xfail_param(False, reason="bug in to_funsor(TransformedDistribution)"), | ||
]) | ||
def test_haar_transform(shape, to_event): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto: try adding a HaarTransform
test case to test/test_distribution_generic.py
to test sampling, to_funsor
/to_data
conversion and other distribution methods?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you suggest a fix for this broken test_generic_distribution_to_funsor()
?
(note to run you'll need the infer-shapes branch of Pyro)
$ FUNSOR_BACKEND=torch pytest -vx test/test_distribution_generic.py -k Haar --pdb
===================================================== test session starts ======================================================
platform darwin -- Python 3.7.0, pytest-6.1.2, py-1.9.0, pluggy-0.13.1 -- /Users/fobermey/opt/miniconda3/envs/pyro/bin/python
cachedir: .pytest_cache
benchmark: 3.2.3 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /Users/fobermey/github/pyro-ppl/funsor, configfile: setup.cfg
plugins: forked-1.2.0, nbval-0.9.6, xdist-2.1.0, benchmark-3.2.3
collected 2244 items / 2211 deselected / 33 selected
test/test_distribution_generic.py::test_generic_distribution_to_funsor[dist.TransformedDistribution( dist.Normal(loc=case.loc, scale=1.).to_event(1), dist.transforms.HaarTransform(dim=-1)) (('loc', 'rand(() + (3,))'),)] FAILED [ 3%]
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> traceback >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
case = <test.test_distribution_generic.DistTestCase object at 0x7fb9b08c8cc0>
@pytest.mark.parametrize("case", TEST_CASES, ids=str)
def test_generic_distribution_to_funsor(case):
HIGHER_ORDER_DISTS = [
backend_dist.Independent,
backend_dist.TransformedDistribution,
] + ([backend_dist.torch_distribution.ExpandedDistribution] if get_backend() == "torch"
else [backend_dist.ExpandedDistribution])
raw_dist = case.get_dist()
expected_value_domain = case.expected_value_domain
dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape)
with interpretation(normalize_with_subs):
funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name)
assert funsor_dist.inputs["value"] == expected_value_domain
while isinstance(funsor_dist, funsor.cnf.Contraction):
funsor_dist = [term for term in funsor_dist.terms
> if isinstance(term, (funsor.distribution.Distribution, funsor.terms.Independent))][0]
E IndexError: list index out of range
test/test_distribution_generic.py:590: IndexError
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> entering PDB >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> PDB post_mortem (IO-capturing turned off) >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
> /Users/fobermey/github/pyro-ppl/funsor/test/test_distribution_generic.py(590)test_generic_distribution_to_funsor()
-> if isinstance(term, (funsor.distribution.Distribution, funsor.terms.Independent))][0]
(Pdb) print(funsor_dist.pretty())
Contraction(ops.nullop, ops.add,
frozenset(),
(Unary(ops.neg,
Binary(ops.log_abs_det_jacobian,
│Unary(ops._InverseTransform,
│ Variable('value', Reals[3])),
│Variable('value', Reals[3]))),
Contraction(ops.add, ops.nullop,
frozenset({Variable('_pyro_event_dim_-1_...
(Normal(
│ Tensor(
│ torch.tensor([0.4366315007209778, 0.8093...
│ (('_pyro_event_dim_-1__BOUND_1',
│ │Bint[3, ],),),
│ 'real'),
│ Tensor(
│ torch.tensor([1.0, 1.0, 1.0], dtype=torc...
│ (('_pyro_event_dim_-1__BOUND_1',
│ │Bint[3, ],),),
│ 'real'),
│ Binary(ops.GetitemOp(0),
│ Unary(ops._InverseTransform,
│ Variable('value', Reals[3])),
│ Variable('_pyro_event_dim_-1__BOUND_1', ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the underlying reason is that the way we use a lazy funsor.terms.Independent
to convert IndependentDistribution
s to funsors (via the pattern funsor.distribution.indepdist_to_funsor
) is sort of fiddly and affects pattern matching, especially when combined with transforms. Now that #402 is done we should try to represent them directly by ops.expand
ing their parameters to the appropriate event_shape
. I'll see if I can get this working in a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
pair coded with @eb8680 @fehiepsi @ordabayevy
This introduces classes
WrappedTransformOp
andLogAbsDetJacobianOp
for dynamically created ops from backendTransform
objects andTransformedDistribution
s. Unlike statically created ops, instances of these ops will never have custom rules; hence we do not create unique subclasses for each new op.I have also refactored
Op
andCachedOpMeta
andWrappedOpMeta
to more finely control whether and how we cache op instance creation.Forwards compatibility
This PR xfails until the following PRs are implemented upstream. I have tested locally with these changes. It is safe to merge this PR before the upstream changes have merged, because the upstream interfaces are pretty stable:
Constraint.event_dim
Transform.forward_shape()
Constraint.event_dim
andTransform.forward_shape()
Constraint.event_dim
Transform.forward_shape()
Tested
WrappedTransformOp
andLogAbsDetJacobianOp
WrappedTransformOp
PowerTransform
(passes locally)HaarTransorm
(passes locally)