-
Notifications
You must be signed in to change notification settings - Fork 1
/
sampler.jl
93 lines (70 loc) · 2.12 KB
/
sampler.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
using HDPHMM
using Test
using Distributions
import ConjugatePriors: NormalInverseChisq
import HDPHMM: DPMMObservationModelStats
# TODO: Add tests with missing values:
# - Float64 observations only
# - Mixed Float64/Missing observations
# - Missing observations only
# TODO: Basic inference test with simple HMM,
# test that we find the same state sequence,
# and similar distributions (?)
function fake(::Type{DPMMObservationModelStats}; L = 10, LP = 5)
n = rand(1:100, L, L)
np = rand(1:100, L, LP)
Y = Matrix{Vector{Float64}}(undef, L, LP)
for i in eachindex(Y); Y[i] = rand(10); end
DPMMObservationModelStats(n, np, Y)
end
@testset "Initial Distribution" begin
L = 10
d = InitialStateDistribution(L, 1)
@test_nowarn resample(d, 1)
end
@testset "Transition Distribution" begin
L = 10
p = TransitionDistributionPrior(
Gamma(1, 1/0.001),
Gamma(1, 1/0.001),
Beta(50, 1)
)
d = TransitionDistribution(L, p)
n = zeros(L, L)
@test_nowarn resample(d, p, n)
n = rand(0:100, L, L)
@test_nowarn resample(d, p, n)
end
@testset "DPMM - MixtureModel" begin
d = MixtureModel([Normal(0,1) for _ in 1:10])
p = NormalInverseChisq(10, 2, 1, 1)
@test_nowarn resample(d, p, 1.0, [rand(100) for _ in 1:10])
end
@testset "DPMM - Stats" begin
L, LP = 10, 5
prior = DPMMObservationModelPrior{Normal}(
NormalInverseChisq(10, 2, 1, 1),
Gamma(1, 0.5)
)
m = DPMMObservationModel(L, LP, prior)
stats = fake(DPMMObservationModelStats)
@test_nowarn resample(m, prior, stats.n, stats.np, stats.Y)
end
@testset "Sampler" begin
L, LP = 10, 5
tp = TransitionDistributionPrior(
Gamma(1, 1/0.001),
Gamma(1, 1/0.001),
Beta(500, 1)
)
op = DPMMObservationModelPrior{Normal}(
NormalInverseChisq(1, 1, 1, 1),
Gamma(1, 0.5),
)
sampler = BlockedSampler(L, LP)
prior = BlockedSamplerPrior(1.0, tp, op)
state = BlockedSamplerState(sampler, prior)
# TODO: Multivariate observations
# TODO: Missing observations
@test_nowarn resample(sampler, state, prior, rand(1000))
end