Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cascading flow surrogate posterior #1345

Open
wants to merge 59 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
e6a2713
fixed conflicts
gisilvs May 14, 2021
c501d2b
Revert "Revert "initial tests, updated init and build""
gisilvs May 14, 2021
b6be9d9
reverted commit
gisilvs May 14, 2021
dbf371b
Revert "removed cascading_flows from pr"
gisilvs May 14, 2021
c6118b1
reverted to latest version
gisilvs May 14, 2021
bcf95e1
fixed surrogate posterior type
gisilvs May 14, 2021
4d4b291
small fixes
gisilvs May 18, 2021
6cd8871
fixed global variables if no auxiliary variabled
gisilvs May 18, 2021
4690d0a
added number of layers parameter
gisilvs May 18, 2021
36ce254
fixed conflicts
gisilvs May 18, 2021
3b182aa
readded highway flow
gisilvs May 18, 2021
3e11546
fixed init
gisilvs May 20, 2021
2ff7130
removed highway flow from this branch
gisilvs May 20, 2021
4cc20e3
removed highway flow from this branch
gisilvs May 20, 2021
0b386c6
working on tests
gisilvs May 20, 2021
d8f4780
more testing
gisilvs May 20, 2021
79b0b99
pulled highway flow from master
gisilvs May 21, 2021
af9a5ba
fixed conflicts
gisilvs May 14, 2021
5cf8ce9
Revert "Revert "initial tests, updated init and build""
gisilvs May 14, 2021
755bca9
reverted commit
gisilvs May 14, 2021
bbc38a4
Revert "removed cascading_flows from pr"
gisilvs May 14, 2021
a89d60a
reverted to latest version
gisilvs May 14, 2021
ea80d7b
fixed surrogate posterior type
gisilvs May 14, 2021
cf11c70
small fixes
gisilvs May 18, 2021
d9e2828
fixed global variables if no auxiliary variabled
gisilvs May 18, 2021
80e8ee7
added number of layers parameter
gisilvs May 18, 2021
1360229
readded highway flow
gisilvs May 18, 2021
e1a2218
fixed init
gisilvs May 20, 2021
4f667ee
working on tests
gisilvs May 20, 2021
75d8b53
more testing
gisilvs May 20, 2021
65f5adb
small refsctoring and changed docstriings
gisilvs May 27, 2021
a2b025c
added dependency to build_trainable_highway_flow
gisilvs May 27, 2021
7bd8457
some refactoring
gisilvs May 28, 2021
17b5c6f
Merge remote-tracking branch 'upstream/master' into cascading_flow_su…
gisilvs May 28, 2021
78a18d8
merged conflicts
gisilvs May 28, 2021
ae34080
changed seed
gisilvs May 28, 2021
d0e287f
reverted to master
gisilvs May 28, 2021
1e9a486
removed substitution rule and updated dependencies
gisilvs Jun 3, 2021
6305084
removed sample_shape
gisilvs Jun 3, 2021
2f27b95
changed if statement and array slicing for value_out
gisilvs Jun 3, 2021
326a766
changed docstrings for target_dist
gisilvs Jun 14, 2021
1f295da
expanded cf to cascading flows and changed bijector
gisilvs Jun 14, 2021
487f7dd
removed testCFDistributionSubstitution
gisilvs Jun 14, 2021
e361a46
removed convex from name
gisilvs Jun 14, 2021
70ffe7b
fixed comment
gisilvs Jun 14, 2021
398d459
adjusted names
gisilvs Jun 15, 2021
2c44c48
fixed dimensions of prior
gisilvs Jun 15, 2021
5fb11ec
readded batchbroadcast
gisilvs Jun 16, 2021
5a20976
small fixes
gisilvs Jun 16, 2021
14e34ee
removed try except
gisilvs Jun 16, 2021
fa69f67
added support for distributions withc constrained support and test
gisilvs Jun 16, 2021
e296c6f
fixed output reshape
gisilvs Jun 18, 2021
55155e8
removed discrete test
gisilvs Jun 21, 2021
9eb8b97
working on batch shape
gisilvs Jun 21, 2021
0c0fb39
small bug fixed
gisilvs Jun 21, 2021
05c652e
Merge remote-tracking branch 'upstream/master' into cascading_flow_su…
gisilvs Jul 21, 2021
8d8777d
changed shapes to static and added auxiliary variables without global…
gisilvs Jul 23, 2021
d514f72
fixed constraining_bijector
gisilvs Aug 9, 2021
5eabcf8
working cf and cf with local aux vars
gisilvs Aug 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tensorflow_probability/python/experimental/vi/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ py_library(
srcs_version = "PY3",
deps = [
":automatic_structured_vi",
":cascading_flows",
":surrogate_posteriors",
"//tensorflow_probability/python/experimental/vi/util",
"//tensorflow_probability/python/internal:all_util",
Expand Down Expand Up @@ -67,6 +68,37 @@ py_library(
],
)

py_library(
name = "cascading_flows",
srcs = ["cascading_flows.py.py"],
srcs_version = "PY3",
deps = [
# tensorflow dep,
"//tensorflow_probability/python/bijectors:build_highway_flow_layer",
"//tensorflow_probability/python/bijectors:chain",
"//tensorflow_probability/python/bijectors:reshape",
"//tensorflow_probability/python/bijectors:scale",
"//tensorflow_probability/python/bijectors:shift",
"//tensorflow_probability/python/bijectors:split",
"//tensorflow_probability/python/distributions:batch_broadcast",
"//tensorflow_probability/python/distributions:beta",
"//tensorflow_probability/python/distributions:blockwise",
"//tensorflow_probability/python/distributions:chi2",
"//tensorflow_probability/python/distributions:exponential",
"//tensorflow_probability/python/distributions:gamma",
"//tensorflow_probability/python/distributions:half_normal",
"//tensorflow_probability/python/distributions:joint_distribution_auto_batched",
"//tensorflow_probability/python/distributions:joint_distribution_coroutine",
"//tensorflow_probability/python/distributions:normal",
"//tensorflow_probability/python/distributions:sample",
"//tensorflow_probability/python/distributions:transformed_distribution",
"//tensorflow_probability/python/distributions:truncated_normal",
"//tensorflow_probability/python/distributions:uniform",
"//tensorflow_probability/python/experimental/bijectors:build_trainable_highway_flow",
"//tensorflow_probability/python/internal:samplers",
],
)

py_library(
name = "surrogate_posteriors",
srcs = ["surrogate_posteriors.py"],
Expand Down Expand Up @@ -111,6 +143,22 @@ py_test(
],
)

py_test(
name = "cascading_flows_test",
size = "large",
srcs = ["cascading_flows_test.py"],
python_version = "PY3",
shard_count = 4,
srcs_version = "PY3",
deps = [
# absl/testing:parameterized dep,
# numpy dep,
# tensorflow dep,
"//tensorflow_probability",
"//tensorflow_probability/python/internal:test_util",
],
)

py_test(
name = "surrogate_posteriors_test",
size = "large",
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_probability/python/experimental/vi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tensorflow_probability.python.experimental.vi import util
from tensorflow_probability.python.experimental.vi.automatic_structured_vi import build_asvi_surrogate_posterior
from tensorflow_probability.python.experimental.vi.automatic_structured_vi import register_asvi_substitution_rule
from tensorflow_probability.python.experimental.vi.cascading_flows import build_cf_surrogate_posterior
from tensorflow_probability.python.experimental.vi.surrogate_posteriors import build_affine_surrogate_posterior
from tensorflow_probability.python.experimental.vi.surrogate_posteriors import build_affine_surrogate_posterior_from_base_distribution
from tensorflow_probability.python.experimental.vi.surrogate_posteriors import build_factored_surrogate_posterior
Expand All @@ -29,6 +30,7 @@
'build_affine_surrogate_posterior',
'build_affine_surrogate_posterior_from_base_distribution',
'build_asvi_surrogate_posterior',
'build_cf_surrogate_posterior',
'build_factored_surrogate_posterior',
'build_split_flow_surrogate_posterior',
'build_trainable_location_scale_distribution',
Expand Down
Loading