From 0c06146d535321637008387fdabd4c085d8a0dfe Mon Sep 17 00:00:00 2001 From: John Muchovej <5000729+jmuchovej@users.noreply.github.com> Date: Mon, 8 Jul 2024 22:45:18 -0400 Subject: [PATCH 1/2] Adding support for custom initial beliefs (from solver) Signed-off-by: John Muchovej <5000729+jmuchovej@users.noreply.github.com> --- src/solver.jl | 7 ++++--- src/sparse_tabular.jl | 7 +++---- src/tree.jl | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/solver.jl b/src/solver.jl index 71fbd42..f4e39bb 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -11,8 +11,8 @@ Base.@kwdef struct SARSOPSolver{LOW,UP} <: Solver prunethresh::Float64= 0.10 end -function POMDPTools.solve_info(solver::SARSOPSolver, pomdp::POMDP) - tree = SARSOPTree(solver, pomdp) +function POMDPTools.solve_info(solver::SARSOPSolver, pomdp::POMDP; b0=initialstate(pomdp)) + tree = SARSOPTree(solver, pomdp, b0) if solver.verbose initialize_verbose_output() @@ -48,7 +48,8 @@ function POMDPTools.solve_info(solver::SARSOPSolver, pomdp::POMDP) ) end -POMDPs.solve(solver::SARSOPSolver, pomdp::POMDP) = first(solve_info(solver, pomdp)) +POMDPs.solve(solver::SARSOPSolver, pomdp::POMDP; b0=initialstate(pomdp)) = + first(solve_info(solver, pomdp; b0)) function initialize_verbose_output() dashed_line() diff --git a/src/sparse_tabular.jl b/src/sparse_tabular.jl index 125efec..f455c96 100644 --- a/src/sparse_tabular.jl +++ b/src/sparse_tabular.jl @@ -7,7 +7,7 @@ struct ModifiedSparseTabular <: POMDP{Int,Int,Int} discount::Float64 end -function ModifiedSparseTabular(pomdp::POMDP) +function ModifiedSparseTabular(pomdp::POMDP, b0) S = ordered_states(pomdp) A = ordered_actions(pomdp) O = ordered_observations(pomdp) @@ -16,7 +16,7 @@ function ModifiedSparseTabular(pomdp::POMDP) T = transition_matrix_a_sp_s(pomdp) R = _tabular_rewards(pomdp, S, A, terminal) O = POMDPTools.ModelTools.observation_matrix_a_sp_o(pomdp) - b0 = _vectorized_initialstate(pomdp, S) + b0 = _vectorized_initialstate(b0, S) return ModifiedSparseTabular(T,R,O,terminal,b0,discount(pomdp)) end @@ -76,8 +76,7 @@ function _vectorized_terminal(pomdp, S) return term end -function _vectorized_initialstate(pomdp, S) - b0 = initialstate(pomdp) +function _vectorized_initialstate(b0, S) b0_vec = Vector{Float64}(undef, length(S)) @inbounds for i ∈ eachindex(S, b0_vec) b0_vec[i] = pdf(b0, S[i]) diff --git a/src/tree.jl b/src/tree.jl index 5443eb2..671deaf 100644 --- a/src/tree.jl +++ b/src/tree.jl @@ -35,8 +35,8 @@ struct SARSOPTree end -function SARSOPTree(solver, pomdp::POMDP) - sparse_pomdp = ModifiedSparseTabular(pomdp) +function SARSOPTree(solver, pomdp::POMDP, b0) + sparse_pomdp = ModifiedSparseTabular(pomdp, b0) cache = TreeCache(sparse_pomdp) upper_policy = solve(solver.init_upper, sparse_pomdp) From c73bdc387e101a6c8ea215d98a41e4b4449ade1a Mon Sep 17 00:00:00 2001 From: John Muchovej <5000729+jmuchovej@users.noreply.github.com> Date: Mon, 15 Jul 2024 13:56:56 -0400 Subject: [PATCH 2/2] Update custom beliefs to `SARSOPSolver.root_belief` approach --- src/solver.jl | 77 +++++++++++++++++++++++++------------------ src/sparse_tabular.jl | 27 ++++++++------- src/tree.jl | 46 +++++++++++++------------- 3 files changed, 82 insertions(+), 68 deletions(-) diff --git a/src/solver.jl b/src/solver.jl index f4e39bb..a3e8d8a 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -1,26 +1,29 @@ -Base.@kwdef struct SARSOPSolver{LOW,UP} <: Solver - epsilon::Float64 = 0.5 - precision::Float64 = 1e-3 - kappa::Float64 = 0.5 - delta::Float64 = 1e-1 - max_time::Float64 = 1.0 - max_steps::Int = typemax(Int) - verbose::Bool = true - init_lower::LOW = BlindLowerBound(bel_res = 1e-2) - init_upper::UP = FastInformedBound(bel_res=1e-2) - prunethresh::Float64= 0.10 +_root_belief(pomdp::POMDP) = initialstate(pomdp) + +Base.@kwdef struct SARSOPSolver{LOW, UP, ROOT} <: Solver + epsilon::Float64 = 0.5 + precision::Float64 = 1e-3 + kappa::Float64 = 0.5 + delta::Float64 = 1e-1 + max_time::Float64 = 1.0 + max_steps::Int = typemax(Int) + verbose::Bool = true + init_lower::LOW = BlindLowerBound(; bel_res=1e-2) + init_upper::UP = FastInformedBound(; bel_res=1e-2) + prunethresh::Float64 = 0.10 + root_belief::ROOT = _root_belief end -function POMDPTools.solve_info(solver::SARSOPSolver, pomdp::POMDP; b0=initialstate(pomdp)) - tree = SARSOPTree(solver, pomdp, b0) - +function POMDPTools.solve_info(solver::SARSOPSolver, pomdp::POMDP) + tree = SARSOPTree(solver, pomdp) + if solver.verbose initialize_verbose_output() end - + t0 = time() iter = 0 - while time()-t0 < solver.max_time && root_diff(tree) > solver.precision + while time() - t0 < solver.max_time && root_diff(tree) > solver.precision sample!(solver, tree) backup!(tree) prune!(solver, tree) @@ -30,38 +33,48 @@ function POMDPTools.solve_info(solver::SARSOPSolver, pomdp::POMDP; b0=initialsta iter += 1 end - if solver.verbose + if solver.verbose dashed_line() log_verbose_info(t0, iter, tree) dashed_line() end - + pol = AlphaVectorPolicy( pomdp, getproperty.(tree.Γ, :alpha), - ordered_actions(pomdp)[getproperty.(tree.Γ, :action)] - ) - return pol, (; - time = time()-t0, - tree, - iter + ordered_actions(pomdp)[getproperty.(tree.Γ, :action)], ) + return pol, (; time=time() - t0, tree, iter) end -POMDPs.solve(solver::SARSOPSolver, pomdp::POMDP; b0=initialstate(pomdp)) = - first(solve_info(solver, pomdp; b0)) +POMDPs.solve(solver::SARSOPSolver, pomdp::POMDP) = first(solve_info(solver, pomdp)) function initialize_verbose_output() dashed_line() - @printf(" %-10s %-10s %-12s %-12s %-15s %-10s %-10s\n", - "Time", "Iter", "LB", "UB", "Precision", "# Alphas", "# Beliefs") - dashed_line() + @printf( + " %-10s %-10s %-12s %-12s %-15s %-10s %-10s\n", + "Time", + "Iter", + "LB", + "UB", + "Precision", + "# Alphas", + "# Beliefs" + ) + return dashed_line() end function log_verbose_info(t0::Float64, iter::Int, tree::SARSOPTree) - @printf(" %-10.2f %-10d %-12.7f %-12.7f %-15.10f %-10d %-10d\n", - time()-t0, iter, tree.V_lower[1], tree.V_upper[1], root_diff(tree), - length(tree.Γ), length(tree.b_pruned) - sum(tree.b_pruned)) + @printf( + " %-10.2f %-10d %-12.7f %-12.7f %-15.10f %-10d %-10d\n", + time() - t0, + iter, + tree.V_lower[1], + tree.V_upper[1], + root_diff(tree), + length(tree.Γ), + length(tree.b_pruned) - sum(tree.b_pruned) + ) end function dashed_line(n=86) diff --git a/src/sparse_tabular.jl b/src/sparse_tabular.jl index f455c96..fcc554f 100644 --- a/src/sparse_tabular.jl +++ b/src/sparse_tabular.jl @@ -1,4 +1,4 @@ -struct ModifiedSparseTabular <: POMDP{Int,Int,Int} +struct ModifiedSparseTabular <: POMDP{Int, Int, Int} T::Vector{SparseMatrixCSC{Float64, Int64}} # T[a][sp, s] R::Array{Float64, 2} # R[s,a] O::Vector{SparseMatrixCSC{Float64, Int64}} # O[a][sp, o] @@ -17,7 +17,7 @@ function ModifiedSparseTabular(pomdp::POMDP, b0) R = _tabular_rewards(pomdp, S, A, terminal) O = POMDPTools.ModelTools.observation_matrix_a_sp_o(pomdp) b0 = _vectorized_initialstate(b0, S) - return ModifiedSparseTabular(T,R,O,terminal,b0,discount(pomdp)) + return ModifiedSparseTabular(T, R, O, terminal, b0, discount(pomdp)) end function transition_matrix_a_sp_s(mdp::Union{MDP, POMDP}) @@ -26,20 +26,20 @@ function transition_matrix_a_sp_s(mdp::Union{MDP, POMDP}) ns = length(S) na = length(A) - - transmat_row_A = [Int64[] for _ in 1:na] - transmat_col_A = [Int64[] for _ in 1:na] - transmat_data_A = [Float64[] for _ in 1:na] - for (si,s) in enumerate(S) - for (ai,a) in enumerate(A) + transmat_row_A = [Int64[] for _ ∈ 1:na] + transmat_col_A = [Int64[] for _ ∈ 1:na] + transmat_data_A = [Float64[] for _ ∈ 1:na] + + for (si, s) ∈ enumerate(S) + for (ai, a) ∈ enumerate(A) if isterminal(mdp, s) # if terminal, there is a probability of 1 of staying in that state push!(transmat_row_A[ai], si) push!(transmat_col_A[ai], si) push!(transmat_data_A[ai], 1.0) else td = transition(mdp, s, a) - for (sp, p) in weighted_iterator(td) + for (sp, p) ∈ weighted_iterator(td) if p > 0.0 spi = stateindex(mdp, sp) push!(transmat_row_A[ai], spi) @@ -50,7 +50,10 @@ function transition_matrix_a_sp_s(mdp::Union{MDP, POMDP}) end end end - transmats_A_SP_S = [sparse(transmat_row_A[a], transmat_col_A[a], transmat_data_A[a], ns, ns) for a in 1:na] + transmats_A_SP_S = [ + sparse(transmat_row_A[a], transmat_col_A[a], transmat_data_A[a], ns, ns) for + a ∈ 1:na + ] return transmats_A_SP_S end @@ -65,12 +68,12 @@ function _tabular_rewards(pomdp, S, A, terminal) R[s_idx, a_idx] = reward(pomdp, s, a) end end - R + return R end function _vectorized_terminal(pomdp, S) term = BitVector(undef, length(S)) - @inbounds for i ∈ eachindex(term,S) + @inbounds for i ∈ eachindex(term, S) term[i] = isterminal(pomdp, S[i]) end return term diff --git a/src/tree.jl b/src/tree.jl index 671deaf..381695d 100644 --- a/src/tree.jl +++ b/src/tree.jl @@ -7,7 +7,7 @@ end struct SARSOPTree pomdp::ModifiedSparseTabular - b::Vector{SparseVector{Float64,Int}} # b_idx => belief vector + b::Vector{SparseVector{Float64, Int}} # b_idx => belief vector b_children::Vector{UnitRange{Int}} # [b_idx][a_idx] => ba_idx Vs_upper::Vector{Float64} V_upper::Vector{Float64} @@ -34,9 +34,8 @@ struct SARSOPTree Γ::Vector{AlphaVec{Int}} end - -function SARSOPTree(solver, pomdp::POMDP, b0) - sparse_pomdp = ModifiedSparseTabular(pomdp, b0) +function SARSOPTree(solver, pomdp::POMDP) + sparse_pomdp = ModifiedSparseTabular(pomdp, solver.root_belief(pomdp)) cache = TreeCache(sparse_pomdp) upper_policy = solve(solver.init_upper, sparse_pomdp) @@ -44,7 +43,6 @@ function SARSOPTree(solver, pomdp::POMDP, b0) tree = SARSOPTree( sparse_pomdp, - Vector{Float64}[], Vector{Int}[], corner_values, #upper_policy.util, @@ -63,8 +61,8 @@ function SARSOPTree(solver, pomdp::POMDP, b0) Vector{Int}(), BitVector(), cache, - PruneData(0,0,solver.prunethresh), - AlphaVec{Int}[] + PruneData(0, 0, solver.prunethresh), + AlphaVec{Int}[], ) return insert_root!(solver, tree, _initialize_belief(pomdp, initialstate(pomdp))) end @@ -82,7 +80,7 @@ POMDPs.discount(tree::SARSOPTree) = discount(tree.pomdp) function _initialize_belief(pomdp::POMDP, dist::Any=initialstate(pomdp)) ns = length(states(pomdp)) b = zeros(ns) - for s in support(dist) + for s ∈ support(dist) sidx = stateindex(pomdp, s) b[sidx] = pdf(dist, s) end @@ -93,7 +91,7 @@ function insert_root!(solver, tree::SARSOPTree, b) pomdp = tree.pomdp Γ_lower = solve(solver.init_lower, pomdp) - for (α,a) ∈ alphapairs(Γ_lower) + for (α, a) ∈ alphapairs(Γ_lower) new_val = dot(α, b) push!(tree.Γ, AlphaVec(α, a)) end @@ -118,7 +116,7 @@ function update(tree::SARSOPTree, b_idx::Int, a, o) ba_idx = tree.b_children[b_idx][a] bp_idx = tree.ba_children[ba_idx][o] V̲, V̄ = if tree.is_terminal[bp_idx] - 0.,0. + 0.0, 0.0 else lower_value(tree, tree.b[bp_idx]), upper_value(tree, tree.b[bp_idx]) end @@ -139,7 +137,7 @@ function add_belief!(tree::SARSOPTree, b, ba_idx::Int, o) push!(tree.is_terminal, terminal) V̲, V̄ = if terminal - 0., 0. + 0.0, 0.0 else lower_value(tree, b), upper_value(tree, b) end @@ -175,19 +173,19 @@ function fill_populated!(tree::SARSOPTree, b_idx::Int) b = tree.b[b_idx] Qa_upper = tree.Qa_upper[b_idx] Qa_lower = tree.Qa_lower[b_idx] - for a in ACT + for a ∈ ACT ba_idx = tree.b_children[b_idx][a] tree.ba_pruned[ba_idx] && continue Rba = belief_reward(tree, b, a) Q̄ = Rba Q̲ = Rba - for o in OBS + for o ∈ OBS bp_idx, V̲, V̄ = update(tree, b_idx, a, o) b′ = tree.b[bp_idx] po = tree.poba[ba_idx][o] - Q̄ += γ*po*V̄ - Q̲ += γ*po*V̲ + Q̄ += γ * po * V̄ + Q̲ += γ * po * V̲ end Qa_upper[a] = Q̄ @@ -195,7 +193,7 @@ function fill_populated!(tree::SARSOPTree, b_idx::Int) end tree.V_lower[b_idx] = lower_value(tree, tree.b[b_idx]) - tree.V_upper[b_idx] = maximum(tree.Qa_upper[b_idx]) + return tree.V_upper[b_idx] = maximum(tree.Qa_upper[b_idx]) end function fill_unpopulated!(tree::SARSOPTree, b_idx::Int) @@ -211,15 +209,15 @@ function fill_unpopulated!(tree::SARSOPTree, b_idx::Int) Qa_upper = Vector{Float64}(undef, N_ACT) Qa_lower = Vector{Float64}(undef, N_ACT) - b_children = (n_ba+1):(n_ba+N_ACT) + b_children = (n_ba + 1):(n_ba + N_ACT) - for a in A + for a ∈ A ba_idx = add_action!(tree, b_idx, a) - ba_children = (n_b+1):(n_b+N_OBS) + ba_children = (n_b + 1):(n_b + N_OBS) tree.ba_children[ba_idx] = ba_children n_b += N_OBS - pred = dropzeros!(mul!(tree.cache.pred, pomdp.T[a],b)) + pred = dropzeros!(mul!(tree.cache.pred, pomdp.T[a], b)) poba = zeros(Float64, N_OBS) Rba = belief_reward(tree, b, a) @@ -230,15 +228,15 @@ function fill_unpopulated!(tree::SARSOPTree, b_idx::Int) # belief update bp = corrector(pomdp, pred, a, o) po = sum(bp) - if po > 0. + if po > 0.0 bp.nzval ./= po poba[o] = po end bp_idx, V̲, V̄ = add_belief!(tree, bp, ba_idx, o) - Q̄ += γ*po*V̄ - Q̲ += γ*po*V̲ + Q̄ += γ * po * V̄ + Q̲ += γ * po * V̲ end Qa_upper[a] = Q̄ Qa_lower[a] = Q̲ @@ -247,5 +245,5 @@ function fill_unpopulated!(tree::SARSOPTree, b_idx::Int) tree.Qa_upper[b_idx] = Qa_upper tree.Qa_lower[b_idx] = Qa_lower tree.V_lower[b_idx] = lower_value(tree, tree.b[b_idx]) - tree.V_upper[b_idx] = maximum(tree.Qa_upper[b_idx]) + return tree.V_upper[b_idx] = maximum(tree.Qa_upper[b_idx]) end