diff --git a/Project.toml b/Project.toml index 04d4d284b..4f4a5ecaa 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.24.0" +version = "0.24.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/varinfo.jl b/src/varinfo.jl index 590626df3..83c914844 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -397,10 +397,9 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) push!(ranges, r) offset = r[end] # `dists`: only valid if they're the same. - dists_left = getdist(metadata_left, vn) - dists_right = getdist(metadata_right, vn) - @assert dists_left == dists_right - push!(dists, dists_left) + dist_right = getdist(metadata_right, vn) + # Give precedence to `metadata_right`. + push!(dists, dist_right) # `orders`: giving precedence to `metadata_right` push!(orders, getorder(metadata_right, vn)) # `flags` @@ -418,8 +417,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) push!(ranges, r) offset = r[end] # `dists` - dists_left = getdist(metadata_left, vn) - push!(dists, dists_left) + dist_left = getdist(metadata_left, vn) + push!(dists, dist_left) # `orders` push!(orders, getorder(metadata_left, vn)) # `flags` @@ -436,8 +435,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) push!(ranges, r) offset = r[end] # `dists` - dists_right = getdist(metadata_right, vn) - push!(dists, dists_right) + dist_right = getdist(metadata_right, vn) + push!(dists, dist_right) # `orders` push!(orders, getorder(metadata_right, vn)) # `flags` diff --git a/test/varinfo.jl b/test/varinfo.jl index aa790fd48..71e341767 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -594,6 +594,30 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end end end + + @testset "different models" begin + @model function demo_merge_different_y() + x ~ Uniform() + return y ~ Normal() + end + @model function demo_merge_different_z() + x ~ Normal() + return z ~ Normal() + end + model_left = demo_merge_different_y() + model_right = demo_merge_different_z() + + varinfo_left = VarInfo(model_left) + varinfo_right = VarInfo(model_right) + + varinfo_merged = merge(varinfo_left, varinfo_right) + vns = [@varname(x), @varname(y), @varname(z)] + check_varinfo_keys(varinfo_merged, vns) + + # Right has precedence. + @test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)] + @test DynamicPPL.getdist(varinfo_merged, @varname(x)) isa Normal + end end @testset "VarInfo with selectors" begin