Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: adds lagrangian nn and simlple example #537

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"

[compat]
Adapt = "3"
Expand All @@ -48,6 +49,7 @@ TerminalLoggers = "0.1"
Zygote = "0.5, 0.6"
ZygoteRules = "0.2"
julia = "1.5"
GenericLinearAlgebra = "0.2.5"

[extras]
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
Expand Down
24 changes: 24 additions & 0 deletions docs/src/examples/lagrangian_nn.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# One point test
using Flux, ReverseDiff, LagrangianNN
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be moved to test folder and included in runtests.jl


m, k, b = 1, 1, 1

X = rand(2,1)
Y = -k.*X[1]/m

g = Chain(Dense(2, 10, σ), Dense(10,1))
model = LagrangianNN(g)
params = model.params
re = model.re

loss(x, y, p) = sum(abs2, y .- model(x, p))

opt = ADAM(0.01)
epochs = 100

for epoch in 1:epochs
x, y = X, Y
gs = ReverseDiff.gradient(p -> loss(x, y, p), params)
Flux.Optimise.update!(opt, params, gs)
@show loss(x,y,params)
end
4 changes: 3 additions & 1 deletion src/DiffEqFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module DiffEqFlux

using GalacticOptim, DataInterpolations, DiffEqBase, DiffResults, DiffEqSensitivity,
Distributions, ForwardDiff, Flux, Requires, Adapt, LinearAlgebra, RecursiveArrayTools,
StaticArrays, Base.Iterators, Printf, Zygote
StaticArrays, Base.Iterators, Printf, Zygote, GenericLinearAlgebra

using DistributionsAD
import ProgressLogging, ZygoteRules
Expand Down Expand Up @@ -82,11 +82,13 @@ include("tensor_product_basis.jl")
include("tensor_product_layer.jl")
include("collocation.jl")
include("hnn.jl")
include("lnn.jl")
include("multiple_shooting.jl")

export diffeq_fd, diffeq_rd, diffeq_adjoint
export DeterministicCNF, FFJORD, NeuralODE, NeuralDSDE, NeuralSDE, NeuralCDDE, NeuralDAE, NeuralODEMM, TensorLayer, AugmentedNDELayer, SplineLayer, NeuralHamiltonianDE
export HamiltonianNN
export LagrangianNN
export ChebyshevBasis, SinBasis, CosBasis, FourierBasis, LegendreBasis, PolynomialBasis
export neural_ode, neural_ode_rd
export neural_dmsde
Expand Down
42 changes: 42 additions & 0 deletions src/lnn.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Constructs a Lagrangian Neural Network [1].

References:
[1] Miles Cranmer, Sam Greydanus, Stephan Hoyer, Peter Battaglia, David Spergel, and Shirley Ho.Lagrangian Neural Networks.
In ICLR 2020 Workshop on Integration of Deep Neural Models and Differential Equations, 2020.
"""

struct LagrangianNN{M, R, P}
model::M
re::R
p::P

# Define inner constructor method
function LagrangianNN(model; p = nothing)
_p, re = Flux.destructure(model)
if p === nothing
p = _p
end
return new{typeof(model), typeof(re), typeof(p)}(model, re, p)
end
end

function (lnn::LagrangianNN)(x, p = lnn.p)
@assert size(x,1) % 2 === 0 # velocity df should be equal to coords degree of freedom
M = div(size(x,1), 2) # number of velocities degrees of freedom
re = lnn.re
hess = x -> Zygote.hessian_reverse(x->sum(re(p)(x)), x) # we have to compute the whole hessian
hess = hess(x)[M+1:end, M+1:end] # takes only velocities
inv_hess = GenericLinearAlgebra.pinv(hess)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why pinv?


_grad_q = x -> Zygote.gradient(x->sum(re(p)(x)), x)[end]
_grad_q = _grad_q(x)[1:M,:] # take only coord derivatives
out1 =_grad_q

# Second term
_grad_qv = x -> Zygote.gradient(x->sum(re(p)(x)), x)[end]
_jac_qv = x -> Zygote.jacobian(x->_grad_qv(x), x)[end]
out2 = _jac_qv(x)[1:M,M+1:end] * x[M+1:end] # take only dqdq_dot derivatives

return inv_hess * (out1 .+ out2)
end