Skip to content

Commit

Permalink
Handle mixtures of NamedTuple and Dict for constant datasets better (#97
Browse files Browse the repository at this point in the history
)

* Make type-inferrable

* Add function for converting to dict with string keys

* Convert to dict with string keys

* Avoid now-unnecessary conversion

* Test convert_to_constant_dataset

* Test dims matched to namedtuple keys

* Increment version number
  • Loading branch information
sethaxen authored Oct 22, 2020
1 parent b458349 commit 9c3a89a
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ArviZ"
uuid = "131c737c-5715-5e2e-ad31-c244f01c1dc7"
authors = ["Seth Axen <[email protected]>"]
version = "0.4.7"
version = "0.4.8"

[deps]
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
Expand Down
16 changes: 9 additions & 7 deletions src/dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct Dataset
end
end

Dataset(; kwargs...) = xarray.Dataset(; kwargs...)
Dataset(; kwargs...) = Dataset(xarray.Dataset(; kwargs...))
@inline Dataset(data::Dataset) = data

@inline PyObject(data::Dataset) = getfield(data, :o)
Expand Down Expand Up @@ -116,25 +116,27 @@ function convert_to_constant_dataset(
library = nothing,
attrs = nothing,
)
obj = convert(Dict, obj)
base = arviz.data.base
coords = coords === nothing ? Dict{String,Vector}() : coords
dims = dims === nothing ? Dict{String,Vector{String}}() : dims

data = Dict{String,Any}()
obj = _asstringkeydict(obj)
coords = _asstringkeydict(coords)
dims = _asstringkeydict(dims)
attrs = _asstringkeydict(attrs)

data = Dict{String,PyObject}()
for (key, vals) in obj
vals = _asarray(vals)
val_dims = get(dims, key, nothing)
(val_dims, val_coords) =
base.generate_dims_coords(size(vals), key; dims = val_dims, coords = coords)
data[string(key)] = xarray.DataArray(vals; dims = val_dims, coords = val_coords)
data[key] = xarray.DataArray(vals; dims = val_dims, coords = val_coords)
end

default_attrs = base.make_attrs()
if library !== nothing
default_attrs = merge(default_attrs, Dict("inference_library" => string(library)))
end
attrs = attrs === nothing ? default_attrs : merge(default_attrs, attrs)
attrs = merge(default_attrs, attrs)
return Dataset(data_vars = data, coords = coords, attrs = attrs)
end

Expand Down
3 changes: 1 addition & 2 deletions src/mcmcchains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,8 @@ function from_mcmcchains(
:predictions_constant_data => predictions_constant_data,
]
group_data === nothing && continue
group_dict = convert(Dict, group_data)
group_dataset =
convert_to_constant_dataset(group_dict; library = library, kwargs...)
convert_to_constant_dataset(group_data; library = library, kwargs...)
concat!(all_idata, InferenceData(; group => group_dataset))
end

Expand Down
4 changes: 4 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ snakecase(s) = replace(lowercase(s), " " => "_")
@inline _asarray(x) = [x]
@inline _asarray(x::AbstractArray) = x

_asstringkeydict(x) = Dict(String(k) => v for (k, v) in pairs(x))
_asstringkeydict(x::Dict{String}) = x
_asstringkeydict(::Nothing) = Dict{String,Any}()

function enforce_stat_types(dict)
return Dict(k => get(sample_stats_types, k, eltype(v)).(v) for (k, v) in dict)
end
Expand Down
91 changes: 91 additions & 0 deletions test/test_dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,24 @@
@test ArviZ.Dataset(dataset) === dataset
@test_throws ArgumentError ArviZ.Dataset(py"PyNullObject()")
@test hash(dataset) == hash(pydataset)

vars = Dict("x" => ("dimx", randn(3)), ("y" => (("dimy_1", "dimy_2"), randn(3, 2))))
coords =
Dict("dimx" => [1, 2, 3], "dimy_1" => ["a", "b", "c"], "dimy_2" => ["d", "e"])
attrs = Dict("prop1" => 1, "prop2" => "propval")
@inferred ArviZ.Dataset(data_vars = vars, coords = coords, attrs = attrs)
ds = ArviZ.Dataset(data_vars = vars, coords = coords, attrs = attrs)
@test ds isa ArviZ.Dataset
vars2, kwargs = ArviZ.dataset_to_dict(ds)
for (k, v) in vars
@test k keys(vars2)
@test vars2[k] v[2]
end
@test kwargs.coords == coords
for (k, v) in attrs
@test k keys(kwargs.attrs)
@test kwargs.attrs[k] == v
end
end

@testset "properties" begin
Expand Down Expand Up @@ -74,6 +92,79 @@ end
end
end

@testset "ArviZ.convert_to_constant_dataset" begin
@testset "ArviZ.convert_to_constant_dataset(::Dict)" begin
data = Dict("x" => randn(4, 5), "y" => ["a", "b", "c"])
dataset = ArviZ.convert_to_constant_dataset(data)
@test dataset isa ArviZ.Dataset
@test "x" dataset.keys()
@test "y" dataset.keys()
@test Set(dataset.coords) == Set(["x_dim_0", "x_dim_1", "y_dim_0"])
@test collect(dataset._variables["x"].values) == data["x"]
@test collect(dataset._variables["y"].values) == data["y"]
end

@testset "ArviZ.convert_to_constant_dataset(::Dict; kwargs...)" begin
data = Dict("x" => randn(4, 5), "y" => ["a", "b", "c"])
coords = Dict("xdim1" => 1:4, "xdim2" => 5:9, "ydim1" => ["d", "e", "f"])
dims = Dict("x" => ["xdim1", "xdim2"], "y" => ["ydim1"])
library = "MyLib"
dataset = ArviZ.convert_to_constant_dataset(data)
attrs = Dict("prop" => "propval")

dataset = ArviZ.convert_to_constant_dataset(
data;
coords = coords,
dims = dims,
library = library,
attrs = attrs,
)
@test dataset isa ArviZ.Dataset
@test "x" dataset.keys()
@test "y" dataset.keys()
@test Set(dataset.coords) == Set(["xdim1", "xdim2", "ydim1"])
@test collect(dataset._variables["xdim1"].values) == coords["xdim1"]
@test collect(dataset._variables["xdim2"].values) == coords["xdim2"]
@test collect(dataset._variables["ydim1"].values) == coords["ydim1"]
@test collect(dataset["x"].coords) == ["xdim1", "xdim2"]
@test collect(dataset["y"].coords) == ["ydim1"]
@test collect(dataset._variables["x"].values) == data["x"]
@test collect(dataset._variables["y"].values) == data["y"]
@test dataset.attrs["prop"] == attrs["prop"]
@test dataset.attrs["inference_library"] == library
end

@testset "ArviZ.convert_to_constant_dataset(::NamedTuple; kwargs...)" begin
data = (x = randn(4, 5), y = ["a", "b", "c"])
coords = (xdim1 = 1:4, xdim2 = 5:9, ydim1 = ["d", "e", "f"])
dims = (x = ["xdim1", "xdim2"], y = ["ydim1"])
library = "MyLib"
dataset = ArviZ.convert_to_constant_dataset(data)
attrs = (prop = "propval",)

dataset = ArviZ.convert_to_constant_dataset(
data;
coords = coords,
dims = dims,
library = library,
attrs = attrs,
)
@test dataset isa ArviZ.Dataset
@test "x" dataset.keys()
@test "y" dataset.keys()
@test Set(dataset.coords) == Set(["xdim1", "xdim2", "ydim1"])
@test collect(dataset._variables["xdim1"].values) == coords.xdim1
@test collect(dataset._variables["xdim2"].values) == coords.xdim2
@test collect(dataset._variables["ydim1"].values) == coords.ydim1
@test collect(dataset["x"].coords) == ["xdim1", "xdim2"]
@test collect(dataset["y"].coords) == ["ydim1"]
@test collect(dataset._variables["x"].values) == data.x
@test collect(dataset._variables["y"].values) == data.y
@test dataset.attrs["prop"] == attrs.prop
@test dataset.attrs["inference_library"] == library
end
end

@testset "dict to dataset roundtrip" begin
rng = MersenneTwister(42)
J = 8
Expand Down
20 changes: 20 additions & 0 deletions test/test_namedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,26 @@ end
sizes = dimsizes(ds)
@test length(sizes) == 1
@test "w" in keys(vardict(ds))
@test dimdict(ds)["w"] == ("wx",)
@test "inference_library" in keys(attributes(ds))
@test attributes(ds)["inference_library"] == "MyLib"

# ensure that dims are matched to named tuple keys
# https://github.com/arviz-devs/ArviZ.jl/issues/96
idata = from_namedtuple(
nt;
(group => (w = [1.0, 2.0],),)...,
dims = Dict("w" => ["wx"]),
coords = Dict("wx" => 1:2),
library = "MyLib",
)
@test idata isa InferenceData
@test group in ArviZ.groupnames(idata)
ds = getproperty(idata, group)
sizes = dimsizes(ds)
@test length(sizes) == 1
@test "w" in keys(vardict(ds))
@test dimdict(ds)["w"] == ("wx",)
@test "inference_library" in keys(attributes(ds))
@test attributes(ds)["inference_library"] == "MyLib"
end
Expand Down

2 comments on commit 9c3a89a

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/23482

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.8 -m "<description of version>" 9c3a89a5cacd19c947a354a7e67f93bd7385f379
git push origin v0.4.8

Please sign in to comment.