From c578da182393e8ba8efac181b796e9838d093680 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 2 Dec 2024 13:18:33 +0530 Subject: [PATCH] fix: recalculate `resid_prototype` in `remake_initialization_data` --- src/systems/nonlinear/initializesystem.jl | 11 ++++++++++- test/initializationsystem.jl | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 4abb345822..726d171bd0 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -260,7 +260,16 @@ function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, newp = remake_buffer( oldinitprob.f.sys, parameter_values(oldinitprob), pidxs, pvals) end - initprob = remake(oldinitprob; u0 = newu0, p = newp) + if oldinitprob.f.resid_prototype === nothing + newf = oldinitprob.f + else + newf = NonlinearFunction{ + SciMLBase.isinplace(oldinitprob.f), SciMLBase.specialization(oldinitprob.f)}( + oldinitprob.f; + resid_prototype = calculate_resid_prototype( + length(oldinitprob.f.resid_prototype), newu0, newp)) + end + initprob = remake(oldinitprob; f = newf, u0 = newu0, p = newp) return SciMLBase.OverrideInitData(initprob, odefn.update_initializeprob!, odefn.initializeprobmap, odefn.initializeprobpmap) end diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 0b0dc42c1e..f3015f7db0 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1032,3 +1032,20 @@ end @test prob3.f.initialization_data !== nothing @test init(prob3)[x] ≈ 0.5 end + +@testset "Issue#3246: type promotion with parameter dependent initialization_eqs" begin + @variables x(t)=1 y(t)=1 + @parameters a = 1 + @named sys = ODESystem([D(x) ~ 0, D(y) ~ x + a], t; initialization_eqs = [y ~ a]) + + ssys = structural_simplify(sys) + prob = ODEProblem(ssys, [], (0, 1), []) + + @test SciMLBase.successful_retcode(solve(prob)) + + seta = setsym_oop(prob, [a]) + (newu0, newp) = seta(prob, ForwardDiff.Dual{ForwardDiff.Tag{:tag, Float64}}.([1.0], 1)) + newprob = remake(prob, u0 = newu0, p = newp) + + @test SciMLBase.successful_retcode(solve(newprob)) +end