You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
function train(t, y, obs_grid, maxiters, lr, rng, p = nothing, state = nothing; kwargs...)
log_results(ps, losses) = (p, loss) -> begin
push!(ps, copy(p.u))
push!(losses, loss)
false
end
ps, losses = ComponentArray[], Float32[]
for k in obs_grid
node, p_new, state_new = neural_ode(t, size(y, 1))
p === nothing && (p = p_new)
state === nothing && (state = state_new)
p, state = train_one_round(
node, p, state, y, OptimizationOptimisers.AdamW(lr), maxiters, rng;
callback = log_results(ps, losses), kwargs...)
end
ps, state, losses
end
the following code contains the neural_ode function which accepts k number of observations based on obs_grid. Turns out the k value is not used in the function and the whole dataset is passed as an input in the function which defeats the purpose of batching inputs.
the following code contains the neural_ode function which accepts k number of observations based on obs_grid. Turns out the k value is not used in the function and the whole dataset is passed as an input in the function which defeats the purpose of batching inputs.
the code should be
and
The text was updated successfully, but these errors were encountered: