From a5b0d95fa622f1e13a2ea449b98325915ea129a3 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Tue, 18 May 2021 13:18:51 -0400 Subject: [PATCH 1/6] Update to work with Gen #417 --- mle/run_all.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mle/run_all.jl b/mle/run_all.jl index 619f4b5..2b08845 100644 --- a/mle/run_all.jl +++ b/mle/run_all.jl @@ -8,15 +8,15 @@ end load_generated_functions() -init_param!(foo, :mu, -1) +init_parameter!((foo, :mu), -1) trace, = generate(foo, (), choicemap((:y, 3))) step_size = 0.01 for iter=1:1000 accumulate_param_gradients!(trace, 0.) - grad_val = get_param_grad(foo, :mu) - set_param!(foo, :mu, get_param(foo, :mu) + step_size * grad_val) - zero_param_grad!(foo, :mu) + gradient = get_gradient((foo, :mu)) + value = get_parameter_value((foo, :mu)) + init_parameter!((foo, :mu), vlaue + step_size * gradient) end @assert abs(get_param(foo, :mu) - 3) < 1e-2 From d5b0db2ce1fae4e190bb8adff2ebb342f4cc29da Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Wed, 19 May 2021 15:24:42 -0400 Subject: [PATCH 2/6] reweighted wake sleep MNIST example --- .travis.yml | 3 +- rws_mnist/Manifest.toml | 427 ++++++++++++++++++++++++++++++++++++++ rws_mnist/Project.toml | 3 + rws_mnist/mnist.jl | 25 +++ rws_mnist/model.jl | 153 ++++++++++++++ rws_mnist/run.jl | 120 +++++++++++ rws_mnist/run.sh | 11 + rws_mnist/rws.jl | 71 +++++++ rws_mnist/sbn.jl | 180 ++++++++++++++++ rws_mnist/sgd_momentum.jl | 48 +++++ 10 files changed, 1040 insertions(+), 1 deletion(-) create mode 100644 rws_mnist/Manifest.toml create mode 100644 rws_mnist/Project.toml create mode 100644 rws_mnist/mnist.jl create mode 100644 rws_mnist/model.jl create mode 100644 rws_mnist/run.jl create mode 100755 rws_mnist/run.sh create mode 100644 rws_mnist/rws.jl create mode 100644 rws_mnist/sbn.jl create mode 100644 rws_mnist/sgd_momentum.jl diff --git a/.travis.yml b/.travis.yml index 4677edd..c705b5c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,10 +3,11 @@ language: julia julia: - - 1.5 + - 1.6 # run these in parallel env: + - GEN_EXAMPLE=rws_mnist JULIA_NUM_THREADS=2 OPENBLAS_NUM_THREADS=1 - GEN_EXAMPLE=gp_structure GKS_ENCODING="utf8" GKSwstype="100" - GEN_EXAMPLE=regression - GEN_EXAMPLE=involutive_mcmc diff --git a/rws_mnist/Manifest.toml b/rws_mnist/Manifest.toml new file mode 100644 index 0000000..ef4c042 --- /dev/null +++ b/rws_mnist/Manifest.toml @@ -0,0 +1,427 @@ +# This file is machine-generated - editing it directly is not advised + +[[ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" + +[[Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[Blosc]] +deps = ["Blosc_jll"] +git-tree-sha1 = "84cf7d0f8fd46ca6f1b3e0305b4b4a37afe50fd6" +uuid = "a74b3585-a348-5f62-a45c-50e91977d574" +version = "0.7.0" + +[[Blosc_jll]] +deps = ["Libdl", "Lz4_jll", "Pkg", "Zlib_jll", "Zstd_jll"] +git-tree-sha1 = "aa9ef39b54a168c3df1b2911e7797e4feee50fbe" +uuid = "0b7ba130-8d10-5ba8-a3d6-c5182647fed9" +version = "1.14.3+1" + +[[ChainRulesCore]] +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "0.9.44" + +[[CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[Compat]] +deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] +git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "3.30.0" + +[[CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" + +[[DataAPI]] +git-tree-sha1 = "dfb3b7e89e395be1e25c2ad6d7690dc29cc53b1d" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.6.0" + +[[DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.9" + +[[Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[DelimitedFiles]] +deps = ["Mmap"] +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" + +[[DiffResults]] +deps = ["StaticArrays"] +git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.0.3" + +[[DiffRules]] +deps = ["NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "214c3fcac57755cfda163d91c58893a8723f93e9" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.0.2" + +[[Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[Distributions]] +deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"] +git-tree-sha1 = "a837fdf80f333415b69684ba8e8ae6ba76de6aaa" +uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" +version = "0.24.18" + +[[DocStringExtensions]] +deps = ["LibGit2", "Markdown", "Pkg", "Test"] +git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.8.4" + +[[Downloads]] +deps = ["ArgTools", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" + +[[FileIO]] +deps = ["Pkg", "Requires", "UUIDs"] +git-tree-sha1 = "cfb694feaddf4f0381ef3cc9d4c0d8fc6b7e2ea7" +uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +version = "1.9.0" + +[[FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays"] +git-tree-sha1 = "31939159aeb8ffad1d4d8ee44d07f8558273120a" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "0.11.7" + +[[ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Printf", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "e2af66012e08966366a43251e1fd421522908be6" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.18" + +[[FunctionWrappers]] +git-tree-sha1 = "241552bc2209f0fa068b6415b1942cc0aa486bcc" +uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" +version = "1.1.2" + +[[FunctionalCollections]] +deps = ["Test"] +git-tree-sha1 = "04cb9cfaa6ba5311973994fe3496ddec19b6292a" +uuid = "de31a74c-ac4f-5751-b3fd-e18cd04993ca" +version = "0.5.0" + +[[Gen]] +deps = ["Compat", "DataStructures", "Distributions", "ForwardDiff", "FunctionalCollections", "JSON", "LinearAlgebra", "MacroTools", "Parameters", "Random", "ReverseDiff", "SpecialFunctions"] +git-tree-sha1 = "b6cd202d1dc5020e5c9a300338413010e3aed5c9" +repo-rev = "20210512-marcoct-gradopts" +repo-url = "https://github.com/probcomp/Gen.jl.git" +uuid = "ea4f424c-a589-11e8-07c0-fd5c91b9da4a" +version = "0.4.3" + +[[HDF5]] +deps = ["Blosc", "Compat", "HDF5_jll", "Libdl", "Mmap", "Random", "Requires"] +git-tree-sha1 = "1d18a48a037b14052ca462ea9d05dee3ac607d23" +uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" +version = "0.15.5" + +[[HDF5_jll]] +deps = ["Artifacts", "JLLWrappers", "LibCURL_jll", "Libdl", "OpenSSL_jll", "Pkg", "Zlib_jll"] +git-tree-sha1 = "fd83fa0bde42e01952757f01149dd968c06c4dba" +uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" +version = "1.12.0+1" + +[[InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[JLD]] +deps = ["FileIO", "HDF5", "Printf"] +git-tree-sha1 = "1d291ba1730de859903b480e6f85a0dc40c19dcb" +uuid = "4138dd39-2aa7-5051-a626-17a0bb65d9c8" +version = "0.12.3" + +[[JLLWrappers]] +deps = ["Preferences"] +git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.3.0" + +[[JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.1" + +[[LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" + +[[LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" + +[[LibGit2]] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" + +[[Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[LogExpFunctions]] +deps = ["DocStringExtensions", "LinearAlgebra"] +git-tree-sha1 = "1ba664552f1ef15325e68dc4c05c3ef8c2d5d885" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.2.4" + +[[Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[Lz4_jll]] +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "51b1db0732bbdcfabb60e36095cc3ed9c0016932" +uuid = "5ced341a-0733-55b8-9ab6-a4889d929147" +version = "1.9.2+2" + +[[MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.6" + +[[Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" + +[[Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "4ea90bd5d3985ae1f9a908bd4500ae88921c5ce7" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.0.0" + +[[Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" + +[[NaNMath]] +git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.5" + +[[NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" + +[[OpenSSL_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "71bbbc616a1d710879f5a1021bcba65ffba6ce58" +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "1.1.1+6" + +[[OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "b9b8b8ed236998f91143938a760c2112dceeb2b4" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.4+0" + +[[OrderedCollections]] +git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.4.1" + +[[PDMats]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "f82a0e71f222199de8e9eb9a09977bd0767d52a0" +uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" +version = "0.11.0" + +[[Parameters]] +deps = ["OrderedCollections", "UnPack"] +git-tree-sha1 = "2276ac65f1e236e0a6ea70baff3f62ad4c625345" +uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" +version = "0.12.2" + +[[Parsers]] +deps = ["Dates"] +git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "1.1.0" + +[[Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[Preferences]] +deps = ["TOML"] +git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.2.2" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[QuadGK]] +deps = ["DataStructures", "LinearAlgebra"] +git-tree-sha1 = "12fbe86da16df6679be7521dfb39fbc861e1dc7b" +uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +version = "2.4.1" + +[[REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[Random]] +deps = ["Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.1.3" + +[[ReverseDiff]] +deps = ["DiffResults", "DiffRules", "ForwardDiff", "FunctionWrappers", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "SpecialFunctions", "StaticArrays", "Statistics"] +git-tree-sha1 = "63ee24ea0689157a1113dbdab10c6cb011d519c4" +uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +version = "1.9.0" + +[[Rmath]] +deps = ["Random", "Rmath_jll"] +git-tree-sha1 = "bf3188feca147ce108c76ad82c2792c57abe7b1f" +uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" +version = "0.7.0" + +[[Rmath_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "68db32dff12bb6127bac73c209881191bf0efbb7" +uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" +version = "0.3.0+0" + +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "2ec1962eba973f383239da22e75218565c390a96" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.0.0" + +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[SpecialFunctions]] +deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"] +git-tree-sha1 = "9146da51b38e9705b9f5ccfadc3ab10a482cae36" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "1.4.0" + +[[StaticArrays]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "c635017268fd51ed944ec429bcc4ad010bcea900" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.2.0" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[StatsAPI]] +git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.0.0" + +[[StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "2f6792d523d7448bbe2fec99eca9218f06cc746d" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.33.8" + +[[StatsFuns]] +deps = ["LogExpFunctions", "Rmath", "SpecialFunctions"] +git-tree-sha1 = "30cd8c360c54081f806b1ee14d2eecbef3c04c49" +uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +version = "0.9.8" + +[[SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" + +[[Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" + +[[Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[UnPack]] +git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" +uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" +version = "1.0.2" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" + +[[Zstd_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "cc4bf3fdde8b7e3e9fa0351bdeedba1cf3b7f6e6" +uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" +version = "1.5.0+0" + +[[nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" + +[[p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" diff --git a/rws_mnist/Project.toml b/rws_mnist/Project.toml new file mode 100644 index 0000000..0e1c6a0 --- /dev/null +++ b/rws_mnist/Project.toml @@ -0,0 +1,3 @@ +[deps] +Gen = "ea4f424c-a589-11e8-07c0-fd5c91b9da4a" +JLD = "4138dd39-2aa7-5051-a626-17a0bb65d9c8" diff --git a/rws_mnist/mnist.jl b/rws_mnist/mnist.jl new file mode 100644 index 0000000..59839c6 --- /dev/null +++ b/rws_mnist/mnist.jl @@ -0,0 +1,25 @@ +import JLD + +function to_bool_vectors(int_mat::Matrix) + data = Vector{Vector{Bool}}() + num_data = size(int_mat)[1] + @assert size(int_mat)[2] == 28 * 28 + for i in 1:num_data + push!(data, Vector{Bool}(int_mat[i,:] .== 1)) + end + return data +end + +function get_mnist_data() + + data = JLD.jldopen("$(@__DIR__)/mnist_salakhutdinov.jld") + + train_x = to_bool_vectors(read(data, "train_x")) + test_x = to_bool_vectors(read(data, "test_x")) + valid_x = to_bool_vectors(read(data, "valid_x")) + + train_data = train_x + test_data = vcat(test_x, valid_x) + + return (train_data, test_data) +end diff --git a/rws_mnist/model.jl b/rws_mnist/model.jl new file mode 100644 index 0000000..b8d1df0 --- /dev/null +++ b/rws_mnist/model.jl @@ -0,0 +1,153 @@ +using Gen: @gen, @param, @load_generated_functions + +@gen (static) function p() + + # prior on the fourth hidden layer + @param h3_b::Vector{Float32} + + # third hidden layer + @param h2_W::Matrix{Float32} + @param h2_b::Vector{Float32} + + # third hidden layer + @param h1_W::Matrix{Float32} + @param h1_b::Vector{Float32} + + # visible layer + @param x_W::Matrix{Float32} + @param x_b::Vector{Float32} + + # TODO issue -- Gen wraps h3_n in a ReverseDiff, which causes a failure + h3_n = 10 + #h3_n = length(h3_b) + + #h2_n = length(h2_b) + #h1_n = length(h1_b) + #h1_n = 200 + #x_n = length(x_b) + + # sample third hidden layer (10 units) + h3 ~ sigmoid_belief_network([false], h3_b, zeros(Float32, h3_n, 1)) + + # sample second hidden layer (200 units) + #@assert size(h2_W) == (h2_n, h3_n) # TODO @assert is not supported in static modeling language + h2 ~ sigmoid_belief_network(h3, h2_b, h2_W) + + # sample first hidden layer (200 units) + #@assert size(h1_W) == (h1_n, h2_n) + h1 ~ sigmoid_belief_network(h2, h1_b, h1_W) + #h1 ~ sigmoid_belief_network([false], h1_b, zeros(Float32, h1_n, 1)) + + # sample visible layer (28x28=784 units) + #@assert size(x_W) == (x_n, h1_n) + x ~ sigmoid_belief_network(h1, x_b, x_W) + + return nothing +end + +@gen (static) function q(x::Vector{Bool}) + #@assert sum(x) > 0 + + # first hidden layer + @param h1_W::Matrix{Float32} + @param h1_b::Vector{Float32} + + # second hidden layer + @param h2_W::Matrix{Float32} + @param h2_b::Vector{Float32} + + # third hidden layer + @param h3_W::Matrix{Float32} + @param h3_b::Vector{Float32} + + #x_n = length(x) + #h1_n = length(h1_b) + #h2_n = length(h2_b) + #h3_n = length(h3_b) + + # sample first hidden layer (200 units) + #@assert size(h1_W) == (h1_n, x_n) + h1 ~ sigmoid_belief_network(x, h1_b, h1_W) + + # sample second hidden layer (200 units) + #@assert size(h2_W) == (h2_n, h1_n) + h2 ~ sigmoid_belief_network(h1, h2_b, h2_W) + + # sample third hidden layer (10 units) + #@assert size(h3_W) == (h3_n, h2_n) + h3 ~ sigmoid_belief_network(h2, h3_b, h3_W) + + return nothing +end + +@load_generated_functions() + +import Random +using Gen: apply_update!, init_parameter!, get_parameter_value + +# see https://github.com/jbornschein/reweighted-ws/blob/e96414719d09ab4941dc77bab4cf4847acc6a8e7/learning/models/sbn.py#L94 +# and https://github.com/jbornschein/reweighted-ws/blob/e96414719d09ab4941dc77bab4cf4847acc6a8e7/learning/model.py#L22 +function make_W(n_output, n_input) + scale = Float32(sqrt(6f0) / sqrt(n_input + n_output)) + return scale * ((2*rand(Float32, n_output, n_input)) .- 1.0f0) / n_input +end + +# see https://github.com/jbornschein/reweighted-ws/blob/e96414719d09ab4941dc77bab4cf4847acc6a8e7/learning/models/sbn.py#L94 +make_b(n) = -ones(Float32, n) + +function initialize_p_params!(; x_n=28*28, h1_n=200, h2_n=200, h3_n=10) + init_parameter!((p, :h3_b), make_b(h3_n)) + + init_parameter!((p, :h2_W), make_W(h2_n, h3_n)) + init_parameter!((p, :h2_b), make_b(h2_n)) + + init_parameter!((p, :h1_W), make_W(h1_n, h2_n)) + init_parameter!((p, :h1_b), make_b(h1_n)) + + init_parameter!((p, :x_W), make_W(x_n, h1_n)) + init_parameter!((p, :x_b), make_b(x_n)) +end + +function initialize_q_params!(; x_n=28*28, h1_n=200, h2_n=200, h3_n=10) + init_parameter!((q, :h1_W), make_W(h1_n, x_n)) + init_parameter!((q, :h1_b), make_b(h1_n)) + + init_parameter!((q, :h2_W), make_W(h2_n, h1_n)) + init_parameter!((q, :h2_b), make_b(h2_n)) + + init_parameter!((q, :h3_W), make_W(h3_n, h2_n)) + init_parameter!((q, :h3_b), make_b(h3_n)) +end + +import Serialization + +function save_params(p, q, metadata, filename) + println("saving params to $filename") + p_params_dict = Dict( + name => Gen.get_parameter_value((p, name)) + for (_, name) in Gen.get_parameters(p, Gen.default_parameter_context)[Gen.default_julia_parameter_store]) + q_params_dict = Dict( + name => Gen.get_parameter_value((q, name)) + for (_, name) in Gen.get_parameters(q, Gen.default_parameter_context)[Gen.default_julia_parameter_store]) + data = Dict( + "metadata" => metadata, + "q_params" => q_params_dict, + "p_params" => p_params_dict) + Serialization.serialize(filename, data) + return nothing +end + +function load_params!(p, q, filename) + println("loading params from $filename") + data = Serialization.deserialize(filename) + println("got metadata: $(data["metadata"])") + for name in keys(data["p_params"]) + init_parameter!((p, name), data["p_params"][name]) + println("$name: $(size(Gen.get_parameter_value(p, name)))") + end + for name in keys(data["q_params"]) + init_parameter!((q, name), data["q_params"][name]) + println("$name: $(size(Gen.get_parameter_value(q, name)))") + end + return data["metadata"] +end diff --git a/rws_mnist/run.jl b/rws_mnist/run.jl new file mode 100644 index 0000000..ec62241 --- /dev/null +++ b/rws_mnist/run.jl @@ -0,0 +1,120 @@ +include("$(@__DIR__)/sbn.jl") +include("$(@__DIR__)/model.jl") +include("$(@__DIR__)/rws.jl") +include("$(@__DIR__)/sgd_momentum.jl") +include("$(@__DIR__)/mnist.jl") + +import Random + +function get_minibatch(data::Vector{Vector{Bool}}, minibatch_size::Int) + n = length(data) + idx = Random.randperm(n)[1:minibatch_size] + return data[idx] +end + +import Gen + +function datum_to_choicemap(datum::Vector{Bool}) + choices = Gen.choicemap() + choices[:x => :outputs] = datum + return choices +end + +function estimate_log_marginal_likelihood(data::Vector{Vector{Bool}}, est_particles) + n = length(data) + log_ml_estimates = Vector{Float64}(undef, n) + Threads.@threads for i in 1:n + datum = data[i] + (_, log_ml_estimates[i]) = Gen.importance_resampling( + p, (), datum_to_choicemap(datum), q, + (datum,), est_particles) + end + return sum(log_ml_estimates) / n +end + +function do_checkpoint(train_data, test_data, est_size, est_particles, save_to, iter) + train_est_batch = get_minibatch(train_data, est_size) + train_lml_est = estimate_log_marginal_likelihood(train_est_batch, est_particles) + println("iter $iter, train set LML estimate: $train_lml_est") + if !isnothing(save_to) + metadata = Dict() + save_params(p, q, metadata, save_to) + end +end + +function train( + train_data, test_data, iters::Int; + est_size=256, est_particles=10000, + save_to=nothing, + checkpoint_period=10, + num_particles=5, + minibatch_size=24, + momentum_beta=0.95f0, + learning_rate=0.001f0) + + # training config + conf = SGDWithMomentumConf(momentum_beta, learning_rate) + p_update = Gen.init_optimizer(conf, p) + q_update = Gen.init_optimizer(conf, q) + + # do training using the whole data set as a batch + for iter in 1:iters + + # checkpoint + if checkpoint_period != Inf && (((iter-1) % checkpoint_period == 0) || (iter == iters)) + println("checkpointing..") + do_checkpoint(train_data, test_data, est_size, est_particles, save_to, iter) + end + + # select a minibatch + minibatch = get_minibatch(train_data, minibatch_size) + + # NOTE be sure to set environment variable OPENBLAS_NUM_THREADS=1 + # and set number of threads with e.g. JULIA_NUM_THREADS=8 + Threads.@threads for datum in minibatch + + # compute stochastic gradient estimates + reweighted_wake_sleep_gradients!( + p, (), q, (), datum; + scale_gradient=1.0f0/minibatch_size, + data_to_choicemap=datum_to_choicemap, + num_particles=num_particles, + do_q_wake_phase=true, + do_q_sleep_phase=true, + get_data=(p_trace) -> p_trace[:x => :outputs]::Vector{Bool}, + get_latent_choices=(p_trace) -> get_choices(p_trace)) + + end + + # apply updates and reset the gradient estimates + apply_update!(p_update) + apply_update!(q_update) + end +end + +function do_training() + + Random.seed!(9) + + (train_data, test_data) = get_mnist_data() + + initialize_p_params!() + initialize_q_params!() + + max_epochs = 1 # one epoch is about 2000 iterations of SGD + iters_per_epoch = Int(round(50000 / 24)) + checkpoint_period = Int(round(10 * iters_per_epoch)) # every 10 epochs + iters = Int(round((iters_per_epoch * max_epochs))) + println("iters_per_epoch: $iters_per_epoch, running for $iters iters...") + train(train_data, test_data, iters; + save_to="params.jls", + num_particles=5, + checkpoint_period=1000,#checkpoint_period, + est_particles=100, + momentum_beta=0.95f0, + learning_rate=1f-3, + minibatch_size=32) + +end + +do_training() diff --git a/rws_mnist/run.sh b/rws_mnist/run.sh new file mode 100755 index 0000000..576ff11 --- /dev/null +++ b/rws_mnist/run.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -u +set -e + + +cd ${GEN_EXAMPLE} +if [ ! -f mnist_salakhutdinov.jld ]; then + wget https://gen-examples.s3.us-east-2.amazonaws.com/mnist_salakhutdinov.jld +fi +cd .. +julia -e 'using Pkg; Pkg.update(); dir=ENV["GEN_EXAMPLE"]; include("$dir/run.jl")' diff --git a/rws_mnist/rws.jl b/rws_mnist/rws.jl new file mode 100644 index 0000000..583b0dc --- /dev/null +++ b/rws_mnist/rws.jl @@ -0,0 +1,71 @@ +import Gen +using Gen: choicemap, ChoiceMap, simulate, generate, get_score, get_choices +using Gen: accumulate_param_gradients! + +function reweighted_wake_sleep_gradients!( + p, p_args, q, q_args, data; + scale_gradient=1.0, + data_to_choicemap=(data::ChoiceMap) -> data, + num_particles=5, + do_q_wake_phase=true, + do_q_sleep_phase=true, + get_data=nothing, + get_latent_choices=nothing) + + if do_q_sleep_phase && (isnothing(get_data) || isnothing(get_latent_choices)) + error("do_q_sleep_phase was set but either get_data or get_latent_choices was not provided") + end + + data_choicemap = data_to_choicemap(data) + + # run q multiple times + q_traces = Vector{Any}(undef, num_particles) + p_traces = Vector{Any}(undef, num_particles) + log_weights = Vector{Float64}(undef, num_particles) + Threads.@threads for i in 1:num_particles + q_traces[i] = simulate(q, (q_args..., data)) + (p_traces[i], should_be_score) = generate( + p, p_args, merge(data_choicemap, get_choices(q_traces[i]))) + @assert isapprox(should_be_score, get_score(p_traces[i])) + log_weights[i] = get_score(p_traces[i]) - get_score(q_traces[i]) + end + @assert !any(isnan.(log_weights)) + + # *** synchronization barrier *** + + # normalize weights + (_, log_normalized_weights) = Gen.normalize_weights(log_weights) + normalized_weights = exp.(log_normalized_weights) + @assert isapprox(sum(normalized_weights), 1.0) + + @sync begin + + # wake-phase update of p + for i in 1:num_particles + Threads.@spawn accumulate_param_gradients!( + p_traces[i], nothing, Float32(normalized_weights[i] * scale_gradient)) + end + + # wake-phase update of q + if do_q_wake_phase + for i in 1:num_particles + Threads.@spawn accumulate_param_gradients!( + q_traces[i], nothing, Float32(normalized_weights[i] * scale_gradient)) + end + end + + # sleep-phase update of q + Threads.@spawn if do_q_sleep_phase + p_trace = simulate(p, p_args) + (q_trace, should_be_score) = generate( + q, (q_args..., get_data(p_trace)), get_latent_choices(p_trace)) + @assert isapprox(should_be_score, get_score(q_trace)) + accumulate_param_gradients!(q_trace, nothing, Float32(scale_gradient)) + end + + # *** synchronization barrier *** + + end # @sync +end + +# TODO we need to add an optional check to generate that all constraints were visited diff --git a/rws_mnist/sbn.jl b/rws_mnist/sbn.jl new file mode 100644 index 0000000..4f66041 --- /dev/null +++ b/rws_mnist/sbn.jl @@ -0,0 +1,180 @@ +# sigmoid belief networks + +import Gen + +sigmoid(x) = 1.0f0 ./ (1.0f0 .+ exp.(-x)) + +mutable struct SBNTrace <: Gen.Trace + parents::Vector{Bool} + W::Matrix{Float32} + b::Vector{Float32} + outputs::Vector{Bool} + probs::Vector{Float32} + log_prob::Float64 +end + +Gen.get_args(trace::SBNTrace) = (trace.W, trace.b, trace.parents) +Gen.get_score(trace::SBNTrace) = trace.log_prob +Gen.get_retval(trace::SBNTrace) = trace.outputs +Gen.get_choices(trace::SBNTrace) = Gen.choicemap((:outputs, trace.outputs)) +Gen.project(trace::SBNTrace, ::Gen.EmptySelection) = 0.0 + +struct SigmoidBeliefNetwork <: Gen.GenerativeFunction{Vector{Bool},SBNTrace} end + +Gen.get_gen_fn(trace::SBNTrace) = SigmoidBeliefNetwork() + +""" + sigmoid_belief_network(parents::Vector{Bool}, b::Vector{Float32}, W::Matrix{Float32}) + +Samples an output vector of binary units given an input vector of binary units +""" +const sigmoid_belief_network = SigmoidBeliefNetwork() + +#@fastmath +function sbn_compute_probs(W::Matrix{Float32}, b::Vector{Float32}, parents) + n = length(b) + m = length(parents) + probs::Vector{Float32} = copy(b) + @inbounds @simd for i in 1:n # TODO compare against switching iteration order.. so that W is accessed column-major + @simd for j in 1:m + if parents[j] + probs[i] += W[i,j] + end + end + probs[i] = sigmoid(probs[i]) + end + return probs +end + +#@fastmath +@inbounds function sbn_logpdf_sum(probs, outputs) + total = 0.0 + n = length(probs) + for i in 1:n + if outputs[i] + total += log(probs[i]) + else + total += log(1.0f0 - probs[i]) + end + end + return total +end + +#@fastmath +@inbounds function sbn_W_grad!(W_grad, outputs, probs::Vector{Float32}, parents) + num_outputs = length(outputs) + num_parents = length(parents) + @assert size(W_grad) == (num_outputs, num_parents) + @simd for col in 1:num_parents + if parents[col] + @simd for row in 1:num_outputs + prob = probs[row] + W_grad[row, col] = (outputs[row] ? 1.0f0 - prob : -prob) + end + else + @simd for row in 1:num_outputs + W_grad[row, col] = 0.0f0 + end + end + end + return W_grad +end + + +function sbn_sample(probs::Vector{Float32}) + n = length(probs) + outputs = Vector{Bool}(undef, n) + @inbounds @simd for i in 1:n + outputs[i] = rand(Float32) < probs[i] + end + return outputs +end + +Gen.accepts_output_grad(::SigmoidBeliefNetwork) = false +Gen.has_argument_grads(::SigmoidBeliefNetwork) = (false, true, true) + +function Gen.simulate(::SigmoidBeliefNetwork, args::Tuple) + (parents, b, W) = args + probs = sbn_compute_probs(W, b, parents) + outputs = sbn_sample(probs) + log_prob = sbn_logpdf_sum(probs, outputs) + trace = SBNTrace(parents, W, b, outputs, probs, log_prob) + return trace +end + +function Gen.generate(::SigmoidBeliefNetwork, args::Tuple, choices::Gen.ChoiceMap) + (parents, b, W) = args + probs = sbn_compute_probs(W, b, parents) + if Gen.has_value(choices, :outputs) + outputs = choices[:outputs] + log_prob = sbn_logpdf_sum(probs, outputs) + log_weight = log_prob + else + outputs = sbn_sample(probs) + log_prob = sbn_logpdf_sum(probs, outputs) + log_weight = 0.0 + end + trace = SBNTrace(parents, W, b, outputs, probs, log_prob) + return (trace, log_weight) +end + +function Gen.accumulate_param_gradients!(trace::SBNTrace, retval_grad::Nothing, scale_factor) + parents = trace.parents + outputs = trace.outputs + probs = trace.probs + b_grad = sum(outputs .* (1.0f0 .- probs), dims=2) - sum( (.!outputs) .* probs, dims=2) + W = trace.W + W_grad = Matrix{Float32}(undef, size(W)[1], size(W)[2]) + W_grad = sbn_W_grad!(W_grad, outputs, probs, parents) + return (nothing, b_grad, W_grad) +end + +########################################### +# test logpdf_grad via finite differences # +########################################### + +function finite_diff_arr(f::Function, args::Tuple, i::Int, idx, dx::Real) + pos_args = Any[deepcopy(args)...] + pos_args[i][idx] += dx + neg_args = Any[deepcopy(args)...] + neg_args[i][idx] -= dx + return (f(pos_args...) - f(neg_args...)) / (2.0f0 * dx) +end + +function test_logpdf_grad() + + n = 5 + m = 3 + + W = randn(Float32, n, m) + b = randn(Float32, n) + + parents = Vector{Bool}(rand(m) .< 0.5) + outputs = Vector{Bool}(rand(n) .< 0.5) + + (trace, log_weight) = Gen.generate(sigmoid_belief_network, (parents, b, W), Gen.choicemap((:outputs, outputs))) + @assert isapprox(log_weight, Gen.get_score(trace)) + (_, b_grad, W_grad) = Gen.accumulate_param_gradients!(trace, nothing, NaN) + + @assert size(b_grad) == (n,) + @assert size(W_grad) == (n, m) + + f = (b, W) -> Gen.generate(sigmoid_belief_network, (parents, b, W), Gen.choicemap((:outputs, outputs)))[2] + dx = 1f-4 + + # check gradients with respect to b + b_grad_expected = Vector{Float32}(undef, n) + for i in 1:n + b_grad_expected[i] = finite_diff_arr(f, (b, W), 1, i, dx) + end + @assert isapprox(b_grad_expected, b_grad, rtol=2e-2) + + # check gradients with respect to W + W_grad_flat_expected = Vector{Float32}(undef, n*m) + for i in 1:(n*m) + W_grad_flat_expected[i] = finite_diff_arr(f, (b, W), 2, i, dx) + end + @assert isapprox(W_grad_flat_expected, W_grad[:], rtol=2e-2) +end + +test_logpdf_grad() diff --git a/rws_mnist/sgd_momentum.jl b/rws_mnist/sgd_momentum.jl new file mode 100644 index 0000000..e085f89 --- /dev/null +++ b/rws_mnist/sgd_momentum.jl @@ -0,0 +1,48 @@ +# SGD with momentum +# grad = beta * prev_grad + (1-beta) * new_grad +# params = params + alpha * grad + +import Gen + +struct SGDWithMomentumConf{T} + beta::T + learning_rate::T +end + +struct SGDWithMomentumJulia + conf::SGDWithMomentumConf + store::Gen.JuliaParameterStore + parameter_ids::Vector{Tuple{Gen.GenerativeFunction,Symbol}} + prev_updates::Vector{Any} +end + +function Gen.init_optimizer( + conf::SGDWithMomentumConf, + parameter_ids::Vector, + store::Gen.JuliaParameterStore=Gen.julia_default_parameter_store) + + # initialize gradients with momentum to zero + prev_updates = Vector{Any}(undef, length(parameter_ids)) + for i in 1:length(parameter_ids) + (gen_fn, name) = parameter_ids[i]::Tuple{Gen.GenerativeFunction,Symbol} + prev_updates[i] = zero(Gen.get_parameter_value((gen_fn, name), store)) + end + + return SGDWithMomentumJulia(conf, store, parameter_ids, prev_updates) +end + +function Gen.apply_update!(state::SGDWithMomentumJulia) + beta = state.conf.beta + learning_rate = state.conf.learning_rate + for i in 1:length(state.prev_updates) + id = state.parameter_ids[i] + value = get_parameter_value(id, state.store) + grad = Gen.get_gradient(id, state.store) + update = beta * state.prev_updates[i] + (1f0 - beta) * learning_rate * grad + new_value = Gen.in_place_add!(value, update) + Gen.set_parameter_value!(id, new_value, state.store) + Gen.reset_gradient!(id, state.store) + state.prev_updates[i] = update + end + return nothing +end From e8382329c46d9cb17129776087641b7380a790b6 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Wed, 19 May 2021 16:25:43 -0400 Subject: [PATCH 3/6] Use gradients branch of Gen for MLE example --- mle/Manifest.toml | 194 +++++++++++++++++++++++++++++++--------------- mle/run_all.jl | 2 +- 2 files changed, 133 insertions(+), 63 deletions(-) diff --git a/mle/Manifest.toml b/mle/Manifest.toml index 131e621..2d5adac 100644 --- a/mle/Manifest.toml +++ b/mle/Manifest.toml @@ -1,19 +1,19 @@ # This file is machine-generated - editing it directly is not advised +[[ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" + [[Artifacts]] -deps = ["Pkg"] -git-tree-sha1 = "c30985d8821e0cd73870b17b0ed0ce6dc44cb744" uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" -version = "1.3.0" [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" [[ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "de4f08843c332d355852721adb1592bce7924da3" +git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.29" +version = "0.9.44" [[CommonSubexpressions]] deps = ["MacroTools", "Test"] @@ -23,20 +23,18 @@ version = "0.3.0" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "919c7f3151e79ff196add81d7f4e45d91bbf420b" +git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.25.0" +version = "3.30.0" [[CompilerSupportLibraries_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "8e695f735fca77e9708e795eda62afdb869cbb70" +deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "0.3.4+0" [[DataAPI]] -git-tree-sha1 = "8ab70b4de35bb3b8cc19654f6b893cf5164f8ee8" +git-tree-sha1 = "dfb3b7e89e395be1e25c2ad6d7690dc29cc53b1d" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.5.1" +version = "1.6.0" [[DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] @@ -70,26 +68,36 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[Distributions]] deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"] -git-tree-sha1 = "f0e06a5b5ccda38e2fb8f59d91316e657b67047d" +git-tree-sha1 = "a837fdf80f333415b69684ba8e8ae6ba76de6aaa" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.24.12" +version = "0.24.18" + +[[DocStringExtensions]] +deps = ["LibGit2", "Markdown", "Pkg", "Test"] +git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.8.4" + +[[Downloads]] +deps = ["ArgTools", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" [[FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "e384d3cff80ac79c7a541a817192841836e46331" +git-tree-sha1 = "31939159aeb8ffad1d4d8ee44d07f8558273120a" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.11.2" +version = "0.11.7" [[ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "d48a40c0f54f29a5c8748cfb3225719accc72b77" +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Printf", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "e2af66012e08966366a43251e1fd421522908be6" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.16" +version = "0.10.18" [[FunctionWrappers]] -git-tree-sha1 = "e4813d187be8c7b993cb7f85cbf2b7bfbaadc694" +git-tree-sha1 = "241552bc2209f0fa068b6415b1942cc0aa486bcc" uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" -version = "1.1.1" +version = "1.1.2" [[FunctionalCollections]] deps = ["Test"] @@ -98,19 +106,22 @@ uuid = "de31a74c-ac4f-5751-b3fd-e18cd04993ca" version = "0.5.0" [[Gen]] -deps = ["DataStructures", "Distributions", "ForwardDiff", "FunctionalCollections", "JSON", "LinearAlgebra", "MacroTools", "Parameters", "Random", "ReverseDiff", "SpecialFunctions"] -git-tree-sha1 = "00b42c1484d658ed9fd74ffe93ddc5259d0d010e" +deps = ["Compat", "DataStructures", "Distributions", "ForwardDiff", "FunctionalCollections", "JSON", "LinearAlgebra", "MacroTools", "Parameters", "Random", "ReverseDiff", "SpecialFunctions"] +git-tree-sha1 = "b6cd202d1dc5020e5c9a300338413010e3aed5c9" +repo-rev = "20210512-marcoct-gradopts" +repo-url = "https://github.com/probcomp/Gen.jl.git" uuid = "ea4f424c-a589-11e8-07c0-fd5c91b9da4a" -version = "0.4.1" +version = "0.4.3" [[InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[JLLWrappers]] -git-tree-sha1 = "a431f5f2ca3f4feef3bd7a5e94b8b8d4f2f647a0" +deps = ["Preferences"] +git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.2.0" +version = "1.3.0" [[JSON]] deps = ["Dates", "Mmap", "Parsers", "Unicode"] @@ -118,10 +129,22 @@ git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" version = "0.21.1" +[[LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" + +[[LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" + [[LibGit2]] -deps = ["Printf"] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +[[LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" + [[Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -129,6 +152,12 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" deps = ["Libdl"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +[[LogExpFunctions]] +deps = ["DocStringExtensions", "LinearAlgebra"] +git-tree-sha1 = "1ba664552f1ef15325e68dc4c05c3ef8c2d5d885" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.2.4" + [[Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -142,36 +171,46 @@ version = "0.5.6" deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +[[MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" + [[Missings]] deps = ["DataAPI"] -git-tree-sha1 = "f8c673ccc215eb50fcadb285f522420e29e69e1c" +git-tree-sha1 = "4ea90bd5d3985ae1f9a908bd4500ae88921c5ce7" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "0.4.5" +version = "1.0.0" [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +[[MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" + [[NaNMath]] git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" version = "0.3.5" +[[NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" + [[OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3" +git-tree-sha1 = "b9b8b8ed236998f91143938a760c2112dceeb2b4" uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.3+4" +version = "0.5.4+0" [[OrderedCollections]] -git-tree-sha1 = "d45739abcfc03b51f6a42712894a593f74c80a23" +git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.3.3" +version = "1.4.1" [[PDMats]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"] -git-tree-sha1 = "95a4038d1011dfdbde7cecd2ad0ac411e53ab1bc" +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "f82a0e71f222199de8e9eb9a09977bd0767d52a0" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.10.1" +version = "0.11.0" [[Parameters]] deps = ["OrderedCollections", "UnPack"] @@ -181,14 +220,20 @@ version = "0.12.2" [[Parsers]] deps = ["Dates"] -git-tree-sha1 = "50c9a9ed8c714945e01cd53a21007ed3865ed714" +git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "1.0.15" +version = "1.1.0" [[Pkg]] -deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +[[Preferences]] +deps = ["TOML"] +git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.2.2" + [[Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -200,7 +245,7 @@ uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" version = "2.4.1" [[REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets"] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" [[Random]] @@ -209,21 +254,21 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [[ReverseDiff]] deps = ["DiffResults", "DiffRules", "ForwardDiff", "FunctionWrappers", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "SpecialFunctions", "StaticArrays", "Statistics"] -git-tree-sha1 = "ca062d55a167a81909772a325932e72c389f9724" +git-tree-sha1 = "63ee24ea0689157a1113dbdab10c6cb011d519c4" uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -version = "1.5.0" +version = "1.9.0" [[Rmath]] deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "86c5647b565873641538d8f812c04e4c9dbeb370" +git-tree-sha1 = "bf3188feca147ce108c76ad82c2792c57abe7b1f" uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.6.1" +version = "0.7.0" [[Rmath_jll]] -deps = ["Libdl", "Pkg"] -git-tree-sha1 = "d76185aa1f421306dec73c057aa384bad74188f0" +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "68db32dff12bb6127bac73c209881191bf0efbb7" uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.2.2+1" +version = "0.3.0+0" [[SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -239,49 +284,62 @@ uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" uuid = "6462fe0b-24de-5631-8697-dd941f90decc" [[SortingAlgorithms]] -deps = ["DataStructures", "Random", "Test"] -git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd" +deps = ["DataStructures"] +git-tree-sha1 = "2ec1962eba973f383239da22e75218565c390a96" uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "0.3.1" +version = "1.0.0" [[SparseArrays]] deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[SpecialFunctions]] -deps = ["ChainRulesCore", "OpenSpecFun_jll"] -git-tree-sha1 = "75394dbe2bd346beeed750fb02baa6445487b862" +deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"] +git-tree-sha1 = "9146da51b38e9705b9f5ccfadc3ab10a482cae36" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "1.2.1" +version = "1.4.0" [[StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "9da72ed50e94dbff92036da395275ed114e04d49" +git-tree-sha1 = "c635017268fd51ed944ec429bcc4ad010bcea900" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.0.1" +version = "1.2.0" [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +[[StatsAPI]] +git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.0.0" + [[StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] -git-tree-sha1 = "7bab7d4eb46b225b35179632852b595a3162cb61" +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "2f6792d523d7448bbe2fec99eca9218f06cc746d" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.2" +version = "0.33.8" [[StatsFuns]] -deps = ["Rmath", "SpecialFunctions"] -git-tree-sha1 = "3b9f665c70712af3264b61c27a7e1d62055dafd1" +deps = ["LogExpFunctions", "Rmath", "SpecialFunctions"] +git-tree-sha1 = "30cd8c360c54081f806b1ee14d2eecbef3c04c49" uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "0.9.6" +version = "0.9.8" [[SuiteSparse]] deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" +[[TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" + +[[Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" + [[Test]] -deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[UUIDs]] @@ -295,3 +353,15 @@ version = "1.0.2" [[Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" + +[[nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" + +[[p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" diff --git a/mle/run_all.jl b/mle/run_all.jl index 2b08845..552ee09 100644 --- a/mle/run_all.jl +++ b/mle/run_all.jl @@ -16,7 +16,7 @@ for iter=1:1000 accumulate_param_gradients!(trace, 0.) gradient = get_gradient((foo, :mu)) value = get_parameter_value((foo, :mu)) - init_parameter!((foo, :mu), vlaue + step_size * gradient) + init_parameter!((foo, :mu), value + step_size * gradient) end @assert abs(get_param(foo, :mu) - 3) < 1e-2 From 0a00ad6f86afa058dbd059bd50025053b567a12c Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Wed, 19 May 2021 16:28:39 -0400 Subject: [PATCH 4/6] update README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index cdfe79e..60a2fd0 100644 --- a/README.md +++ b/README.md @@ -26,3 +26,5 @@ NOTE: These examples take substantially longer to run than Gen's unit tests. - Minimal example of maximum likelihood estimation: `mle/` - Reversible jump MCMC in change point model: `coal/` + +- Learning a deep generative model of binarized MNIST digits using [reweighted wake sleep](https://arxiv.org/abs/1406.2751): `rws_mnist/` From 3941f4498dac42d3ff8546b79576040b8044e18d Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Wed, 19 May 2021 16:53:06 -0400 Subject: [PATCH 5/6] add note about issue --- rws_mnist/model.jl | 32 ++++++++------------------------ 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/rws_mnist/model.jl b/rws_mnist/model.jl index b8d1df0..b6fb667 100644 --- a/rws_mnist/model.jl +++ b/rws_mnist/model.jl @@ -17,36 +17,26 @@ using Gen: @gen, @param, @load_generated_functions @param x_W::Matrix{Float32} @param x_b::Vector{Float32} - # TODO issue -- Gen wraps h3_n in a ReverseDiff, which causes a failure - h3_n = 10 - #h3_n = length(h3_b) - - #h2_n = length(h2_b) - #h1_n = length(h1_b) - #h1_n = 200 - #x_n = length(x_b) + # NOTE: we should instead use h3_n = length(h3_b), but we can't because of + # https://github.com/probcomp/Gen.jl/issues/418 + h3_n = 10 # sample third hidden layer (10 units) h3 ~ sigmoid_belief_network([false], h3_b, zeros(Float32, h3_n, 1)) # sample second hidden layer (200 units) - #@assert size(h2_W) == (h2_n, h3_n) # TODO @assert is not supported in static modeling language h2 ~ sigmoid_belief_network(h3, h2_b, h2_W) # sample first hidden layer (200 units) - #@assert size(h1_W) == (h1_n, h2_n) h1 ~ sigmoid_belief_network(h2, h1_b, h1_W) - #h1 ~ sigmoid_belief_network([false], h1_b, zeros(Float32, h1_n, 1)) # sample visible layer (28x28=784 units) - #@assert size(x_W) == (x_n, h1_n) x ~ sigmoid_belief_network(h1, x_b, x_W) return nothing end @gen (static) function q(x::Vector{Bool}) - #@assert sum(x) > 0 # first hidden layer @param h1_W::Matrix{Float32} @@ -60,21 +50,13 @@ end @param h3_W::Matrix{Float32} @param h3_b::Vector{Float32} - #x_n = length(x) - #h1_n = length(h1_b) - #h2_n = length(h2_b) - #h3_n = length(h3_b) - # sample first hidden layer (200 units) - #@assert size(h1_W) == (h1_n, x_n) h1 ~ sigmoid_belief_network(x, h1_b, h1_W) # sample second hidden layer (200 units) - #@assert size(h2_W) == (h2_n, h1_n) h2 ~ sigmoid_belief_network(h1, h2_b, h2_W) # sample third hidden layer (10 units) - #@assert size(h3_W) == (h3_n, h2_n) h3 ~ sigmoid_belief_network(h2, h3_b, h3_W) return nothing @@ -85,14 +67,16 @@ end import Random using Gen: apply_update!, init_parameter!, get_parameter_value -# see https://github.com/jbornschein/reweighted-ws/blob/e96414719d09ab4941dc77bab4cf4847acc6a8e7/learning/models/sbn.py#L94 -# and https://github.com/jbornschein/reweighted-ws/blob/e96414719d09ab4941dc77bab4cf4847acc6a8e7/learning/model.py#L22 +# parameter initializations based on: +# https://github.com/jbornschein/reweighted-ws/blob/e96414719d09ab4941dc77bab4cf4847acc6a8e7/learning/models/sbn.py#L94 +# https://github.com/jbornschein/reweighted-ws/blob/e96414719d09ab4941dc77bab4cf4847acc6a8e7/learning/model.py#L22 +# https://github.com/jbornschein/reweighted-ws/blob/e96414719d09ab4941dc77bab4cf4847acc6a8e7/learning/models/sbn.py#L94 + function make_W(n_output, n_input) scale = Float32(sqrt(6f0) / sqrt(n_input + n_output)) return scale * ((2*rand(Float32, n_output, n_input)) .- 1.0f0) / n_input end -# see https://github.com/jbornschein/reweighted-ws/blob/e96414719d09ab4941dc77bab4cf4847acc6a8e7/learning/models/sbn.py#L94 make_b(n) = -ones(Float32, n) function initialize_p_params!(; x_n=28*28, h1_n=200, h2_n=200, h3_n=10) From 217444199db5f39ef26f483b156665c26a05c558 Mon Sep 17 00:00:00 2001 From: Marco Cusumano-Towner Date: Wed, 19 May 2021 17:19:21 -0400 Subject: [PATCH 6/6] fix mle example --- mle/run_all.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mle/run_all.jl b/mle/run_all.jl index 552ee09..0c70c33 100644 --- a/mle/run_all.jl +++ b/mle/run_all.jl @@ -19,4 +19,4 @@ for iter=1:1000 init_parameter!((foo, :mu), value + step_size * gradient) end -@assert abs(get_param(foo, :mu) - 3) < 1e-2 +@assert abs(get_parameter_value((foo, :mu)) - 3) < 1e-2