Skip to content

Commit

Permalink
matmul alpha backup
Browse files Browse the repository at this point in the history
  • Loading branch information
WhiffleFish committed Feb 25, 2023
1 parent 29ed7e9 commit 6baf19a
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 48 deletions.
119 changes: 78 additions & 41 deletions src/backup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,63 +11,46 @@ function max_alpha_val(Γ, b)
return max_α.alpha
end

function backup_a!(α, pomdp::ModifiedSparseTabular, cache::TreeCache, a, Γao)
γ = discount(pomdp)
@inline function backup_a!(α, pomdp, a, β::AbstractArray{<:Number}, Γv)
R = @view pomdp.R[:,a]
T_a = pomdp.T[a]
Z_a = cache.Oᵀ[a]
Γa = @view Γao[:,:,a]
γ = discount(pomdp)
mul!(α, β, Γv)
return @. α = R + γ*α
end

Tnz = nonzeros(T_a)
Trv = rowvals(T_a)
Znz = nonzeros(Z_a)
Zrv = rowvals(Z_a)
function fill_alpha!(tree, b_idx, a)
(;pomdp, Γ) = tree
n = n_states(pomdp)
m = n_observations(pomdp)
Γa = tree.cache.Γ
@assert length(Γa) == m*n

for s eachindex(α)
v = 0.0
for sp_idx nzrange(T_a, s)
sp = Trv[sp_idx]
p = Tnz[sp_idx]
tmp = 0.0
for o_idx nzrange(Z_a, sp)
o = Zrv[o_idx]
po = Znz[o_idx]
tmp += po*Γa[sp, o]
end
v += tmp*p
end
α[s] = v
ba_idx = tree.b_children[b_idx][a]
for o observations(pomdp)
flat_idxs = n*(o-1)+1 : n*o
bp_idx = tree.ba_children[ba_idx][o]
bp = tree.b[bp_idx]
Γa[flat_idxs] .= max_alpha_val(Γ, bp)
end
@. α = R + γ*α
Γa
end

function backup!(tree, b_idx)
Γ = tree.Γ
(;Γ,β,pomdp,cache) = tree
Γv = cache.Γ

b = tree.b[b_idx]
pomdp = tree.pomdp
γ = discount(tree)
S = states(tree)
A = actions(tree)
O = observations(tree)

Γao = tree.cache.Γ

for a A
ba_idx = tree.b_children[b_idx][a]
for o O
bp_idx = tree.ba_children[ba_idx][o]
bp = tree.b[bp_idx]
Γao[:,o,a] .= max_alpha_val(Γ, bp)
end
end

V = -Inf
α_a = tree.cache.alpha # zeros(Float64, length(S))
α_a = cache.alpha # zeros(Float64, length(S))
best_α = zeros(Float64, length(S))
best_action = first(A)

for a A
α_a = backup_a!(α_a, pomdp, tree.cache, a, Γao)
fill_alpha!(tree, b_idx, a)
α_a = backup_a!(α_a, pomdp, a, β[a], Γv)
Qba = dot(α_a, b)
tree.Qa_lower[b_idx][a] = Qba
if Qba > V
Expand All @@ -87,3 +70,57 @@ function backup!(tree)
backup!(tree, tree.sampled[i])
end
end

function alpha_backup_lmap(T::Matrix, Zᵀ::Matrix)
@assert size(T,1) == size(T,2) == size(Zᵀ,2)
n = size(T,1) # T[sp, s]
m = size(Zᵀ,1) # Z[o, sp]

β = zeros(n,n*m)
for s 1:n
for o 1:m
for sp 1:n
row_idx = n*(o-1) + sp
β[s,row_idx] = T[sp, s]*Zᵀ[o, sp]
end
end
end
return β
end

# α[a]' = R[:,a] + γ*β[a]*Γ[a]
function alpha_backup_lmap(T::SparseMatrixCSC, Zᵀ::SparseMatrixCSC)
@assert size(T,1) == size(T,2) == size(Zᵀ,2)
n = size(T,1) # T[sp, s]
m = size(Zᵀ,1) # Zᵀ[o, sp]

Tnz = nonzeros(T)
Trv = rowvals(T)
Znz = nonzeros(Zᵀ)
Zrv = rowvals(Zᵀ)

β = zeros(n,n*m)
for s 1:n
for sp_idx nzrange(T, s)
sp = Trv[sp_idx]
pT = Tnz[sp_idx]
for o_idx nzrange(Zᵀ, sp)
o = Zrv[o_idx]
pZ = Znz[o_idx]
row_idx = n*(o-1) + sp
β[s,row_idx] = pT*pZ
end
end
end
return sparse(β)
end

function alpha_backup_lmap(pomdp::ModifiedSparseTabular)
A = actions(pomdp)
B = Vector{SparseMatrixCSC{Float64, Int}}(undef, length(A))
for a A
Oᵀ = sparse(transpose(pomdp.O[a]))
B[a] = alpha_backup_lmap(pomdp.T[a], Oᵀ)
end
return B
end
10 changes: 4 additions & 6 deletions src/cache.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
struct TreeCache
pred::SparseVector{Float64,Int}
alpha::Vector{Float64}
Γ::Array{Float64,3}
Oᵀ::Vector{SparseMatrixCSC{Float64, Int64}}
Γ::Vector{Float64}
end

function TreeCache(pomdp::ModifiedSparseTabular)
Ns = n_states(pomdp)
Na = n_actions(pomdp)
No = n_observations(pomdp)

pred = Vector{Float64}(undef, Ns)
alpha = Vector{Float64}(undef, Ns)
Γ = Array{Float64,3}(undef, Ns, No, Na)
Oᵀ = map(sparse transpose, pomdp.O)
return TreeCache(pred, alpha, Γ, Oᵀ)
Γ = Vector{Float64}(undef, Ns*No)
return TreeCache(pred, alpha, Γ)
end
4 changes: 3 additions & 1 deletion src/tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct SARSOPTree
prune_data::PruneData

Γ::Vector{AlphaVec{Int}}
β::Vector{SparseMatrixCSC{Float64, Int}}
end


Expand Down Expand Up @@ -66,7 +67,8 @@ function SARSOPTree(solver, pomdp::POMDP)
BitVector(),
cache,
PruneData(0,0,solver.prunethresh),
AlphaVec{Int}[]
AlphaVec{Int}[],
alpha_backup_lmap(sparse_pomdp)
)
return insert_root!(solver, tree, _initialize_belief(pomdp, initialstate(pomdp)))
end
Expand Down

0 comments on commit 6baf19a

Please sign in to comment.