diff --git a/src/pino_ode_solve.jl b/src/pino_ode_solve.jl index e20e1dec2f..fdfbeaf301 100644 --- a/src/pino_ode_solve.jl +++ b/src/pino_ode_solve.jl @@ -67,11 +67,8 @@ end function generate_pino_phi_θ(chain::Lux.AbstractExplicitLayer, init_params) θ, st = Lux.setup(Random.default_rng(), chain) - if init_params === nothing - init_params = ComponentArrays.ComponentArray(θ) - else - init_params = ComponentArrays.ComponentArray(init_params) - end + init_params = isnothing(init_params) ? θ : init_params + init_params = ComponentArrays.ComponentArray(init_params) PINOPhi(chain, st), init_params end