Skip to content

Commit

Permalink
GTPSA ext new feature compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
mattsignorelli committed Jan 22, 2025
1 parent 0f8481c commit 9452a55
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ ForwardDiff = "0.10"
FunctionWrappers = "1.0"
FunctionWrappersWrappers = "0.1"
GeneralizedGenerated = "0.3"
GTPSA = "1.3"
GTPSA = "1.4"
LinearAlgebra = "1.9"
Logging = "1.9"
MPI = "0.20"
Expand Down
10 changes: 5 additions & 5 deletions ext/DiffEqBaseGTPSAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@ else
end

value(x::TPS) = scalar(x)
value(::Type{TPS{T}}) where {T} = T
value(::Type{<:TPS{T}}) where {T} = T

ODE_DEFAULT_NORM(u::TPS, t) = normTPS(u)
ODE_DEFAULT_NORM(f::F, u::TPS, t) where {F} = normTPS(f(u))

function ODE_DEFAULT_NORM(u::AbstractArray{TPS{T}}, t) where {T}
x = zero(real(T))
function ODE_DEFAULT_NORM(u::AbstractArray{<:TPS}, t)
x = zero(real(GTPSA.numtype(eltype(u))))
@inbounds @fastmath for ui in u
x += normTPS(ui)^2
end
Base.FastMath.sqrt_fast(x / max(length(u), 1))
end

function ODE_DEFAULT_NORM(f::F, u::AbstractArray{TPS{T}}, t) where {F, T}
x = zero(real(T))
function ODE_DEFAULT_NORM(f::F, u::AbstractArray{<:TPS}, t) where {F}
x = zero(real(GTPSA.numtype(eltype(u))))
@inbounds @fastmath for ui in u
x += normTPS(f(ui))^2
end
Expand Down
6 changes: 3 additions & 3 deletions test/downstream/gtpsa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ sol = solve(prob, Tsit5(), reltol=1e-16, abstol=1e-16)

# Parametric GTPSA map
desc = Descriptor(3, 2, 3, 2) # 3 variables 3 parameters, both to 2nd order
dx = vars(desc)
dp = params(desc)
dx = @vars(desc)
dp = @params(desc)
prob_GTPSA = ODEProblem(f!, x .+ dx, (0.0, 1.0), p .+ dp)
sol_GTPSA = solve(prob_GTPSA, Tsit5(), reltol=1e-16, abstol=1e-16)

Expand Down Expand Up @@ -50,7 +50,7 @@ prob = DynamicalODEProblem(pdot!, qdot!, [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], (0.0,
sol = solve(prob, Yoshida6(), dt = 1.0, reltol=1e-16, abstol=1e-16)

desc = Descriptor(6, 2) # 6 variables to 2nd order
dx = vars(desc) # identity map
dx = @vars(desc) # identity map
prob_GTPSA = DynamicalODEProblem(pdot!, qdot!, dx[1:3], dx[4:6], (0.0, 25.0))
sol_GTPSA = solve(prob_GTPSA, Yoshida6(), dt = 1.0, reltol=1e-16, abstol=1e-16)

Expand Down

0 comments on commit 9452a55

Please sign in to comment.