Skip to content

Commit

Permalink
specify GibbsIntegrator
Browse files Browse the repository at this point in the history
  • Loading branch information
joannajzou committed Aug 5, 2024
1 parent 1e9c9e5 commit 4628b7f
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 21 deletions.
8 changes: 4 additions & 4 deletions src/Distributions/mixture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,27 +79,27 @@ function rand(d::MixtureModel{Union{Univariate,Multivariate}, Continuous, Gibbs}
end

# 2 - pdf
function pdf(d::MixtureModel{Union{Univariate,Multivariate}, Continuous, Gibbs}, x::Float64, normint::Integrator)
function pdf(d::MixtureModel{Union{Univariate,Multivariate}, Continuous, Gibbs}, x::Float64, normint::GibbsIntegrator)
p = probs(d)
return sum(p_i * pdf(component(d, i), x, normint) for (i, p_i) in enumerate(p) if !iszero(p_i))
end

# multiple samples of x
function pdf(d::MixtureModel{Union{Univariate,Multivariate}, Continuous, Gibbs}, xsamp::Vector{Float64}, normint::Integrator)
function pdf(d::MixtureModel{Union{Univariate,Multivariate}, Continuous, Gibbs}, xsamp::Vector{Float64}, normint::GibbsIntegrator)
p = probs(d)
return sum(p_i * updf.((component(d, i),), xsamp) ./ normconst(component(d, i), normint) for (i, p_i) in enumerate(p) if !iszero(p_i))
end


# 3 - log unnormalized pdf
function logpdf(d::MixtureModel{Union{Univariate,Multivariate}, Continuous, Gibbs}, x::Float64, normint::Integrator)
function logpdf(d::MixtureModel{Union{Univariate,Multivariate}, Continuous, Gibbs}, x::Float64, normint::GibbsIntegrator)
p = probs(d)
lp = logsumexp(log(p_i) + logpdf(component(d, i), x, normint) for (i, p_i) in enumerate(p) if !iszero(p_i))
return lp
end

# multiple samples of x
function logpdf(d::MixtureModel{Union{Univariate,Multivariate}, Continuous, Gibbs}, xsamp::Vector{Float64}, normint::Integrator)
function logpdf(d::MixtureModel{Union{Univariate,Multivariate}, Continuous, Gibbs}, xsamp::Vector{Float64}, normint::GibbsIntegrator)
p = probs(d)
data = reduce(hcat, [log(p_i) .+ logupdf.((component(d, i),), xsamp) .- log(normconst(component(d, i), normint)) for (i, p_i) in enumerate(p) if !iszero(p_i)])
lp = logsumexp.([data[i,:] for i = 1:size(data,1)])
Expand Down
16 changes: 8 additions & 8 deletions src/Integrators/importance_sampling.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
abstract type ISIntegrator <: Integrator end
abstract type ISIntegrator <: GibbsIntegrator end

"""
struct ISMC <: ISIntegrator
Expand Down Expand Up @@ -44,15 +44,15 @@ This method is implemented with the user providing samples from the biasing dist
# Arguments
- `g :: Distribution` : biasing distribution
- `xsamp :: Vector` : fixed set of samples
- `normint :: Union{Integrator, Nothing}` : integrator for computing normalizing constant of biasing distribution (required for mixture models)
- `normint :: Union{GibbsIntegrator, Nothing}` : integrator for computing normalizing constant of biasing distribution (required for mixture models)
"""
mutable struct ISSamples <: ISIntegrator
g :: Distribution
xsamp :: Vector
normint :: Union{Integrator, Nothing}
normint :: Union{GibbsIntegrator, Nothing}

function ISSamples(g::MixtureModel, xsamp::Vector, normint::Integrator)
function ISSamples(g::MixtureModel, xsamp::Vector, normint::GibbsIntegrator)
return new(g, xsamp, normint)
end

Expand All @@ -76,7 +76,7 @@ This method is implemented with the user providing samples from each component m
- `n :: Int` : number of samples
- `knl :: Kernel` : kernel function to compute mixture weights
- `xsamp :: Vector` : Vector of sample sets from each component distribution
- `normint :: Integrator` : Integrator for the approximating the normalizing constant of each component distribution
- `normint :: GibbsIntegrator` : Integrator for the approximating the normalizing constant of each component distribution
"""
mutable struct ISMixSamples <: ISIntegrator
Expand All @@ -85,14 +85,14 @@ mutable struct ISMixSamples <: ISIntegrator
n :: Int
knl :: Kernel
xsamp :: Vector
normint :: Integrator
normint :: GibbsIntegrator


function ISMixSamples(g::MixtureModel, refs::Vector, n::Int, knl::Kernel, xsamp::Vector, normint::Integrator)
function ISMixSamples(g::MixtureModel, refs::Vector, n::Int, knl::Kernel, xsamp::Vector, normint::GibbsIntegrator)
return new(g, refs, n, knl, xsamp, normint)
end

function ISMixSamples(g::MixtureModel, n::Int, knl::Kernel, xsamp::Vector, normint::Integrator)
function ISMixSamples(g::MixtureModel, n::Int, knl::Kernel, xsamp::Vector, normint::GibbsIntegrator)
d = new()
d.g = g
d.refs = [πg.θ for πg in g.components]
Expand Down
25 changes: 20 additions & 5 deletions src/Integrators/integrators.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
abstract type Integrator end
"""
Integrator
GibbsIntegrator
A struct of abstract type Integrator computes the expectation of a function h(x, θ) with respect to an invariant measure p(x, θ).
A struct of abstract type GibbsIntegrator computes the expectation of a function h(x, θ) with respect to the invariant measure p(x, θ).
"""
abstract type Integrator end
abstract type GibbsIntegrator <: Integrator end

"""
PathIntegrator
A struct of abstract type PathIntegrator computes the expectation of a functional h with respect to the path measure P.
"""
abstract type PathIntegrator <: Integrator end

include("quadrature.jl")
include("monte_carlo.jl")
include("importance_sampling.jl")
include("path_integrators.jl")


export
Integrator,
GibbsIntegrator,
MCIntegrator,
MonteCarlo,
MCMC,
MCSamples,
MCPaths,
ISIntegrator,
ISMC,
ISMCMC,
Expand All @@ -23,4 +34,8 @@ export
QuadIntegrator,
GaussQuadrature,
gaussquad,
gaussquad_2D
gaussquad_2D,
PathIntegrator,
RiemannIntegrator,
ItoIntegrator,
compute_integral
6 changes: 3 additions & 3 deletions src/Integrators/monte_carlo.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
abstract type MCIntegrator <: Integrator end
abstract type MCIntegrator <: GibbsIntegrator end

"""
struct MonteCarlo <: MCIntegrator
Expand All @@ -23,12 +23,12 @@ This method is implemented when the distribution cannot be analytically sampled.
# Arguments
- `n :: Int` : number of samples
- `sampler :: Sampler` : type of sampler (see `Sampler`)
- `ρ0 :: Distribution` : prior distribution of the state
- `ρ0 :: Union{Distribution, Real, Vector{<:Real}}` : initial state or prior distribution of the state
"""
struct MCMC <: MCIntegrator
n :: Int
sampler :: Sampler
ρ0 :: Distribution
ρ0 :: Union{Distribution, Real, Vector{<:Real}}
end


Expand Down
2 changes: 1 addition & 1 deletion src/Integrators/quadrature.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
abstract type QuadIntegrator <: Integrator end
abstract type QuadIntegrator <: GibbsIntegrator end


"""
Expand Down

0 comments on commit 4628b7f

Please sign in to comment.