Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BitVector failure #1523

Closed
mhauru opened this issue Sep 9, 2024 · 3 comments
Closed

BitVector failure #1523

mhauru opened this issue Sep 9, 2024 · 3 comments
Labels
ChainRules adjoint -> rrule, and further integration

Comments

@mhauru
Copy link

mhauru commented Sep 9, 2024

using Zygote: Zygote

struct VNV{TVal}
    vals::TVal
    bv::BitVector
end

f(x) = VNV(x, BitVector(undef, 1)).vals
Zygote.pullback(f, [1.0])

The above fails with

ERROR: LoadError: ArgumentError: new: too few arguments (expected 3)
Stacktrace:
  [1] __new__
    @ ~/.julia/packages/Zygote/nsBv0/src/tools/builtins.jl:9 [inlined]
  [2] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:296 [inlined]
  [3] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [4] BitArray
    @ ./bitarray.jl:39 [inlined]
  [5] _pullback(::Zygote.Context{false}, ::Type{BitVector}, ::UndefInitializer, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [6] f
    @ ~/projects/DynamicPPL.jl/tmp_zygote_bug.jl:10 [inlined]
  [7] _pullback(ctx::Zygote.Context{false}, f::typeof(Main.TmpZygoteBug.f), args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [8] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:90
  [9] pullback(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:88
 [10] top-level scope
    @ ~/projects/DynamicPPL.jl/tmp_zygote_bug.jl:11
 [11] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [12] top-level scope
    @ REPL[3]:1
in expression starting at /Users/mhauru/projects/DynamicPPL.jl/tmp_zygote_bug.jl:1

on v0.6.70.

Switching to e.g. Vector{Bool} rather than a BitVector works.

@willtebbutt
Copy link
Member

Looks like the constructor for BitVector is fairly involved. You could just using ChainRules to @non_differentiable it, e.g.

@non_differentiable BitVector(a, b)

in a fresh session seems to work okay for me locally. It seems reasonable to me that you wouldn't be able to drop any gradient info doing this, so it should be safe.

@mcabbott mcabbott added the ChainRules adjoint -> rrule, and further integration label Sep 9, 2024
@mcabbott
Copy link
Member

mcabbott commented Sep 9, 2024

If that works, you can make a 1-line PR to this file https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/nondiff.jl to fix it permanently.

@CarloLucibello
Copy link
Member

the OP example now works

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ChainRules adjoint -> rrule, and further integration
Projects
None yet
Development

No branches or pull requests

4 participants