Skip to content

Commit

Permalink
Merge pull request #74 from adolgert/feature/copy-sampler
Browse files Browse the repository at this point in the history
Feature/copy sampler
  • Loading branch information
adolgert authored Jun 2, 2024
2 parents 9a2ed10 + 1b7fa30 commit f070409
Show file tree
Hide file tree
Showing 18 changed files with 242 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/src/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ enable!
disable!
next
reset!
copy!
```

## Query a Sampler
Expand Down
8 changes: 8 additions & 0 deletions src/prefixsearch/binarytreeprefixsearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ function Base.empty!(ps::BinaryTreePrefixSearch)
ps.cnt = 0
end

function Base.copy!(dst::BinaryTreePrefixSearch{T}, src::BinaryTreePrefixSearch{T}) where {T}
copy!(dst.array, src.array)
dst.depth = src.depth
dst.offset = src.offset
dst.cnt = src.cnt
dst.initial_allocation = src.initial_allocation
end


time_type(ps::BinaryTreePrefixSearch{T}) where {T} = T
time_type(::Type{BinaryTreePrefixSearch{T}}) where {T} = T
Expand Down
4 changes: 4 additions & 0 deletions src/prefixsearch/cumsumprefixsearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ function Base.empty!(ps::CumSumPrefixSearch)
empty!(ps.cumulant)
end

function Base.copy!(dst::CumSumPrefixSearch{T}, src::CumSumPrefixSearch{T}) where {T}
copy!(dst.array, src.array)
copy!(dst.cumulant, src.cumulant)
end

Base.length(ps::CumSumPrefixSearch) = length(ps.array)
time_type(ps::CumSumPrefixSearch{T}) where {T} = T
Expand Down
14 changes: 14 additions & 0 deletions src/prefixsearch/keyedprefixsearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ function Base.empty!(kp::KeyedKeepPrefixSearch)
empty!(kp.prefix)
end

function Base.copy!(dst::KeyedKeepPrefixSearch{T,P}, src::KeyedKeepPrefixSearch{T,P}) where {T,P}
copy!(dst.index, src.index)
copy!(dst.key, src.key)
copy!(dst.prefix, src.prefix)
dst
end


Base.length(kp::KeyedKeepPrefixSearch) = length(kp.index)
time_type(kp::KeyedKeepPrefixSearch{T,P}) where {T,P} = time_type(P)
Expand Down Expand Up @@ -93,6 +100,13 @@ function Base.empty!(kp::KeyedRemovalPrefixSearch)
empty!(kp.prefix)
end

function Base.copy!(dst::KeyedRemovalPrefixSearch{T,P}, src::KeyedRemovalPrefixSearch{T,P}) where {T,P}
copy!(dst.index, src.index)
copy!(dst.key, src.key)
copy!(dst.free, src.free)
copy!(dst.prefix, src.prefix)
dst
end

Base.length(kp::KeyedRemovalPrefixSearch) = length(kp.index)

Expand Down
8 changes: 7 additions & 1 deletion src/sample/combinednr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ sampling_space(::LinearGamma) = LinearSampling
If you want to test a distribution, look at `tests/nrmetric.jl` to see how
distributions are timed.
"""
struct CombinedNextReaction{K,T} <: SSA{K,T}
mutable struct CombinedNextReaction{K,T} <: SSA{K,T}
firing_queue::MutableBinaryMinHeap{OrderedSample{K,T}}
transition_entry::Dict{K,NRTransition{T}}
end
Expand All @@ -153,6 +153,12 @@ function reset!(nr::CombinedNextReaction)
nothing
end

function Base.copy!(dst::CombinedNextReaction{K,T}, src::CombinedNextReaction{K,T}) where {K,T}
dst.firing_queue = deepcopy(src.firing_queue)
copy!(dst.transition_entry, src.transition_entry)
end


@doc raw"""
For the first reaction sampler, you can call next() multiple times and get
different, valid, answers. That isn't the case here. When you call next()
Expand Down
2 changes: 2 additions & 0 deletions src/sample/direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ end

reset!(dc::DirectCall) = (empty!(dc.prefix_tree); nothing)

Base.copy!(dst::DirectCall{K,T,P}, src::DirectCall{K,T,P}) where {K,T,P} = copy!(dst.prefix_tree, src.prefix_tree)


"""
enable!(dc::DirectCall, clock::T, distribution::Exponential, when, rng)
Expand Down
1 change: 1 addition & 0 deletions src/sample/firstreaction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ end


reset!(fr::FirstReaction) = reset!(fr.core_matrix)
Base.copy!(dst::FirstReaction{K,T}, src::FirstReaction{K,T}) where {K,T} = (copy!(dst.core_matrix, src.core_matrix); dst)


function enable!(fr::FirstReaction{K,T}, clock::K, distribution::UnivariateDistribution,
Expand Down
8 changes: 7 additions & 1 deletion src/sample/firsttofire.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fire and saves that time in a sorted heap of future times. Then it works
through the heap, one by one. When a clock is disabled, its future firing time
is removed from the list. There is no memory of previous firing times.
"""
struct FirstToFire{K,T} <: SSA{K,T}
mutable struct FirstToFire{K,T} <: SSA{K,T}
firing_queue::MutableBinaryMinHeap{OrderedSample{K,T}}
# This maps from transition to entry in the firing queue.
transition_entry::Dict{K,Int}
Expand All @@ -31,6 +31,12 @@ function reset!(propagator::FirstToFire{K,T}) where {K,T}
empty!(propagator.transition_entry)
end

function Base.copy!(dst::FirstToFire{K,T}, src::FirstToFire{K,T}) where {K,T}
dst.firing_queue = deepcopy(src.firing_queue)
copy!(dst.transition_entry, src.transition_entry)
dst
end


# Finds the next one without removing it from the queue.
function next(propagator::FirstToFire{K,T}, when::T, rng::AbstractRNG) where {K,T}
Expand Down
11 changes: 11 additions & 0 deletions src/sample/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ for another sample run.
function reset!(sampler::SSA{K,T}) where {K,T} end


"""
copy!(destination_sampler, source_sampler)
This copies the state of the source sampler to the destination sampler, replacing
the current state of the destination sampler. This is useful for splitting
techniques where you make copies of a simulation and restart it with different
random number generators.
"""
function Base.copy!(sampler::SSA{K,T}) where {K,T} end


"""
disable!(sampler, clock, when)
Expand Down
13 changes: 13 additions & 0 deletions src/sample/multiple_direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ function reset!(md::MultipleDirect)
end


function Base.copy!(
dst::MultipleDirect{SamplerKey,K,Time,Chooser},
src::MultipleDirect{SamplerKey,K,Time,Chooser}
) where {SamplerKey,K,Time,Chooser}
copy!(dst.scan, src.scan)
copy!(dst.totals, src.totals)
dst.chooser = deepcopy(src.chooser)
copy!(dst.chosen, src.chosen)
copy!(dst.scanmap, src.scanmap)
dst
end


function Base.setindex!(
md::MultipleDirect{SamplerKey,K,Time,Chooser}, keyed_prefix_search, key
) where {SamplerKey,K,Time,Chooser}
Expand Down
16 changes: 16 additions & 0 deletions src/sample/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ function SingleSampler(propagator::SSA{Key,Time}) where {Key,Time}
SingleSampler{SSA{Key,Time},Time}(propagator, zero(Time))
end

function Base.copy!(dst::SingleSampler{Algorithm,Time}, src::SingleSampler{Algorithm,Time}) where {Algorithm,Time}
copy!(dst.propagator, src.propagator)
dst.when = src.when
dst
end

function sample!(sampler::SingleSampler, rng::AbstractRNG)
when, transition = next(sampler.propagator, sampler.when, rng)
Expand Down Expand Up @@ -143,6 +148,17 @@ function reset!(sampler::MultiSampler)
end


function Base.copy!(
dst::MultiSampler{SamplerKey,Key,Time,Chooser},
src::MultiSampler{SamplerKey,Key,Time,Chooser}
) where {SamplerKey,Key,Time,Chooser}

copy!(dst.propagator, src.propagator)
dst.when = src.when
dst
end


function Base.setindex!(
sampler::MultiSampler{SamplerKey,Key,Time}, algorithm::SSA{Key,Time}, sampler_key::SamplerKey
) where {SamplerKey,Key,Time}
Expand Down
8 changes: 8 additions & 0 deletions src/sample/track.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ end

reset!(ts::TrackWatcher) = (empty!(ts.enabled); nothing)

function Base.copy!(dst::TrackWatcher{K,T}, src::TrackWatcher{K,T}) where {K,T}
copy!(dst.enabled, src.enabled)
end

function Base.iterate(ts::TrackWatcher)
return iterate(values(ts.enabled))
end
Expand Down Expand Up @@ -99,6 +103,10 @@ end

reset!(ts::DebugWatcher) = (empty!(ts.enabled); empty!(ts.disabled); nothing)

function Base.copy!(dst::DebugWatcher{K,T}, src::DebugWatcher{K,T}) where {K,T}
copy!(dst.enabled, src.enabled)
copy!(dst.disabled, src.disabled)
end

function enable!(ts::DebugWatcher{K,T}, clock::K, dist::UnivariateDistribution, te, when, rng) where {K,T}
push!(ts.enabled, EnablingEntry(clock, dist, te, when))
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ end
include("test_keyedprefixsearch.jl")
end

@testset "test_track.jl" begin
include("test_track.jl")
end

@testset "test_combinednr.jl" begin
include("test_combinednr.jl")
Expand Down
30 changes: 29 additions & 1 deletion test/test_combinednr.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using SafeTestsets


@safetestset CombinedNextReactionSmoke = "combinednext reaction does basic things" begin
@safetestset CombinedNextReactionSmoke = "CombinedNextReaction reaction does basic things" begin
using Distributions
using Random
using CompetingClocks: CombinedNextReaction, next, enable!, disable!, reset!
Expand Down Expand Up @@ -49,3 +49,31 @@ end
@test sampler[2] == 12.3

end


@safetestset CombinedNextReaction_copy = "CombinedNextReaction copy" begin
using CompetingClocks
using Distributions
using Random: Xoshiro

src = CombinedNextReaction{Int64,Float64}()
dst = clone(src)
rng = Xoshiro(123)

enable!(src, 37, Exponential(), 0.0, 0.0, rng)
enable!(src, 38, Exponential(), 0.0, 0.0, rng)
enable!(dst, 29, Exponential(), 0.0, 0.0, rng)
@test length(src) == 2
@test length(dst) == 1
copy!(dst, src)
@test length(src) == 2
@test length(dst) == 2
# Changing src doesn't change dst.
enable!(src, 48, Exponential(), 0.0, 0.0, rng)
@test length(src) == 3
@test length(dst) == 2
# Changing dst doesn't change src.
enable!(dst, 49, Exponential(), 0.0, 0.0, rng)
@test length(src) == 3
@test length(dst) == 3
end
24 changes: 24 additions & 0 deletions test/test_direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,27 @@ end
md = DirectCall{Int,Float64}()
test_exponential_binomial(md, rng)
end


@safetestset direct_call_copy = "DirectCall copy" begin
using CompetingClocks: DirectCall, enable!, next
using Random: MersenneTwister
using Distributions: Exponential

src = DirectCall{Int,Float64}()
dst = DirectCall{Int,Float64}()
rng = MersenneTwister(90422342)
enable!(src, 1, Exponential(), 0.0, 0.0, rng)
enable!(src, 2, Exponential(), 0.0, 0.0, rng)
enable!(dst, 3, Exponential(), 0.0, 0.0, rng)
@test length(src) == 2
@test length(dst) == 1
copy!(dst, src)
@test length(dst) == 2
enable!(src, 5, Exponential(), 0.0, 0.0, rng)
@test length(src) == 3
@test length(dst) == 2
enable!(dst, 6, Exponential(), 0.0, 0.0, rng)
@test length(src) == 3
@test length(dst) == 3
end
25 changes: 25 additions & 0 deletions test/test_firstreaction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,28 @@ end
ks2_test = ExactOneSampleKSTest(shifted, dist)
@test pvalue(ks2_test; tail = :both) > 0.04
end



@safetestset first_reaction_copy = "FirstReaction copy" begin
using CompetingClocks: FirstReaction, enable!, next
using Random: MersenneTwister
using Distributions: Exponential

src = FirstReaction{Int,Float64}()
dst = FirstReaction{Int,Float64}()
rng = MersenneTwister(90422342)
enable!(src, 1, Exponential(), 0.0, 0.0, rng)
enable!(src, 2, Exponential(), 0.0, 0.0, rng)
enable!(dst, 3, Exponential(), 0.0, 0.0, rng)
@test length(src) == 2
@test length(dst) == 1
copy!(dst, src)
@test length(dst) == 2
enable!(src, 5, Exponential(), 0.0, 0.0, rng)
@test length(src) == 3
@test length(dst) == 2
enable!(dst, 6, Exponential(), 0.0, 0.0, rng)
@test length(src) == 3
@test length(dst) == 3
end
24 changes: 24 additions & 0 deletions test/test_firsttofire.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,27 @@ end
@test propagator[2] == 12.3

end


@safetestset FirstToFire_copy = "FirstToFire copy" begin
using CompetingClocks: FirstToFire, enable!, next
using Random: MersenneTwister
using Distributions: Exponential

src = FirstToFire{Int,Float64}()
dst = FirstToFire{Int,Float64}()
rng = MersenneTwister(90422342)
enable!(src, 1, Exponential(), 0.0, 0.0, rng)
enable!(src, 2, Exponential(), 0.0, 0.0, rng)
enable!(dst, 3, Exponential(), 0.0, 0.0, rng)
@test length(src) == 2
@test length(dst) == 1
copy!(dst, src)
@test length(dst) == 2
enable!(src, 5, Exponential(), 0.0, 0.0, rng)
@test length(src) == 3
@test length(dst) == 2
enable!(dst, 6, Exponential(), 0.0, 0.0, rng)
@test length(src) == 3
@test length(dst) == 3
end
45 changes: 45 additions & 0 deletions test/test_track.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using SafeTestsets


@safetestset track_trackwatcher_smoke = "TrackWatcher smoke" begin
using Distributions
using CompetingClocks
using Random
rng = Xoshiro(3242234)
tw = TrackWatcher{Int,Float64}()
enable!(tw, 3, Exponential(), 0.0, 0.0, rng)
@test length(tw.enabled) == 1 && 3 keys(tw.enabled)
enable!(tw, 4, Exponential(), 0.0, 3.0, rng)
@test length(tw.enabled) == 2 && 4 keys(tw.enabled)
enable!(tw, 7, Exponential(), 5.0, 5.0, rng)
@test length(tw.enabled) == 3 && 7 keys(tw.enabled)
disable!(tw, 4, 9.0)
@test length(tw.enabled) == 2 && 4 keys(tw.enabled)

dst = TrackWatcher{Int,Float64}()
enable!(dst, 11, Exponential(), 5.0, 5.0, rng)
copy!(dst, tw)
@test length(tw.enabled) == 2 && 11 keys(tw.enabled)
end


@safetestset track_debugwatcher_smoke = "DebugWatcher smoke" begin
using Distributions
using CompetingClocks
using Random
rng = Xoshiro(3242234)
dw = DebugWatcher{Int,Float64}()
enable!(dw, 3, Exponential(), 0.0, 0.0, rng)
@test dw.enabled[1].clock == 3
enable!(dw, 4, Exponential(), 0.0, 3.0, rng)
@test dw.enabled[2].clock == 4
enable!(dw, 7, Exponential(), 5.0, 5.0, rng)
@test dw.enabled[3].clock == 7
disable!(dw, 4, 9.0)
@test dw.disabled[1].clock == 4

dst = DebugWatcher{Int,Float64}()
enable!(dst, 11, Exponential(), 5.0, 5.0, rng)
copy!(dst, dw)
@test length(dw.enabled) == 3 && length(dw.disabled) == 1
end

0 comments on commit f070409

Please sign in to comment.