You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The biggest single issue blocking wider use of Funsor in Pyro and NumPyro right now is the incomplete coverage of distributions.
At a high level, the goal is to be able to perform all distribution operations that appear in any Pyro or NumPyro model (e.g. sampling and scoring) on Funsors directly, where distribution funsors are obtained from using to_funsor to automatically convert distributions to Funsors initially and using funsor.to_data to automatically convert the final results back to raw PyTorch/JAX objects. Wherever possible, the Funsor wrappers should also avoid the need for user-facing higher-order distributions, such as Independent or TransformedDistribution, in favor of idiomatic Funsor operations or broadcasting semantics.
We have gotten pretty far with the generic wrappers in funsor.distribution and funsor.{jax,torch}.distributions, but finishing the job and achieving full coverage of pyro.distributions remains a challenge because of the large distribution API and number of distributions, many small impedance mismatches and legacy design choices in PyTorch (e.g. the data type of Bernoulli), and difficulty of programmatic access and automation (e.g. no generic tool for constructing random valid instances of a distribution given a batch_shape).
I've tried to collect the remaining Funsor-specific tasks in this issue so that we can better measure progress toward this goal. We may also need to do additional work upstream in Pyro, NumPyro or PyTorch distributions.
Transforms and TransformedDistributions (some design discussion in #309):
Direct wrapping of TFP distributions in the JAX backend - this would probably involve a new subclass class TFPDistribution(funsor.distribution.Distribution) with TFP-specific implementations of _infer_value_domain and _infer_param_domain
Direct TFP Bijector wrappers
Atomic distribution computations beyond sampling and scoring implemented in the backend libraries:
Test harnesses for distribution wrappers (testing correctness of underlying distribution functionality here is out of scope - we are mostly interested in ensuring that results are converted to Funsors correctly):
Generic support for sampling via Funsor.sample() in funsor.pyro.FunsorDistribution - done properly, this may eliminate the need for first-class conjugate distribution implementations in backends, e.g. DirichletMultinomial
Some Funsor analogue of Distribution.expand() e.g. via lazy units as in Add lazy multidimensional Unit terms #235 or substitution of indexed Variable terms (f(v=v['i']))
Generic tests for conjugate distribution pairs
deprecate existing, case-by-case tests
The text was updated successfully, but these errors were encountered:
Conversion of TransformedDistributions to and from funsors on the JAX backend (can copy #365)
Note this will be a bit trickier than expected because in NumPyro Transform.inv is just a regular method, rather than returning an _InverseTransform, which NumPyro does not implement.
The biggest single issue blocking wider use of Funsor in Pyro and NumPyro right now is the incomplete coverage of distributions.
At a high level, the goal is to be able to perform all distribution operations that appear in any Pyro or NumPyro model (e.g. sampling and scoring) on Funsors directly, where distribution funsors are obtained from using
to_funsor
to automatically convert distributions to Funsors initially and usingfunsor.to_data
to automatically convert the final results back to raw PyTorch/JAX objects. Wherever possible, the Funsor wrappers should also avoid the need for user-facing higher-order distributions, such asIndependent
orTransformedDistribution
, in favor of idiomatic Funsor operations or broadcasting semantics.We have gotten pretty far with the generic wrappers in
funsor.distribution
andfunsor.{jax,torch}.distributions
, but finishing the job and achieving full coverage ofpyro.distributions
remains a challenge because of the large distribution API and number of distributions, many small impedance mismatches and legacy design choices in PyTorch (e.g. the data type ofBernoulli
), and difficulty of programmatic access and automation (e.g. no generic tool for constructing random valid instances of a distribution given abatch_shape
).I've tried to collect the remaining Funsor-specific tasks in this issue so that we can better measure progress toward this goal. We may also need to do additional work upstream in Pyro, NumPyro or PyTorch distributions.
Transforms and
TransformedDistribution
s (some design discussion in #309):Transform
s in PyroTransformedDistribution
s to and fromfunsor.Distributions
on the PyTorch backendTransform
sTransformedDistribution
s to and from funsors on the JAX backend (can copy Support conversion of Transforms and TransformedDistributions to and from Funsors #365)Transform
s likeAbsTransform
Transform
s with ground parameters, notablyAffineTransform
andPowerTransform
withbatch_shape == ()
Transform
with non-ground parameters, e.g.AffineTransform
withlen(batch_shape) > 0
ConditionalTransform
s andTransformedModule
s in Pyro/PyTorchOther basic distribution modifiers:
to_funsor
conversion for custom distributionsDelta
conversion workingfunsor.Independent
to and fromIndependent
distributionsIndependent
distributions in FunsorExpandedDistribution
s to funsorsExpandedDistribution
s to funsors_IndependentConstraint
arg_constraints
IndependentDistribution
s directly to base Funsor distributionsMaskedMixture
andMixtureSameFamily
Masking:
mask is False
) (discussed in Support mask=False for backend MaskedDistribution #459)Direct TFP distribution wrappers:
class TFPDistribution(funsor.distribution.Distribution)
with TFP-specific implementations of_infer_value_domain
and_infer_param_domain
Bijector
wrappersAtomic distribution computations beyond sampling and scoring implemented in the backend libraries:
TorchDistribution.mean
)TorchDistribution.variance
)Test harnesses for distribution wrappers (testing correctness of underlying distribution functionality here is out of scope - we are mostly interested in ensuring that results are converted to Funsors correctly):
to_funsor
andto_data
enumerate_support
)Miscellaneous:
Bernoulli
(discussed in Casting between real and bint #348)Funsor.sample()
infunsor.pyro.FunsorDistribution
- done properly, this may eliminate the need for first-class conjugate distribution implementations in backends, e.g.DirichletMultinomial
Lower priority, possibly unnecessary:
Distribution.expand()
e.g. via lazy units as in Add lazy multidimensional Unit terms #235 or substitution of indexedVariable
terms (f(v=v['i'])
)The text was updated successfully, but these errors were encountered: