From ba8c452d377a5815ee04d908408298965575e069 Mon Sep 17 00:00:00 2001 From: qhho Date: Sat, 27 Jul 2024 13:30:32 -0700 Subject: [PATCH 1/3] witness pruning with sets in alpha vectors --- src/alpha.jl | 1 + src/backup.jl | 2 +- src/prune.jl | 57 +++++++++++++++++++++++++-------------------------- src/tree.jl | 2 +- 4 files changed, 31 insertions(+), 31 deletions(-) diff --git a/src/alpha.jl b/src/alpha.jl index 5364ac7..ce3bf0c 100644 --- a/src/alpha.jl +++ b/src/alpha.jl @@ -1,6 +1,7 @@ struct AlphaVec{A} <: AbstractVector{Float64} alpha::Vector{Float64} action::A + witnesses::Set{Int} end @inline Base.length(v::AlphaVec) = length(v.alpha) diff --git a/src/backup.jl b/src/backup.jl index ada1f62..b48b5d9 100644 --- a/src/backup.jl +++ b/src/backup.jl @@ -77,7 +77,7 @@ function backup!(tree, b_idx) end end - α = AlphaVec(best_α, best_action) + α = AlphaVec(best_α, best_action, Set(b_idx)) push!(Γ, α) tree.V_lower[b_idx] = V end diff --git a/src/prune.jl b/src/prune.jl index 7f19834..14023f9 100644 --- a/src/prune.jl +++ b/src/prune.jl @@ -48,16 +48,27 @@ function prune!(tree::SARSOPTree) end end -function belief_space_domination(α1, α2, B, δ) - a1_dominant = true - a2_dominant = true - for b ∈ B - !a1_dominant && !a2_dominant && return (false, false) - δV = intersection_distance(α1, α2, b) - δV ≤ δ && (a1_dominant = false) - δV ≥ -δ && (a2_dominant = false) +function recertify_witnesses!(tree, α1, α2, δ) + + if α1 == α2 + union!(α2.witnesses, α1.witnesses) + empty!(α1.witnesses) + return + end + + for b_idx in α1.witnesses + if tree.b_pruned[b_idx] + delete!(α1.witnesses, b_idx) + continue + end + + δV = intersection_distance(α2, α1, tree.b[b_idx]) + + if δV > δ + delete!(α1.witnesses, b_idx) + push!(α2.witnesses, b_idx) + end end - return a1_dominant, a2_dominant end @inline function intersection_distance(α1, α2, b) @@ -75,33 +86,21 @@ end function prune_alpha!(tree::SARSOPTree, δ) Γ = tree.Γ - B_valid = tree.b[map(!,tree.b_pruned)] pruned = falses(length(Γ)) - # checking if α_i dominates α_j - for (i,α_i) ∈ enumerate(Γ) + for (i, α_i) ∈ enumerate(Γ) pruned[i] && continue - for (j,α_j) ∈ enumerate(Γ) - (j ≤ i || pruned[j]) && continue - a1_dominant,a2_dominant = belief_space_domination(α_i, α_j, B_valid, δ) - #= - NOTE: α1 and α2 shouldn't technically be able to mutually dominate - i.e. a1_dominant and a2_dominant should never both be true. - But this does happen when α1 == α2 because intersection_distance returns NaN. - Current impl prunes α2 without doing an equality check, removing - the duplicate α. Could do equality check to short-circuit - belief_space_domination which would speed things up if we have - a lot of duplicates, but the equality check can slow things down - if α's are sufficiently diverse. - =# - if a1_dominant - pruned[j] = true - elseif a2_dominant + for (j, α_j) ∈ enumerate(Γ) + pruned[j] || j == i && continue + recertify_witnesses!(tree, α_i, α_j, δ) + if isempty(α_i.witnesses) pruned[i] = true break + elseif isempty(α_j.witnesses) + pruned[j] = true end end end deleteat!(Γ, pruned) tree.prune_data.last_Γ_size = length(Γ) -end +end \ No newline at end of file diff --git a/src/tree.jl b/src/tree.jl index 5443eb2..cef6fc7 100644 --- a/src/tree.jl +++ b/src/tree.jl @@ -95,7 +95,7 @@ function insert_root!(solver, tree::SARSOPTree, b) Γ_lower = solve(solver.init_lower, pomdp) for (α,a) ∈ alphapairs(Γ_lower) new_val = dot(α, b) - push!(tree.Γ, AlphaVec(α, a)) + push!(tree.Γ, AlphaVec(α, a, Set(1))) end tree.prune_data.last_Γ_size = length(tree.Γ) From 3e10a8c51424af1744e902ea1d63a1ad306dcaed Mon Sep 17 00:00:00 2001 From: qhho Date: Sun, 28 Jul 2024 20:31:21 -0700 Subject: [PATCH 2/3] duplicate check during backup --- src/backup.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/backup.jl b/src/backup.jl index b48b5d9..b54d8b7 100644 --- a/src/backup.jl +++ b/src/backup.jl @@ -76,9 +76,13 @@ function backup!(tree, b_idx) best_action = a end end - - α = AlphaVec(best_α, best_action, Set(b_idx)) - push!(Γ, α) + α_idx = findfirst(x->x == best_α, Γ) + if α_idx === nothing + α = AlphaVec(best_α, best_action, Set(b_idx)) + push!(Γ, α) + else + union!(Γ[α_idx].witnesses, b_idx) + end tree.V_lower[b_idx] = V end From f6a2fcc77d695be96ed9d6d225b8c55ead2f600e Mon Sep 17 00:00:00 2001 From: qhho Date: Sun, 28 Jul 2024 21:59:40 -0700 Subject: [PATCH 3/3] fix bracket in pruning check short-circuiting --- src/prune.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/prune.jl b/src/prune.jl index 14023f9..2e1691c 100644 --- a/src/prune.jl +++ b/src/prune.jl @@ -91,7 +91,7 @@ function prune_alpha!(tree::SARSOPTree, δ) for (i, α_i) ∈ enumerate(Γ) pruned[i] && continue for (j, α_j) ∈ enumerate(Γ) - pruned[j] || j == i && continue + (pruned[j] || j == i) && continue recertify_witnesses!(tree, α_i, α_j, δ) if isempty(α_i.witnesses) pruned[i] = true