diff --git a/Project.toml b/Project.toml index 9101c516c..9760925c5 100644 --- a/Project.toml +++ b/Project.toml @@ -48,7 +48,7 @@ FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1" ForwardDiff = "0.10" GPUArrays = "8.4.2, 9" GPUArraysCore = "0.1.1" -IRTools = "0.4.11" +IRTools = "0.4.12" LogExpFunctions = "0.3.1" MacroTools = "0.5" NaNMath = "0.3, 1" diff --git a/docs/src/limitations.md b/docs/src/limitations.md index f27f74305..4e9012ced 100644 --- a/docs/src/limitations.md +++ b/docs/src/limitations.md @@ -82,27 +82,30 @@ julia> gradient(rand(3)) do y ## Try-catch statements -Any expressions involving `try`/`catch` statements is not supported. -```julia -function tryme(x) - try - 2 * x - catch e - throw(e) - end -end +Code containting try-catch blocks can be differentiated as long as no exception is actually thrown. -julia> gradient(rand(3)) do x - sum(tryme(x)) +```julia +julia> function safe_sqrt(x) + try + sqrt(x) + catch + 0. + end end -ERROR: Compiling Tuple{typeof(tryme), Vector{Float64}}: try/catch is not supported. -Refer to the Zygote documentation for fixes. -https://fluxml.ai/Zygote.jl/latest/limitations +safe_sqrt (generic function with 1 method) + +julia> gradient(safe_sqrt, 4.) +(0.25,) + +julia> val, pull = pullback(safe_sqrt, -1.) +(0.0, Zygote.var"#76#77"{Zygote.Pullback{Tuple{typeof(safe_sqrt), Float64}, Any}}(∂(safe_sqrt))) +julia> pull(1.) +ERROR: Can't differentiate function execution in catch block at #= REPL[2]:3 =#. Stacktrace: - ... ``` -Here `tryme` uses a `try`/`catch` statement, and Zygote throws an error when trying to differentiate it as expected. `try`/`catch` expressions are used for error handling, but they are less common in Julia compared to some other languages. + +Here, the `safe_sqrt` function catches DomainError from the sqrt call when the input is out of domain and safely returns 0. Zygote is able to differentiate the function when no error is thrown by the sqrt call, but fails to differentiate when the control flow goes through the catch block. ## Foreign call expressions diff --git a/src/compiler/emit.jl b/src/compiler/emit.jl index ca79f11ce..75bfc0e58 100644 --- a/src/compiler/emit.jl +++ b/src/compiler/emit.jl @@ -36,22 +36,23 @@ concrete(T::DataType) = T concrete(::Type{Type{T}}) where T = typeof(T) concrete(T) = Any -runonce(b) = b.id in (1, length(b.ir.blocks)) +runonce(b) = b.id in (1, length(b.ir.blocks)) && + !any(((_,stmt),) -> isexpr(stmt.expr, :catch), b) function forward_stacks!(adj, F) stks, recs = [], [] pr = adj.primal for b in blocks(pr), α in alphauses(block(adj.adjoint, b.id)) - if runonce(b) + not_stack = runonce(b) + if not_stack push!(recs, Variable(α)) else stk = pushfirst!(pr, xstack(Any)) push!(recs, stk) push!(b, xcall(Zygote, :_push!, stk, Variable(α))) end - push!(stks, (b.id, alpha(α))) + push!(stks, (b.id, alpha(α), not_stack)) end - args = arguments(pr)[3:end] rec = push!(pr, xtuple(recs...)) P = length(pr.blocks) == 1 ? Pullback{F} : Pullback{F,Any} # P = Pullback{F,Any} # reduce specialisation @@ -68,11 +69,10 @@ function reverse_stacks!(adj, stks) self = argument!(entry, at = 1) t = pushfirst!(blocks(ir)[end], xcall(:getfield, self, QuoteNode(:t))) repl = Dict() - runonce(b) = b.id in (1, length(ir.blocks)) for b in blocks(ir) - for (i, (b′, α)) in enumerate(stks) + for (i, (b′, α, not_stack)) in enumerate(stks) b.id == b′ || continue - if runonce(b) + if not_stack val = insertafter!(ir, t, xcall(:getindex, t, i)) else stk = push!(entry, xcall(:getindex, t, i)) diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index 0583b3da6..5ed79ea81 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -124,11 +124,6 @@ function instrument(ir::IR) ex = st.expr if isexpr(ex, :foreigncall, :isdefined) continue - elseif isexpr(ex, :enter, :leave) - error("""try/catch is not supported. - Refer to the Zygote documentation for fixes. - https://fluxml.ai/Zygote.jl/latest/limitations - """) elseif isexpr(ex, :(=)) @assert ex.args[1] isa GlobalRef pr[v] = xcall(Zygote, :global_set, QuoteNode(ex.args[1]), ex.args[2]) @@ -258,7 +253,7 @@ function adjointcfg(pr::Primal) end if isempty(preds) || (!isempty(branches(b)) && branches(b)[end] == IRTools.unreachable) # If `b` is unreachable, then no context produced by the primal should end up branching to `rb` - push!(rb, xcall(Core, :throw, "unreachable")) # `throw` is necessary for inference not to hit the `unreachable` + push!(rb, xcall(Base, :error, "unreachable")) # `throw` is necessary for inference not to hit the `unreachable` branch!(rb, 0) end end @@ -279,7 +274,7 @@ xaccum(ir, xs...) = push!(ir, xcall(Zygote, :accum, xs...)) function passthrough_expr(ex::Expr) # Metadata we want to preserve - isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo) && return true + isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo, :enter, :leave, :catch) && return true # ccalls and more that are safe to preserve/required for proper operation: # - jl_set_task_threadpoolid: added in 1.9 for @spawn isexpr(ex, :foreigncall) && unwrapquote(ex.args[1]) in (:jl_set_task_threadpoolid,) && return true @@ -297,9 +292,14 @@ function adjoint(pr::Primal) for i = 1:length(sigs[b.id]) grad(sigs[b.id][i], arguments(rb)[i]) end + + has_leave = false + # Backprop through statements for v in reverse(keys(b)) ex = b[v].expr + has_leave |= isexpr(ex, :leave) + if haskey(pr.pullbacks, v) g = push!(rb, stmt(Expr(:call, alpha(pr.pullbacks[v]), grad(v)), line = b[v].line)) @@ -321,6 +321,17 @@ function adjoint(pr::Primal) continue end end + + # This is corresponds to a catch blocks which technically + # has predecessors but they are not modelled in the IRTools CFG. + # We put an error message at the beginning of said block. + if has_leave && isempty(predecessors(b)) && b.id != 1 + _, f_stmt = first(b) + li = pr.ir.lines[f_stmt.line] + pushfirst!(rb, stmt(xcall(Base, :error, + "Can't differentiate function execution in catch block at $(li.file):$(li.line)."))) + end + if b.id > 1 # Backprop through (predecessor) branch arguments gs = grad.(arguments(b)) for br in branches(rb) diff --git a/test/compiler.jl b/test/compiler.jl index 4f8776c90..af93ae4f3 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -245,3 +245,100 @@ end @test_nowarn g = back(1.) @test only(g) ∈ (1., 2.) end + +function throws_and_catches_if_x_negative(x,y) + z = x + y + try + if x < 0. + throw(DomainError("x is negative")) + end + z = 2z + x + y + catch err + @error "something went wrong" exception=(err,catch_backtrace()) + end + return 3z +end + +function try_catch_finally(cond, x) + + try + x = 2x + cond && throw(DomainError()) + catch + x = 2x + finally + x = 3x + end + + x +end + +if VERSION >= v"1.8" + # try/catch/else is invalid syntax prior to v1.8 + eval(Meta.parse(""" + function try_catch_else(cond, x) + x = 2x + + try + x = 2x + cond && throw(nothing) + catch + x = 3x + else + x = 2x + end + + x + end + """)) +end + +@testset "try/catch" begin + @testset "happy path (nothrow)" begin + res, (dx,dy) = withgradient(throws_and_catches_if_x_negative, 1., 2.) + @test res == 3 * (2 * (1. + 2.) + 1. + 2.) + @test dx == 3. * (2. + 1.) + @test dy == 3. * (2. + 1.) + end + + @testset "try/catch/finally" begin + res, (_, dx,) = withgradient(try_catch_finally, false, 1.) + @test res == 6. + @test dx == 6. + + res, pull = pullback(try_catch_finally, true, 1.) + @test res == 12. + @test_throws ErrorException pull(1.) + err = try pull(1.) catch ex; ex end + @test occursin("Can't differentiate function execution in catch block", + string(err)) + end + + if VERSION >= v"1.8" + @testset "try/catch/else" begin + @test Zygote.gradient(try_catch_else, false, 1.0) == (nothing, 8.0) + @test_throws "Can't differentiate function execution in catch block" Zygote.gradient(try_catch_else, true, 1.0) + end + end + + function foo_try(f) + y = 1 + try + y = f() + catch + y + end + y + end + + g, = gradient(x -> foo_try(() -> x), 1) # 1 + @test g == 1. + + vy, pull = pullback(foo_try, () -> 0//0) # bypass because of expr + @test vy === 1 + @test_throws ErrorException pull(1.) + + err = try pull(1.) catch ex; ex end + @test occursin("Can't differentiate function execution in catch block", + string(err)) +end diff --git a/test/features.jl b/test/features.jl index 908ae5815..78dba0484 100644 --- a/test/features.jl +++ b/test/features.jl @@ -416,8 +416,7 @@ function pow_try(x) end end -@test_broken gradient(pow_try, 1) == (2,) -@test_throws Zygote.CompileError gradient(pow_try, 1) +@test gradient(pow_try, 1) == (2,) function pow_simd(x, n) r = 1