diff --git a/src/fallbacks.jl b/src/fallbacks.jl index a8dc9ce42..be7e6538d 100644 --- a/src/fallbacks.jl +++ b/src/fallbacks.jl @@ -187,6 +187,7 @@ pmf(d::DiscreteDistribution, args::Any...) = pdf(d, args...) ### Gradient (derivative of logpdf) gradloglik(d::UnivariateDistribution, x::Real) = gradloglik(d, float64(x)) +gradloglik(d::MultivariateDistribution, x::Vector{Real}) = gradloglik(d, float64(x)) #### Sampling: rand & rand! #### diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index fdc8d5ef8..d9d0eb20d 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -115,6 +115,10 @@ function logpdf!(r::Array{Float64}, d::AbstractMvNormal, x::Matrix{Float64}) r end +function gradloglik(d::GenericMvNormal, x::Vector{Float64}) + z::Vector{Float64} = d.zeromean ? x : x - d.μ + -invcov(d)*z +end # Sampling (for GenericMvNormal) diff --git a/src/multivariate/mvtdist.jl b/src/multivariate/mvtdist.jl index 49dc3d88e..a9a3e2ee3 100644 --- a/src/multivariate/mvtdist.jl +++ b/src/multivariate/mvtdist.jl @@ -74,7 +74,9 @@ mode(d::GenericMvTDist) = d.μ modes(d::GenericMvTDist) = [mode(d)] var(d::GenericMvTDist) = d.df>2 ? (d.df/(d.df-2))*diag(d.Σ) : Float64[NaN for i = 1:d.dim] +scale(d::GenericMvTDist) = full(d.Σ) cov(d::GenericMvTDist) = d.df>2 ? (d.df/(d.df-2))*full(d.Σ) : NaN*ones(d.dim, d.dim) +invscale(d::GenericMvTDist) = full(inv(d.Σ)) invcov(d::GenericMvTDist) = d.df>2 ? ((d.df-2)/d.df)*full(inv(d.Σ)) : NaN*ones(d.dim, d.dim) logdet_cov(d::GenericMvTDist) = d.df>2 ? logdet((d.df/(d.df-2))*d.Σ) : NaN @@ -125,6 +127,12 @@ function logpdf!(r::Array{Float64}, d::AbstractMvTDist, x::Matrix{Float64}) r end +function gradloglik(d::GenericMvTDist, x::Vector{Float64}) + z::Vector{Float64} = d.zeromean ? x : x - d.μ + prz = invscale(d)*z + -((d.df + d.dim) / (d.df + dot(z, prz))) * prz +end + # Sampling (for GenericMvTDist) function rand!(d::GenericMvTDist, x::Vector{Float64}) diff --git a/test/gradloglik.jl b/test/gradloglik.jl index 7c3641795..6412d7a3e 100644 --- a/test/gradloglik.jl +++ b/test/gradloglik.jl @@ -1,6 +1,8 @@ using Distributions using Base.Test +# Test for gradloglik on univariate distributions + @test_approx_eq_eps gradloglik(Beta(1.5, 3.0), 0.7) -5.9523809523809526 1.0e-8 @test_approx_eq_eps gradloglik(Chi(5.0), 5.5) -4.7727272727272725 1.0e-8 @test_approx_eq_eps gradloglik(Chisq(7.0), 12.0) -0.29166666666666663 1.0e-8 @@ -13,3 +15,8 @@ using Base.Test @test_approx_eq_eps gradloglik(Normal(-4.5, 2.0), 1.6) -1.525 1.0e-8 @test_approx_eq_eps gradloglik(TDist(8.0), 9.1) -0.9018830525272548 1.0e-8 @test_approx_eq_eps gradloglik(Weibull(2.0), 3.5) -6.714285714285714 1.0e-8 + +# Test for gradloglik on multivariate distributions + +@test_approx_eq_eps gradloglik(MvNormal([1., 2.], [1. 0.1; 0.1 1.]), [0.7, 0.9]) [0.191919191919192, 1.080808080808081] 1.0e-8 +@test_approx_eq_eps gradloglik(MvTDist(5., [1., 2.], [1. 0.1; 0.1 1.]), [0.7, 0.9]) [0.2150711513583442, 1.2111901681759383] 1.0e-8