Skip to content

Commit

Permalink
Merge pull request #1270 from gisilvs:cascading_flow_vi
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 374964013
  • Loading branch information
tensorflower-gardener committed May 20, 2021
2 parents 5838e0e + 430d151 commit 4d8ce35
Show file tree
Hide file tree
Showing 4 changed files with 621 additions and 1 deletion.
32 changes: 32 additions & 0 deletions tensorflow_probability/python/experimental/bijectors/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ multi_substrate_py_library(
srcs_version = "PY3",
deps = [
":distribution_bijectors",
":highway_flow",
":scalar_function_with_inferred_inverse",
":sharded",
"//tensorflow_probability/python/bijectors:ldj_ratio",
Expand Down Expand Up @@ -104,6 +105,37 @@ multi_substrate_py_test(
],
)

multi_substrate_py_library(
name = "highway_flow",
srcs = ["highway_flow.py"],
srcs_version = "PY3",
deps = [
":scalar_function_with_inferred_inverse",
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/bijectors",
"//tensorflow_probability/python/internal:samplers",
"//tensorflow_probability/python/util",
],
)

multi_substrate_py_test(
name = "highway_flow_test",
size = "medium",
srcs = ["highway_flow_test.py"],
disabled_substrates = ["numpy"],
jax_size = "medium",
python_version = "PY3",
srcs_version = "PY3",
deps = [
# numpy dep,
# tensorflow dep,
"//tensorflow_probability",
"//tensorflow_probability/python/bijectors:bijector_test_util",
"//tensorflow_probability/python/internal:test_util",
],
)

multi_substrate_py_library(
name = "sharded",
srcs = ["sharded.py"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
from tensorflow_probability.python.bijectors.ldj_ratio import forward_log_det_jacobian_ratio
from tensorflow_probability.python.bijectors.ldj_ratio import inverse_log_det_jacobian_ratio
from tensorflow_probability.python.experimental.bijectors.distribution_bijectors import make_distribution_bijector
from tensorflow_probability.python.experimental.bijectors.highway_flow import build_trainable_highway_flow
from tensorflow_probability.python.experimental.bijectors.highway_flow import HighwayFlow
from tensorflow_probability.python.experimental.bijectors.scalar_function_with_inferred_inverse import ScalarFunctionWithInferredInverse
from tensorflow_probability.python.experimental.bijectors.sharded import Sharded


__all__ = [
'build_trainable_highway_flow',
'forward_log_det_jacobian_ratio',
'inverse_log_det_jacobian_ratio',
'make_distribution_bijector',
'HighwayFlow',
'ScalarFunctionWithInferredInverse',
'Sharded',
]
Loading

0 comments on commit 4d8ce35

Please sign in to comment.