Skip to content

Commit

Permalink
refactor: separate out resid_prototype calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 2, 2024
1 parent 4d4ff85 commit b063e6b
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,16 @@ function hessian_sparsity(sys::NonlinearSystem)
unknowns(sys)) for eq in equations(sys)]
end

function calculate_resid_prototype(N, u0, p)
u0ElType = u0 === nothing ? Float64 : eltype(u0)
if SciMLStructures.isscimlstructure(p)
u0ElType = promote_type(
eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]),
u0ElType)
end
return zeros(u0ElType, N)
end

"""
```julia
SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
Expand Down Expand Up @@ -337,13 +347,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
if length(dvs) == length(equations(sys))
resid_prototype = nothing
else
u0ElType = u0 === nothing ? Float64 : eltype(u0)
if SciMLStructures.isscimlstructure(p)
u0ElType = promote_type(
eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]),
u0ElType)
end
resid_prototype = zeros(u0ElType, length(equations(sys)))
resid_prototype = calculate_resid_prototype(length(equations(sys)), u0, p)
end

NonlinearFunction{iip}(f,
Expand Down

0 comments on commit b063e6b

Please sign in to comment.