Skip to content

Commit

Permalink
constraints and rollout accept context
Browse files Browse the repository at this point in the history
  • Loading branch information
lassepe committed Feb 11, 2024
1 parent f537aea commit 9b4d0a9
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 22 deletions.
4 changes: 2 additions & 2 deletions src/costs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ struct TimeSeparableTrajectoryGameCost{T1,T2,T3}
discount_factor::Float64
end

function (c::TimeSeparableTrajectoryGameCost)(xs, us, context_state)
function (c::TimeSeparableTrajectoryGameCost)(xs, us, context)
ts = Iterators.eachindex(xs)
Iterators.map(xs, us, ts) do x, u, t
c.discount_factor^(t - 1) .* c.stage_cost(x, u, t, context_state)
c.discount_factor^(t - 1) .* c.stage_cost(x, u, t, context)
end |> c.reducer
end

Expand Down
15 changes: 8 additions & 7 deletions src/dynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ time varying or not.
function temporal_structure_trait end

"""
dynamics(state, actions[, t]).
dynamics(state, controls[, t, context]).
Computes the next state for a collecition of player `actions` applied to the multi-player \
Computes the next state for a BlockVector of player `controls` (one block per player) applied to the multi-player \
`dynamics`.
"""
abstract type AbstractDynamics end
Expand Down Expand Up @@ -62,7 +62,7 @@ applying the inputs dictated by the strategy.
Kwargs:
- `get_info` is a callback `(γ, x, t) -> info` that can be passed to extract additional info from \
- `get_info` is a callback `(γ, x, t, context) -> info` that can be passed to extract additional info from \
the strategy `γ` for each rollout state `x` and time `t`.
- `skip_last_strategy_call = false`. Setting this to true avoids calling the strategy on the last \
time step because this input will never be applied anyway. In that case, us will have one element \
Expand All @@ -75,17 +75,18 @@ function rollout(
strategy,
x1,
T = horizon(dynamics);
get_info = (γ, x, t) -> nothing,
get_info = (γ, x, t, context) -> nothing,
skip_last_strategy_call = false,
context = nothing,
)
xs = sizehint!([x1], T)
us = sizehint!([strategy(x1, 1)], T)
infos = sizehint!([get_info(strategy, x1, 1)], T)
infos = sizehint!([get_info(strategy, x1, 1, context)], T)

time_steps = 1:(T - 1)

for tt in time_steps
xp = dynamics(xs[tt], us[tt], tt)
xp = dynamics(xs[tt], us[tt], tt, context)
push!(xs, xp)

if skip_last_strategy_call && tt == lastindex(time_steps)
Expand All @@ -94,7 +95,7 @@ function rollout(

up = strategy(xp, tt + 1)
push!(us, up)
infop = get_info(strategy, xs[tt], tt + 1)
infop = get_info(strategy, xs[tt], tt + 1, context)
push!(infos, infop)
end

Expand Down
2 changes: 1 addition & 1 deletion src/environment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ end

function get_constraints(environment::PolygonEnvironment, player_index = nothing)
constraints = LazySets.constraints_list(environment.set)
function (state)
function (state, ::Any = nothing)
positions = (substate[1:2] for substate in blocks(state))
mapreduce(vcat, Iterators.product(constraints, positions)) do (constraint, position)
-constraint.a' * position + constraint.b
Expand Down
4 changes: 2 additions & 2 deletions src/game.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
Base.@kwdef struct TrajectoryGame{TD<:AbstractDynamics,TC,TE,TS}
"An object that describes the dynamics of this trajectory game"
dynamics::TD
"A cost function taking (xs, us, [context]) with states `xs` and inputs `us` in Blocks and an
"A cost function taking (xs, us, context) with states `xs` and inputs `us` in Blocks and an
optional `context` information. Returns a collection of cost values; one per player."
cost::TC
"The environment object that characerizes static constraints of the problem and can be used for
visualization."
env::TE
"An object which encodes the constraints between different players. It must be callable as
`con(xs, us) -> gs`: returning a collection of scalar constraints `gs` each of which is negative
`con(xs, us, context) -> gs`: returning a collection of scalar constraints `gs` each of which is negative
if the corresponding contraint is active."
coupling_constraints::TS = nothing
end
Expand Down
4 changes: 2 additions & 2 deletions src/linear_dynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ function time_invariant_linear_dynamics(; A, B, horizon = ∞, bounds...)
LinearDynamics(; A = Fill(A, horizon), B = Fill(B, horizon), bounds...)
end

function (sys::LinearDynamics)(x, u, t::Int)
function (sys::LinearDynamics)(x, u, t::Int, ::Any = nothing)
sys.A[t] * x + sys.B[t] * u
end

function (sys::LinearDynamics)(x, u, ::Nothing = nothing)
function (sys::LinearDynamics)(x, u, ::Nothing = nothing, ::Any = nothing)

Check warning on line 58 in src/linear_dynamics.jl

View check run for this annotation

Codecov / codecov/patch

src/linear_dynamics.jl#L58

Added line #L58 was not covered by tests
temporal_structure_trait(sys) isa TimeInvariant ||
error("Only time-invariant systems can ommit the `t` argument.")
sys.A.value * x + sys.B.value * u
Expand Down
11 changes: 9 additions & 2 deletions src/product_dynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,15 @@ Base.@kwdef struct ProductDynamics{T} <: AbstractDynamics
end
end

function (dynamics::ProductDynamics)(x::AbstractBlockArray, u::AbstractBlockArray, t = nothing)
mortar([sub(x̂, u, t) for (sub, x̂, u) in zip(dynamics.subsystems, blocks(x), blocks(u))])
function (dynamics::ProductDynamics)(
x::AbstractBlockArray,
u::AbstractBlockArray,
t = nothing,
context = nothing,
)
mortar([
sub(x̂, u, t, context) for (sub, x̂, u) in zip(dynamics.subsystems, blocks(x), blocks(u))
])
end

function state_dim(dynamics::ProductDynamics)
Expand Down
6 changes: 3 additions & 3 deletions src/solve.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
solve_trajectory_game!(solver, game, initial_state)
solve_trajectory_game!(solver, game, initial_state; context, kwargs...)
Computes a joint strategy `γ` for all players for the given `game` for game-play starting at
`initial_state`. This call may modify the `solver` itself; e.g., due to learning of parameters or
updates of initial guesses for repeated calls. That is, a subsequent call of
`initial_state`. This call may modify the `solver` itself; e.g., due to learning of parameters,
updates of initial guesses, or advancing RNG states. That is, a subsequent call of
`solve_trajectory_game!` on the *same* `initial_state` may result in a *different* strategy.
"""
function solve_trajectory_game! end
7 changes: 4 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ function TrajectoryGamesBase.solve_trajectory_game!(
solver,
game,
initial_state;
initial_guess=nothing
context=nothing,
initial_guess=nothing,
)
@test !isnothing(initial_guess)
TrajectoryGamesBase.JointStrategy([trivial_strategy, trivial_strategy])
Expand Down Expand Up @@ -133,9 +134,9 @@ end # Mock module
@testset "environment" begin
constraints = get_constraints(game.env)
# probe a feasible state
@test all(constraints(zeros(4)) .> 0)
@test all(constraints(zeros(4), context) .> 0)
# probe an infeasible state
@test any(constraints(fill(10, 4)) .< 0)
@test any(constraints(fill(10, 4), context) .< 0)
end

@testset "receding horizon utils" begin
Expand Down

0 comments on commit 9b4d0a9

Please sign in to comment.