Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start forward mode AD #389

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft

Start forward mode AD #389

wants to merge 13 commits into from

Conversation

gdalle
Copy link
Collaborator

@gdalle gdalle commented Nov 24, 2024

This is a very rough backbone of forward mode AD, based on #386 and the existing reverse mode implementation. Right now it only recurses into primitives so that we don't change the number of statements (cause I don't know what else we need to adapt in the IRCode when we add statements).

ping @willtebbutt

Copy link

codecov bot commented Nov 24, 2024

Codecov Report

Attention: Patch coverage is 47.90419% with 87 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/interpreter/diffractor_compiler_utils.jl 10.60% 59 Missing ⚠️
src/interpreter/s2s_forward_mode_ad.jl 72.83% 22 Missing ⚠️
src/dual.jl 69.23% 4 Missing ⚠️
src/debug_mode.jl 0.00% 1 Missing ⚠️
src/rrules/low_level_maths.jl 75.00% 1 Missing ⚠️
Files with missing lines Coverage Δ
src/Mooncake.jl 100.00% <100.00%> (ø)
src/test_utils.jl 80.62% <100.00%> (-12.33%) ⬇️
src/debug_mode.jl 97.22% <0.00%> (-2.78%) ⬇️
src/rrules/low_level_maths.jl 29.68% <75.00%> (-70.32%) ⬇️
src/dual.jl 69.23% <69.23%> (ø)
src/interpreter/s2s_forward_mode_ad.jl 72.83% <72.83%> (ø)
src/interpreter/diffractor_compiler_utils.jl 10.60% <10.60%> (ø)

... and 30 files with indirect coverage changes

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. I've left a few comments, but if you're planning to do a bunch of additional stuff, then maybe they're redundant. Either way, don't feel the need to respond to them.

