diff --git a/docs/src/index.md b/docs/src/index.md index a08d9ec1..47a233ba 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -26,5 +26,4 @@ trace_input The following utilities can be used to extract input indices from [`Tracer`](@ref)s: ```@docs inputs -sortedinputs ``` diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index 5a08f71c..8e3f3806 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -9,7 +9,7 @@ include("connectivity.jl") export Tracer export tracer, trace_input -export inputs, sortedinputs +export inputs export connectivity end # module diff --git a/src/tracer.jl b/src/tracer.jl index 1c8f1a8c..cee0c7d7 100644 --- a/src/tracer.jl +++ b/src/tracer.jl @@ -74,10 +74,10 @@ julia> M * [x, y] ``` """ struct Tracer <: Number - inputs::Set{UInt64} # indices of connected, enumerated inputs + inputs::BitSet # indices of connected, enumerated inputs end -const EMPTY_TRACER = Tracer(Set{UInt64}()) +const EMPTY_TRACER = Tracer(BitSet()) # We have to be careful when defining constructors: # Generic code expecting "regular" numbers `x` will sometimes convert them @@ -94,8 +94,8 @@ uniontracer(a::Tracer, b::Tracer) = Tracer(union(a.inputs, b.inputs)) Convenience constructor for [`Tracer`](@ref) from input indices. """ -tracer(index::Integer) = Tracer(Set{UInt64}(index)) -tracer(inds::NTuple{N,<:Integer}) where {N} = Tracer(Set{UInt64}(inds)) +tracer(index::Integer) = Tracer(BitSet(index)) +tracer(inds::NTuple{N,<:Integer}) where {N} = Tracer(BitSet(inds)) tracer(inds...) = tracer(inds) # Utilities for accessing input indices @@ -103,7 +103,6 @@ tracer(inds...) = tracer(inds) inputs(tracer) Return raw `UInt64` input indices of a [`Tracer`](@ref). -See also [`sortedinputs`](@ref). ## Example ```jldoctest @@ -111,42 +110,14 @@ julia> t = tracer(1, 2, 4) Tracer(1, 2, 4) julia> inputs(t) -3-element Vector{UInt64}: - 0x0000000000000004 - 0x0000000000000002 - 0x0000000000000001 -``` -""" -inputs(t::Tracer) = collect(keys(t.inputs.dict)) - -""" - sortedinputs(tracer) - sortedinputs([T=Int], tracer) - -Return sorted input indices of a [`Tracer`](@ref). -See also [`inputs`](@ref). - -## Example -```jldoctest -julia> t = tracer(1, 2, 4) -Tracer(1, 2, 4) - -julia> sortedinputs(t) 3-element Vector{Int64}: 1 2 4 - -julia> sortedinputs(UInt8, t) -3-element Vector{UInt8}: - 0x01 - 0x02 - 0x04 ``` """ -sortedinputs(t::Tracer) = sortedinputs(Int, t) -sortedinputs(T::Type, t::Tracer) = convert.(T, sort!(inputs(t))) +inputs(t::Tracer) = collect(t.inputs) function Base.show(io::IO, t::Tracer) - return Base.show_delim_array(io, sortedinputs(Int, t), "Tracer(", ',', ')', true) + return Base.show_delim_array(io, inputs(t), "Tracer(", ',', ')', true) end diff --git a/test/benchmark.jl b/test/benchmark.jl index 3b5ddaa5..5b614e16 100644 --- a/test/benchmark.jl +++ b/test/benchmark.jl @@ -1,6 +1,7 @@ using BenchmarkTools using SparseConnectivityTracer using Symbolics: Symbolics +using NNlib: conv include("brusselator.jl") @@ -24,8 +25,29 @@ function benchmark_brusselator(N::Integer, method=:tracer) end end -benchmark_brusselator(6, :tracer) -benchmark_brusselator(6, :symbolics) +function benchmark_conv(method=:tracer) + x = rand(28, 28, 3, 1) # WHCN image + w = rand(5, 5, 3, 16) # corresponds to Conv((5, 5), 3 => 16) + f(x) = conv(x, w) -benchmark_brusselator(24, :tracer) -benchmark_brusselator(24, :symbolics) + if method == :tracer + return @benchmark connectivity($f, $x) + elseif method == :symbolics + return @benchmark Symbolics.jacobian_sparsity($f, $x) + end +end + +## Run Brusselator benchmarks +for N in (6, 24) + for method in (:tracer, :symbolics) + @info "Benchmarking Brusselator of size $N with $method..." + b = benchmark_brusselator(N, method) + display(b) + end +end + +## Run conv benchmarks +@info "Benchmarking NNlib.conv with tracer..." +# Symbolics fails on this example +b = benchmark_conv(:tracer) +display(b) diff --git a/test/runtests.jl b/test/runtests.jl index 2084e43d..048ea94d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,16 +44,16 @@ DocMeta.setdocmeta!( # Matrix multiplication A = rand(1, 3) yt = only(A * xt) - @test sortedinputs(yt) == [1, 2, 3] + @test inputs(yt) == [1, 2, 3] @test connectivity(x -> only(A * x), x) ≈ [1 1 1] # Custom functions f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])] yt = f(xt) - @test sortedinputs(yt[1]) == [1] - @test sortedinputs(yt[2]) == [1, 2] - @test sortedinputs(yt[3]) == [3] + @test inputs(yt[1]) == [1] + @test inputs(yt[2]) == [1, 2] + @test inputs(yt[3]) == [3] @test connectivity(f, x) ≈ [1 0 0; 1 1 0; 0 0 1]