diff --git a/src/prune.jl b/src/prune.jl index 01c1b8e..667fa4b 100644 --- a/src/prune.jl +++ b/src/prune.jl @@ -56,16 +56,17 @@ end d = dot_b / sqrt(dot_diff) return d end + function prune_alpha!(tree::SARSOPTree, δ, eps=0.0) Γ = tree.Γ - B_valid = tree.b[map(!,tree.b_pruned)] - + B_valid = tree.b[map(!, tree.b_pruned)] + n_Γ = length(Γ) n_B = length(B_valid) - + dominant_indices_bools = falses(n_Γ) dominant_vector_indices = Vector{Int}(undef, n_B) - + # First, identify dominant alpha vectors for b_idx in 1:n_B max_value = -Inf @@ -84,7 +85,7 @@ function prune_alpha!(tree::SARSOPTree, δ, eps=0.0) non_dominant_indices = findall(!, dominant_indices_bools) n_non_dom = length(non_dominant_indices) keep_non_dom = falses(n_non_dom) - + for b_idx in 1:n_B dom_vec_idx = dominant_vector_indices[b_idx] for j in 1:n_non_dom @@ -98,7 +99,7 @@ function prune_alpha!(tree::SARSOPTree, δ, eps=0.0) end end end - + non_dominant_indices = non_dominant_indices[.!keep_non_dom] deleteat!(Γ, non_dominant_indices) tree.prune_data.last_Γ_size = length(Γ) @@ -116,7 +117,7 @@ end function prune_strictly_dominated!(tree::SARSOPTree, eps=1e-10) Γ = tree.Γ Γ_new_idxs = [] - + for (α_try_idx, α_try) in enumerate(Γ) marked_for_deletion = falses(length(Γ_new_idxs)) dominated = false @@ -134,7 +135,7 @@ function prune_strictly_dominated!(tree::SARSOPTree, eps=1e-10) push!(Γ_new_idxs, α_try_idx) end end - + Γ_idxs_to_delete = setdiff(1:length(Γ), Γ_new_idxs) deleteat!(Γ, Γ_idxs_to_delete) end