-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Start forward mode AD #389
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great. I've left a few comments, but if you're planning to do a bunch of additional stuff, then maybe they're redundant. Either way, don't feel the need to respond to them.
src/frules/basic.jl
Outdated
@@ -0,0 +1,11 @@ | |||
frule!!(f::F, args::Vararg{Dual,N}) where {F,N} = frule!!(zero_dual(f), args...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just a temporary hack in order to ensure we don't have to insert any statements into the IR for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes but I'm wondering if it's a bad thing to leave it there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would much rather aim for complete uniformity here -- it's what we do in reverse-mode, and I think it will ensure that we avoid and "perturbation confusion" problems when doing forwards-mode over forwards-mode (if you know that the second time you differentiate a programme, everything is going to be a Dual
inside a Dual
, you avoid issues of the form "okay, so this is a Dual
, but is it a Dual
associated to the first time I differentiated this programme, or the second? Off the top of my head, I don't have a concrete example to put forward, but we should definitely try and concoct one -- there might be one somewhere in ForwardDiff.jl, but if not there must be plenty of examples out there in the literature)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #389 (comment), I replaced this with _frule!!_funcnotdual
to avoid confusion
Co-authored-by: Will Tebbutt <[email protected]> Signed-off-by: Guillaume Dalle <[email protected]>
@willtebbutt following our discussion yesterday I scratched my head some more, and I decided that it would be infinitely simpler to enforce the invariant that one line of primal IR maps to one line of dual IR. While this may require additional fallbacks in the Julia code itself, I hope it will make our lives much easier on the IR side. What do you think? |
I think this could work. You could just replace the @inline function call_frule!!(rule::R, fargs::Vararg{Any, N}) where {N}
return rule(map(x -> x isa Dual ? x : zero_dual(x), fargs)...)
end The optimisation pass will lower this to the what we were thinking about writing out in the IR anyway. I think the other important kinds of nodes would be largely straightforward to handle. |
I think we might need to be slightly more subtle. If an argument to the |
Yes. I think my propose code handles this though, or am I missing something? |
In the spirit of higher-order AD, we may encounter |
Very good point.
Agreed. Specifically, I think we need to distinguish between literals / |
I still need to dig into the different node types we might encounter (and I still don't understand |
I was reviewing the design docs and realised that, sadly, the "one line of primal IR maps to one line of dual IR" won't work for |
I think that's okay, the main trouble is adding new lines which insert new variables because it requires manual renumbering. A GoTo should be much simpler. |
Were the difficulties around renumbering etc not resolved by not |
No they weren't. I experimented with |
Ah, right, but we do need to insert a new SSAValue. Suppose that the GotoIfNot(%5, #3) i.e. jump to block 3 if not %new_ssa = Expr(:call, primal, %5)
GotoIfNot(%new_ssa, #3) Does this not cause the same kind of problems? |
Oh yes you're probably right. Although it might be slightly less of a hassle because the new SSA is only used in one spot, right after. I'll take a look |
Do you know what I should do about expressions of type |
Yup -- I just strip them out of the IR entirely in reverse-mode. See
The way to remove an instruction from an |
I think this works for
MWE (requires this branch of Mooncake): const CC = Core.Compiler
using Mooncake
using MistyClosures
f(x) = x > 1 ? 2x : 3 + x
ir = Base.code_ircode(f, (Float64,))[1][1]
initial_ir = copy(ir)
get_primal_inst = CC.NewInstruction(Expr(:call, +, 1, 2), Any) # placeholder for get_primal
CC.insert_node!(ir, CC.SSAValue(3), get_primal_inst, false)
ir = CC.compact!(ir)
for k in 1:length(ir.stmts)
inst = ir[CC.SSAValue(k)][:stmt]
if inst isa Core.GotoIfNot
Mooncake.replace_call!(ir,CC.SSAValue(k), Core.GotoIfNot(CC.SSAValue(k-1), inst.dest))
end
end
ir julia> initial_ir
5 1 ─ %1 = Base.lt_float(1.0, _2)::Bool │╻╷╷ >
│ %2 = Base.or_int(%1, false)::Bool ││╻ <
└── goto #3 if not %2 │
2 ─ %4 = Base.mul_float(2.0, _2)::Float64 ││╻ *
└── return %4 │
3 ─ %6 = Base.add_float(3.0, _2)::Float64 ││╻ +
└── return %6 │
julia> ir
5 1 ─ %1 = Base.lt_float(1.0, _2)::Bool │╻╷╷ >
│ Base.or_int(%1, false)::Bool ││╻ <
│ %3 = (+)(1, 2)::Any │
└── goto #3 if not %3 │
2 ─ %5 = Base.mul_float(2.0, _2)::Float64 ││╻ *
└── return %5 │
3 ─ %7 = Base.add_float(3.0, _2)::Float64 ││╻ +
└── return %7 |
This is cool. I was also prototyping insertions with the IRCode while you were doing this, but was running into trouble when you have to make insertions at multiple locations. Could you try this on an example with multiple On a separate note: if it does turn out that (for some reason) that it's really awkward to do this insertion stuff directly on |
Eyeballing it, this insert-then-shift approach seems to work? const CC = Core.Compiler
using Mooncake
primal(x) = x
function replace_gotoifnot(ir::CC.IRCode)
ir = copy(ir)
# insertion loop
for k in 1:length(ir.stmts)
inst = ir[CC.SSAValue(k)][:stmt]
if inst isa Core.GotoIfNot
get_primal_inst = CC.NewInstruction(Expr(:call, primal, inst.cond), Any)
CC.insert_node!(ir, CC.SSAValue(k), get_primal_inst, false)
end
end
ir = CC.compact!(ir)
# shift loop
for k in 1:length(ir.stmts)
inst = ir[CC.SSAValue(k)][:stmt]
if inst isa Core.GotoIfNot
Mooncake.replace_call!(
ir, CC.SSAValue(k), Core.GotoIfNot(CC.SSAValue(k - 1), inst.dest)
)
end
end
return ir
end
function f(x)
if x > 0
if x > 1
return sin(x)
else
return cos(x)
end
else
if x < -1
return exp(x)
else
return log(x)
end
end
end
ir = Base.code_ircode(f, (Float64,))[1][1]
new_ir = replace_gotoifnot(ir) julia> ir = Base.code_ircode(f, (Float64,))[1][1]
30 1 ─ %1 = Base.lt_float(0.0, _2)::Bool │╻╷╷ >
│ %2 = Base.or_int(%1, false)::Bool ││╻ <
└── goto #5 if not %2 │
31 2 ─ %4 = Base.lt_float(1.0, _2)::Bool │╻╷╷ >
│ %5 = Base.or_int(%4, false)::Bool ││╻ <
└── goto #4 if not %5 │
32 3 ─ %7 = invoke Main.sin(_2::Float64)::Float64 │
└── return %7 │
34 4 ─ %9 = invoke Main.cos(_2::Float64)::Float64 │
└── return %9 │
37 5 ─ %11 = Base.lt_float(_2, -1.0)::Bool │╻╷ <
│ %12 = Base.or_int(%11, false)::Bool ││╻ |
└── goto #7 if not %12 │
38 6 ─ %14 = invoke Main.exp(_2::Float64)::Float64 │
└── return %14 │
40 7 ─ %16 = invoke Main.log(_2::Float64)::Float64 │
└── return %16 │
julia> new_ir = replace_gotoifnot(ir)
30 1 ─ %1 = Base.lt_float(0.0, _2)::Bool │╻╷╷ >
│ %2 = Base.or_int(%1, false)::Bool ││╻ <
│ %3 = (primal)(%2)::Any │
└── goto #5 if not %3 │
31 2 ─ %5 = Base.lt_float(1.0, _2)::Bool │╻╷╷ >
│ %6 = Base.or_int(%5, false)::Bool ││╻ <
│ %7 = (primal)(%6)::Any │
└── goto #4 if not %7 │
32 3 ─ %9 = invoke Main.sin(_2::Float64)::Float64 │
└── return %9 │
34 4 ─ %11 = invoke Main.cos(_2::Float64)::Float64 │
└── return %11 │
37 5 ─ %13 = Base.lt_float(_2, -1.0)::Bool │╻╷ <
│ %14 = Base.or_int(%13, false)::Bool ││╻ |
│ %15 = (primal)(%14)::Any │
└── goto #7 if not %15 │
38 6 ─ %17 = invoke Main.exp(_2::Float64)::Float64 │
└── return %17 │
40 7 ─ %19 = invoke Main.log(_2::Float64)::Float64 │
└── return %19 Can you spot any issue in the generated IR? |
This is a very rough backbone of forward mode AD, based on #386 and the existing reverse mode implementation. Right now it only recurses into primitives so that we don't change the number of statements (cause I don't know what else we need to adapt in the
IRCode
when we add statements).ping @willtebbutt