Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in train function in neural ode weather forecasting example #919

Closed
gebmecod opened this issue May 10, 2024 · 1 comment
Closed

Error in train function in neural ode weather forecasting example #919

gebmecod opened this issue May 10, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@gebmecod
Copy link

gebmecod commented May 10, 2024

function train(t, y, obs_grid, maxiters, lr, rng, p = nothing, state = nothing; kwargs...)
    log_results(ps, losses) = (p, loss) -> begin
        push!(ps, copy(p.u))
        push!(losses, loss)
        false
    end

    ps, losses = ComponentArray[], Float32[]
    for k in obs_grid
        node, p_new, state_new = neural_ode(t, size(y, 1))
        p === nothing && (p = p_new)
        state === nothing && (state = state_new)

        p, state = train_one_round(
            node, p, state, y, OptimizationOptimisers.AdamW(lr), maxiters, rng;
            callback = log_results(ps, losses), kwargs...)
    end
    ps, state, losses
end

the following code contains the neural_ode function which accepts k number of observations based on obs_grid. Turns out the k value is not used in the function and the whole dataset is passed as an input in the function which defeats the purpose of batching inputs.

the code should be

neural_ode(t[1:k], size(y,1))

and

train_one_round(
            node, p, state, y[:,1:k], OptimizationOptimisers.AdamW(lr), maxiters, rng;
            callback = log_results(ps, losses), kwargs...)
@gebmecod gebmecod added the bug Something isn't working label May 10, 2024
@ChrisRackauckas
Copy link
Member

New docs build is up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants