Skip to content

Commit

Permalink
added params to multivariate and matrix vars
Browse files Browse the repository at this point in the history
- Also mixtures and truncated distributions
- Tests enforce contract typeof(d)(params(d)...) == d for isa(d,
Distribution)
- Similar contracts hold for mixtures and truncated distributions
  • Loading branch information
jmxpearson committed Jan 22, 2016
1 parent 4cd35c5 commit 8ff3fe2
Show file tree
Hide file tree
Showing 19 changed files with 79 additions and 31 deletions.
4 changes: 2 additions & 2 deletions src/matrix/inversewishart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ end

#### Constructors

function InverseWishart{ST<:AbstractPDMat}(df::Real, Ψ::ST)
function InverseWishart{ST <: AbstractPDMat}(df::Real, Ψ::ST)
p = dim(Ψ)
df > p - 1 || error("df should be greater than dim - 1.")
InverseWishart{ST}(df, Ψ, _invwishart_c0(df, Ψ))
Expand All @@ -35,7 +35,7 @@ insupport(d::InverseWishart, X::Matrix{Float64}) = size(X) == size(d) && isposde

dim(d::InverseWishart) = dim(d.Ψ)
size(d::InverseWishart) = (p = dim(d); (p, p))

params(d::InverseWishart) = (d.df, d.Ψ, d.c0)

#### Show

Expand Down
4 changes: 2 additions & 2 deletions src/matrix/wishart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ end

#### Constructors

function Wishart{ST<:AbstractPDMat}(df::Real, S::ST)
function Wishart{ST <: AbstractPDMat}(df::Real, S::ST)
p = dim(S)
df > p - 1 || error("df should be greater than dim - 1.")
Wishart{ST}(df, S, _wishart_c0(df, S))
Expand All @@ -35,7 +35,7 @@ insupport(d::Wishart, X::Matrix{Float64}) = size(X) == size(d) && isposdef(X)

dim(d::Wishart) = dim(d.S)
size(d::Wishart) = (p = dim(d); (p, p))

params(d::Wishart) = (d.df, d.S, d.c0)

#### Show

Expand Down
1 change: 1 addition & 0 deletions src/mixtures/mixturemodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ components(d::MixtureModel) = d.components
component(d::MixtureModel, k::Int) = d.components[k]

probs(d::MixtureModel) = probs(d.prior)
params(d::MixtureModel) = ([params(c) for c in d.components], params(d.prior)[1])

function mean(d::UnivariateMixture)
K = ncomponents(d)
Expand Down
2 changes: 2 additions & 0 deletions src/mixtures/unigmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ mean(d::UnivariateGMM) = dot(d.means, probs(d))

rand(d::UnivariateGMM) = (k = rand(d.prior); d.means[k] + randn() * d.stds[k])

params(d::UnivariateGMM) = (d.means, d.stds, d.prior)

immutable UnivariateGMMSampler <: Sampleable{Univariate,Continuous}
means::Vector{Float64}
stds::Vector{Float64}
Expand Down
1 change: 1 addition & 0 deletions src/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Base.show(io::IO, d::Dirichlet) = show(io, d, (:alpha,))

length(d::Dirichlet) = length(d.alpha)
mean(d::Dirichlet) = d.alpha .* inv(d.alpha0)
params(d::Dirichlet) = (d.alpha,)

function var(d::Dirichlet)
α = d.alpha
Expand Down
1 change: 1 addition & 0 deletions src/multivariate/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ Base.show(io::IO, d::MvNormal) =

length(d::MvNormal) = length(d.μ)
mean(d::MvNormal) = convert(Vector{Float64}, d.μ)
params(d::MvNormal) = (d.μ, d.Σ)

var(d::MvNormal) = diag(d.Σ)
cov(d::MvNormal) = full(d.Σ)
Expand Down
1 change: 1 addition & 0 deletions src/multivariate/mvnormalcanon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ canonform{C}(d::MvNormal{C,ZeroVector{Float64}}) = MvNormalCanon(inv(d.Σ))

length(d::MvNormalCanon) = length(d.μ)
mean(d::MvNormalCanon) = convert(Vector{Float64}, d.μ)
params(d::MvNormalCanon) = (d.μ, d.h, d.J)

var(d::MvNormalCanon) = diag(inv(d.J))
cov(d::MvNormalCanon) = full(inv(d.J))
Expand Down
10 changes: 6 additions & 4 deletions src/multivariate/mvtdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ end

function GenericMvTDist{Cov<:AbstractPDMat}(df::Float64, Σ::Cov)
d = dim(Σ)
GenericMvTDist{Cov}(df, d, true, zeros(d), Σ)
GenericMvTDist{Cov}(df, d, true, zeros(d), Σ)
end

## Construction of multivariate normal with specific covariance type
Expand Down Expand Up @@ -80,6 +80,8 @@ invscale(d::GenericMvTDist) = full(inv(d.Σ))
invcov(d::GenericMvTDist) = d.df>2 ? ((d.df-2)/d.df)*full(inv(d.Σ)) : NaN*ones(d.dim, d.dim)
logdet_cov(d::GenericMvTDist) = d.df>2 ? logdet((d.df/(d.df-2))*d.Σ) : NaN

params(d::GenericMvTDist) = (d.df, d.μ, d.Σ)

# For entropy calculations see "Multivariate t Distributions and their Applications", S. Kotz & S. Nadarajah
function entropy(d::GenericMvTDist)
hdf, hdim = 0.5*d.df, 0.5*d.dim
Expand All @@ -89,12 +91,12 @@ end

# evaluation (for GenericMvTDist)

insupport{T<:Real}(d::AbstractMvTDist, x::AbstractVector{T}) =
insupport{T<:Real}(d::AbstractMvTDist, x::AbstractVector{T}) =
length(d) == length(x) && allfinite(x)

function sqmahal{T<:Real}(d::GenericMvTDist, x::DenseVector{T})
function sqmahal{T<:Real}(d::GenericMvTDist, x::DenseVector{T})
z::Vector{Float64} = d.zeromean ? x : x - d.μ
invquad(d.Σ, z)
invquad(d.Σ, z)
end

function sqmahal!{T<:Real}(r::DenseArray, d::GenericMvTDist, x::DenseMatrix{T})
Expand Down
10 changes: 4 additions & 6 deletions src/multivariate/vonmisesfisher.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ meandir(d::VonMisesFisher) = d.μ
concentration(d::VonMisesFisher) = d.κ

insupport{T<:Real}(d::VonMisesFisher, x::DenseVector{T}) = isunitvec(x)

params(d::VonMisesFisher) = (d.μ, d.κ)

### Evaluation

Expand All @@ -51,13 +51,13 @@ function _vmflck(p, κ)
q = hp - 1.0
q * log(κ) - hp * log(2π) - log(besseli(q, κ))
end
_vmflck3(κ) = log(κ) - log2π - κ - log1mexp(-2.0 * κ)
_vmflck3(κ) = log(κ) - log2π - κ - log1mexp(-2.0 * κ)
vmflck(p, κ) = (p == 3 ? _vmflck3(κ) : _vmflck(p, κ))::Float64

_logpdf{T<:Real}(d::VonMisesFisher, x::DenseVector{T}) = d.logCκ + d.κ * dot(d.μ, x)


### Sampling
### Sampling

sampler(d::VonMisesFisher) = VonMisesFisherSampler(d.μ, d.κ)

Expand Down Expand Up @@ -106,6 +106,4 @@ function _vmf_estkappa(p::Int, ρ::Float64)
return κ
end

_vmfA(half_p::Float64, κ::Float64) = besseli(half_p, κ) / besseli(half_p - 1.0, κ)


_vmfA(half_p::Float64, κ::Float64) = besseli(half_p, κ) / besseli(half_p - 1.0, κ)
21 changes: 21 additions & 0 deletions src/testutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ function test_distr(distr::DiscreteUnivariateDistribution, n::Int)

test_stats(distr, vs)
test_samples(distr, n)
test_params(distr)
end


Expand All @@ -50,6 +51,7 @@ function test_distr(distr::ContinuousUnivariateDistribution, n::Int)

xs = test_samples(distr, n)
allow_test_stats(distr) && test_stats(distr, xs)
test_params(distr)
end


Expand Down Expand Up @@ -505,3 +507,22 @@ function test_stats(d::ContinuousUnivariateDistribution, xs::AbstractVector{Floa
end
end
end

function test_params(d::Distribution)
# simply test that params returns something sufficient to
# reconstruct d
D = typeof(d)
pars = params(d)
d_new = D(pars...)
@test d_new == d
end

function test_params(d::Truncated)
# simply test that params returns something sufficient to
# reconstruct d
d_unt = d.untruncated
D = typeof(d_unt)
pars = params(d_unt)
d_new = Truncated(D(pars...), d.lower, d.upper)
@test d_new == d
end
2 changes: 2 additions & 0 deletions src/univariate/discrete/noncentralhypergeometric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ function quantile(d::NoncentralHypergeometric, q::Float64)
end
end

params(d::NoncentralHypergeometric) = (d.ns, d.nf, d.n, d.ω)

## Fisher's noncentral hypergeometric distribution

immutable FisherNoncentralHypergeometric <: NoncentralHypergeometric
Expand Down
1 change: 1 addition & 0 deletions test/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ d = Dirichlet(v)
@test length(d) == length(v)
@test d.alpha == v
@test d.alpha0 == sum(v)
@test d == typeof(d)(params(d)...)

@test_approx_eq mean(d) v / sum(v)
@test_approx_eq cov(d) [8 -2 -6; -2 5 -3; -6 -3 9] / (36 * 7)
Expand Down
1 change: 1 addition & 0 deletions test/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ IW = InverseWishart(v,S)
for d in [W,IW]
@test size(d) == size(rand(d))
@test length(d) == length(rand(d))
@test typeof(d)(params(d)...) == d
end

@test_approx_eq_eps mean(rand(W,100000)) mean(W) 0.1
Expand Down
15 changes: 15 additions & 0 deletions test/mixture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,19 @@ function test_mixture(g::MultivariateMixture, n::Int, ns::Int)
@test_approx_eq_eps cov(Xs, vardim=2) cov(g) 0.01
end

function test_params(g::AbstractMixtureModel)
C = eltype(g.components)
pars = params(g)
mm = MixtureModel(C, pars...)
@test g.prior == mm.prior
@test g.components == mm.components
end

function test_params(g::UnivariateGMM)
pars = params(g)
mm = UnivariateGMM(pars...)
@test g == mm
end

# Tests

Expand All @@ -131,11 +143,13 @@ g_u = MixtureModel(Normal, [(0.0, 1.0), (2.0, 1.0), (-4.0, 1.5)], [0.2, 0.5, 0.3
@test isa(g_u, MixtureModel{Univariate, Continuous, Normal})
@test ncomponents(g_u) == 3
test_mixture(g_u, 1000, 10^6)
test_params(g_u)

g_u = UnivariateGMM([0.0, 2.0, -4.0], [1.0, 1.2, 1.5], Categorical([0.2, 0.5, 0.3]))
@test isa(g_u, UnivariateGMM)
@test ncomponents(g_u) == 3
test_mixture(g_u, 1000, 10^6)
test_params(g_u)

println(" testing MultivariateMixture")
g_m = MixtureModel(
Expand All @@ -147,3 +161,4 @@ g_m = MixtureModel(
@test length(components(g_m)) == 3
@test length(g_m) == 2
test_mixture(g_m, 1000, 10^6)
test_params(g_m)
2 changes: 1 addition & 1 deletion test/multinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ x = rand(d)
@test insupport(d, x)
@test size(x) == size(d)
@test length(x) == length(d)
@test d == typeof(d)(params(d)...)

x = rand(d, 100)
@test isa(x, Matrix{Int})
Expand Down Expand Up @@ -86,4 +87,3 @@ r = fit_mle(Multinomial, x, fill(2.0, size(x,2)))
@test r.n == nt
@test length(r) == length(p)
@test_approx_eq_eps probs(r) p 0.02

27 changes: 13 additions & 14 deletions test/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Distributions: distrname

function test_mvnormal(g::AbstractMvNormal, n_tsamples::Int=10^6)
d = length(g)
μ = mean(g)
μ = mean(g)
Σ = cov(g)
@test isa(μ, Vector{Float64})
@test isa(Σ, Matrix{Float64})
Expand All @@ -22,6 +22,7 @@ function test_mvnormal(g::AbstractMvNormal, n_tsamples::Int=10^6)
ldcov = logdetcov(g)
@test_approx_eq ldcov logdet(Σ)
vs = diag(Σ)
@test g == typeof(g)(params(g)...)

# sampling
X = rand(g, n_tsamples)
Expand Down Expand Up @@ -61,35 +62,35 @@ h = [1., 2., 3.]
dv = [1.2, 3.4, 2.6]
J = [4. -2. -1.; -2. 5. -1.; -1. -1. 6.]

for (T, g, μ, Σ) in [
(IsoNormal, MvNormal(mu, sqrt(2.0)), mu, 2.0 * eye(3)),
(ZeroMeanIsoNormal, MvNormal(3, sqrt(2.0)), zeros(3), 2.0 * eye(3)),
(DiagNormal, MvNormal(mu, sqrt(va)), mu, diagm(va)),
(ZeroMeanDiagNormal, MvNormal(sqrt(va)), zeros(3), diagm(va)),
(FullNormal, MvNormal(mu, C), mu, C),
for (T, g, μ, Σ) in [
(IsoNormal, MvNormal(mu, sqrt(2.0)), mu, 2.0 * eye(3)),
(ZeroMeanIsoNormal, MvNormal(3, sqrt(2.0)), zeros(3), 2.0 * eye(3)),
(DiagNormal, MvNormal(mu, sqrt(va)), mu, diagm(va)),
(ZeroMeanDiagNormal, MvNormal(sqrt(va)), zeros(3), diagm(va)),
(FullNormal, MvNormal(mu, C), mu, C),
(ZeroMeanFullNormal, MvNormal(C), zeros(3), C),
(IsoNormalCanon, MvNormalCanon(h, 2.0), h / 2.0, 0.5 * eye(3)),
(ZeroMeanIsoNormalCanon, MvNormalCanon(3, 2.0), zeros(3), 0.5 * eye(3)),
(DiagNormalCanon, MvNormalCanon(h, dv), h ./ dv, diagm(1.0 ./ dv)),
(ZeroMeanDiagNormalCanon, MvNormalCanon(dv), zeros(3), diagm(1.0 ./ dv)),
(FullNormalCanon, MvNormalCanon(h, J), J \ h, inv(J)),
(FullNormalCanon, MvNormalCanon(h, J), J \ h, inv(J)),
(ZeroMeanFullNormalCanon, MvNormalCanon(J), zeros(3), inv(J)) ]

println(" testing $(distrname(g))")

@test isa(g, T)
@test_approx_eq mean(g) μ
@test_approx_eq cov(g) Σ
test_mvnormal(g)
test_mvnormal(g)

# conversion between mean form and canonical form
if isa(g, MvNormal)
gc = canonform(g)
gc = canonform(g)
@test isa(gc, MvNormalCanon)
@test length(gc) == length(g)
@test_approx_eq mean(gc) mean(g)
@test_approx_eq cov(gc) cov(g)
else
else
@assert isa(g, MvNormalCanon)
gc = meanform(g)
@test isa(gc, MvNormal)
Expand All @@ -116,7 +117,7 @@ function _gauss_mle(x::Matrix{Float64}, w::Vector{Float64})
mu = (x * w) * (1/sw)
z = x .- mu
C = (z * scale(w, z')) * (1/sw)
Base.LinAlg.copytri!(C, 'U')
Base.LinAlg.copytri!(C, 'U')
return mu, C
end

Expand Down Expand Up @@ -161,5 +162,3 @@ g = fit_mle(DiagNormal, x, w)
@test isa(g, DiagNormal)
@test_approx_eq g.μ uw
@test_approx_eq g.Σ.diag diag(Cw)


1 change: 1 addition & 0 deletions test/mvtdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ df = [1., 2, 3, 5, 10]
for i = 1:length(df)
d = MvTDist(df[i], mu, Sigma)
@test_approx_eq_eps logpdf(d, [-2., 3]) rvalues[i] 1.0e-8
@test d == typeof(d)(params(d)...)
end
4 changes: 3 additions & 1 deletion test/noncentralhypergeometric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ n = 100
# http://en.wikipedia.org/wiki/Fisher's_noncentral_hypergeometric_distribution
ω = 10.0
d = FisherNoncentralHypergeometric(ns, nf, n, ω)
@test d == typeof(d)(params(d)...)

@test_approx_eq_eps mean(d) 71.95759 1e-5
@test mode(d) == 72
Expand Down Expand Up @@ -50,6 +51,7 @@ n = 100
# http://en.wikipedia.org/wiki/Fisher's_noncentral_hypergeometric_distribution
ω = 10.0
d = WalleniusNoncentralHypergeometric(ns, nf, n, ω)
@test d == typeof(d)(params(d)...)

@test_approx_eq_eps mean(d) 78.82945 1e-5
@test mode(d) == 80
Expand Down Expand Up @@ -81,4 +83,4 @@ ref = Hypergeometric(ns,nf,n)
@test_approx_eq_eps cdf(d, 51) cdf(ref, 51) 1e-7
@test_approx_eq_eps quantile(d, 0.05) quantile(ref, 0.05) 1e-7
@test_approx_eq_eps quantile(d, 0.95) quantile(ref, 0.95) 1e-7
@test mode(d) == mode(ref)
@test mode(d) == mode(ref)
2 changes: 1 addition & 1 deletion test/vonmisesfisher.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ function test_vonmisesfisher(p::Int, κ::Float64, n::Int, ns::Int)
@test length(d) == p
@test meandir(d) == μ
@test concentration(d) == κ
@test d == typeof(d)(params(d)...)
# println(d)

θ = κ * μ
Expand Down Expand Up @@ -72,4 +73,3 @@ for (p, κ) in [(2, 1.0),

test_vonmisesfisher(p, κ, n, ns)
end

0 comments on commit 8ff3fe2

Please sign in to comment.