Skip to content

Commit

Permalink
refactor: remove setting st in NNODE
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Jun 28, 2024
1 parent c15eb96 commit bed52b7
Showing 1 changed file with 0 additions and 4 deletions.
4 changes: 0 additions & 4 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ function (f::ODEPhi{C, T, U})(t::Number,
θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number}
y, st = f.chain(
adapt(parameterless_type(ComponentArrays.getdata.depvar)), [t]), θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 + (t - f.t0) * first(y)
end

Expand All @@ -129,14 +128,12 @@ function (f::ODEPhi{C, T, U})(t::AbstractVector,
# Batch via data as row vectors
y, st = f.chain(
adapt(parameterless_type(ComponentArrays.getdata.depvar)), t'), θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 .+ (t' .- f.t0) .* y
end

function (f::ODEPhi{C, T, U})(t::Number, θ) where {C <: Lux.AbstractExplicitLayer, T, U}
y, st = f.chain(
adapt(parameterless_type(ComponentArrays.getdata.depvar)), [t]), θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 .+ (t .- f.t0) .* y
end

Expand All @@ -145,7 +142,6 @@ function (f::ODEPhi{C, T, U})(t::AbstractVector,
# Batch via data as row vectors
y, st = f.chain(
adapt(parameterless_type(ComponentArrays.getdata.depvar)), t'), θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 .+ (t' .- f.t0) .* y
end

Expand Down

0 comments on commit bed52b7

Please sign in to comment.