src/interpreter/s2s_forward_mode_ad.jl Outdated Show resolved Hide resolved
test/forward.jl Outdated Show resolved Hide resolved
@@ -0,0 +1,11 @@
frule!!(f::F, args::Vararg{Dual,N}) where {F,N} = frule!!(zero_dual(f), args...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just a temporary hack in order to ensure we don't have to insert any statements into the IR for now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes but I'm wondering if it's a bad thing to leave it there?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would much rather aim for complete uniformity here -- it's what we do in reverse-mode, and I think it will ensure that we avoid and "perturbation confusion" problems when doing forwards-mode over forwards-mode (if you know that the second time you differentiate a programme, everything is going to be a Dual inside a Dual, you avoid issues of the form "okay, so this is a Dual, but is it a Dual associated to the first time I differentiated this programme, or the second? Off the top of my head, I don't have a concrete example to put forward, but we should definitely try and concoct one -- there might be one somewhere in ForwardDiff.jl, but if not there must be plenty of examples out there in the literature)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #389 (comment), I replaced this with _frule!!_funcnotdual to avoid confusion

src/frules/basic.jl Outdated Show resolved Hide resolved
src/interpreter/s2s_forward_mode_ad.jl Outdated Show resolved Hide resolved
src/interpreter/s2s_forward_mode_ad.jl Outdated Show resolved Hide resolved
src/interpreter/s2s_forward_mode_ad.jl Outdated Show resolved Hide resolved
src/interpreter/s2s_forward_mode_ad.jl Outdated Show resolved Hide resolved
@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

@willtebbutt following our discussion yesterday I scratched my head some more, and I decided that it would be infinitely simpler to enforce the invariant that one line of primal IR maps to one line of dual IR. While this may require additional fallbacks in the Julia code itself, I hope it will make our lives much easier on the IR side. What do you think?

@willtebbutt
Copy link
Member

I think this could work.

You could just replace the frule!! calls with a call to a function call_frule!! which would be something like

@inline function call_frule!!(rule::R, fargs::Vararg{Any, N}) where {N}
    return rule(map(x -> x isa Dual ? x : zero_dual(x), fargs)...)
end

The optimisation pass will lower this to the what we were thinking about writing out in the IR anyway.

I think the other important kinds of nodes would be largely straightforward to handle.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

I think we might need to be slightly more subtle. If an argument to the :call or :invoke expression is a CC.Argument or a CC.SSAValue, we don't wrap it in a Dual because we assume it will already be one, right?

@willtebbutt
Copy link
Member

willtebbutt commented Nov 26, 2024

Yes. I think my propose code handles this though, or am I missing something?

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

In the spirit of higher-order AD, we may encounter Dual inputs that we want to wrap with a second Dual, and Dual inputs that we want to leave as-is. So I think this wrapping needs to be decided from the type of each argument in the IR?

@willtebbutt
Copy link
Member

Very good point.

So I think this wrapping needs to be decided from the type of each argument in the IR?

Agreed. Specifically, I think we need to distinguish between literals / QuoteNodes / GlobalRefs, and Argument / SSAValues?

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 26, 2024

I still need to dig into the different node types we might encounter (and I still don't understand QuoteNodes) but yeah, Argument and SSAValue don't need to be wrapped.

@gdalle gdalle mentioned this pull request Nov 27, 2024
@willtebbutt
Copy link
Member

I was reviewing the design docs and realised that, sadly, the "one line of primal IR maps to one line of dual IR" won't work for Core.GotoIfNot nodes. See https://compintell.github.io/Mooncake.jl/previews/PR386/developer_documentation/forwards_mode_design/#Statement-Transformation .

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

I think that's okay, the main trouble is adding new lines which insert new variables because it requires manual renumbering. A GoTo should be much simpler.

@willtebbutt
Copy link
Member

Were the difficulties around renumbering etc not resolved by not compact!ing until the end? I feel like I might be missing something.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

No they weren't. I experimented with compact! in various places and I was struggling a lot, so I asked Frames for advice. She agreed that insertion should usually be avoided.
If we have to insert something for GoTo, I think it will still be easier because we're not defining a new SSAValue so we don't have to adapt future statements that refer to it.

@willtebbutt
Copy link
Member

willtebbutt commented Nov 27, 2024

Ah, right, but we do need to insert a new SSAValue. Suppose that the GotoIfNot of interest is

GotoIfNot(%5, #3)

i.e. jump to block 3 if not %5. In the forwards-mode IR this would become

%new_ssa = Expr(:call, primal, %5)
GotoIfNot(%new_ssa, #3)

Does this not cause the same kind of problems?

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

Oh yes you're probably right. Although it might be slightly less of a hassle because the new SSA is only used in one spot, right after. I'll take a look

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

Do you know what I should do about expressions of type :code_coverage_effect? I assume they're inserted automatically and they're alone on their lines?

@willtebbutt
Copy link
Member

willtebbutt commented Nov 27, 2024

Yup -- I just strip them out of the IR entirely in reverse-mode. See

elseif Meta.isexpr(stmt, :code_coverage_effect)

The way to remove an instruction from an IRCode is just to replace the instruction with nothing.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

I think this works for GotoIfNot:

  1. make all the insertions necessary
  2. compact! once to make sure they applied
  3. shift the conditions of all GotoIfNot nodes to refer to the node right before them (where we get the primal value of the condition)

MWE (requires this branch of Mooncake):

const CC = Core.Compiler
using Mooncake
using MistyClosures

f(x) = x > 1 ? 2x : 3 + x
ir = Base.code_ircode(f, (Float64,))[1][1]
initial_ir = copy(ir)
get_primal_inst = CC.NewInstruction(Expr(:call, +, 1, 2), Any)  # placeholder for get_primal
CC.insert_node!(ir, CC.SSAValue(3), get_primal_inst, false)
ir = CC.compact!(ir)
for k in 1:length(ir.stmts)
    inst = ir[CC.SSAValue(k)][:stmt]
    if inst isa Core.GotoIfNot
        Mooncake.replace_call!(ir,CC.SSAValue(k), Core.GotoIfNot(CC.SSAValue(k-1), inst.dest))
    end
end
ir
julia> initial_ir
5 1%1 = Base.lt_float(1.0, _2)::Bool                                                                                 │╻╷╷ >%2 = Base.or_int(%1, false)::Bool                                                                                 ││╻   <
  └──      goto #3 if not %2                                                                                            │   
  2%4 = Base.mul_float(2.0, _2)::Float64                                                                             ││╻   *
  └──      return %43%6 = Base.add_float(3.0, _2)::Float64                                                                             ││╻   +
  └──      return %6                                                                                                    │   
                                                                                                                            

julia> ir
5 1%1 = Base.lt_float(1.0, _2)::Bool                                                                                 │╻╷╷ >
  │        Base.or_int(%1, false)::Bool                                                                                 ││╻   <%3 = (+)(1, 2)::Any                                                                                               │   
  └──      goto #3 if not %3                                                                                            │   
  2%5 = Base.mul_float(2.0, _2)::Float64                                                                             ││╻   *
  └──      return %53%7 = Base.add_float(3.0, _2)::Float64                                                                             ││╻   +
  └──      return %7      

@willtebbutt
Copy link
Member

willtebbutt commented Nov 27, 2024

This is cool. I was also prototyping insertions with the IRCode while you were doing this, but was running into trouble when you have to make insertions at multiple locations. Could you try this on an example with multiple GotoIfNot nodes?

On a separate note: if it does turn out that (for some reason) that it's really awkward to do this insertion stuff directly on IRCode, I show in this gist how to use BBCode to manage IR transformations involving insertions, because this kind of thing is what BBCode excels at.

@gdalle
Copy link
Collaborator Author

gdalle commented Nov 27, 2024

Eyeballing it, this insert-then-shift approach seems to work?

const CC = Core.Compiler
using Mooncake

primal(x) = x

function replace_gotoifnot(ir::CC.IRCode)
    ir = copy(ir)
    # insertion loop
    for k in 1:length(ir.stmts)
        inst = ir[CC.SSAValue(k)][:stmt]
        if inst isa Core.GotoIfNot
            get_primal_inst = CC.NewInstruction(Expr(:call, primal, inst.cond), Any)
            CC.insert_node!(ir, CC.SSAValue(k), get_primal_inst, false)
        end
    end
    ir = CC.compact!(ir)
    # shift loop
    for k in 1:length(ir.stmts)
        inst = ir[CC.SSAValue(k)][:stmt]
        if inst isa Core.GotoIfNot
            Mooncake.replace_call!(
                ir, CC.SSAValue(k), Core.GotoIfNot(CC.SSAValue(k - 1), inst.dest)
            )
        end
    end
    return ir
end

function f(x)
    if x > 0
        if x > 1
            return sin(x)
        else
            return cos(x)
        end
    else
        if x < -1
            return exp(x)
        else
            return log(x)
        end
    end
end

ir = Base.code_ircode(f, (Float64,))[1][1]
new_ir = replace_gotoifnot(ir)
julia> ir = Base.code_ircode(f, (Float64,))[1][1]
30 1%1  = Base.lt_float(0.0, _2)::Bool                                                                               │╻╷╷ >%2  = Base.or_int(%1, false)::Bool                                                                               ││╻   <
   └──       goto #5 if not %2                                                                                          │   
31 2%4  = Base.lt_float(1.0, _2)::Bool                                                                               │╻╷╷ >%5  = Base.or_int(%4, false)::Bool                                                                               ││╻   <
   └──       goto #4 if not %5                                                                                          │   
32 3%7  = invoke Main.sin(_2::Float64)::Float64                                                                      │   
   └──       return %734 4%9  = invoke Main.cos(_2::Float64)::Float64                                                                      │   
   └──       return %937 5%11 = Base.lt_float(_2, -1.0)::Bool                                                                              │╻╷  <%12 = Base.or_int(%11, false)::Bool                                                                              ││╻   |
   └──       goto #7 if not %12                                                                                         │   
38 6%14 = invoke Main.exp(_2::Float64)::Float64                                                                      │   
   └──       return %1440 7%16 = invoke Main.log(_2::Float64)::Float64                                                                      │   
   └──       return %16                                                                                                 │   
                                                                                                                            

julia> new_ir = replace_gotoifnot(ir)
30 1%1  = Base.lt_float(0.0, _2)::Bool                                                                               │╻╷╷ >%2  = Base.or_int(%1, false)::Bool                                                                               ││╻   <%3  = (primal)(%2)::Any                                                                                          │   
   └──       goto #5 if not %3                                                                                          │   
31 2%5  = Base.lt_float(1.0, _2)::Bool                                                                               │╻╷╷ >%6  = Base.or_int(%5, false)::Bool                                                                               ││╻   <%7  = (primal)(%6)::Any                                                                                          │   
   └──       goto #4 if not %7                                                                                          │   
32 3%9  = invoke Main.sin(_2::Float64)::Float64                                                                      │   
   └──       return %934 4%11 = invoke Main.cos(_2::Float64)::Float64                                                                      │   
   └──       return %1137 5%13 = Base.lt_float(_2, -1.0)::Bool                                                                              │╻╷  <%14 = Base.or_int(%13, false)::Bool                                                                              ││╻   |%15 = (primal)(%14)::Any                                                                                         │   
   └──       goto #7 if not %15                                                                                         │   
38 6%17 = invoke Main.exp(_2::Float64)::Float64                                                                      │   
   └──       return %1740 7%19 = invoke Main.log(_2::Float64)::Float64                                                                      │   
   └──       return %19  

Can you spot any issue in the generated IR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants