Skip to content

Commit

Permalink
Use BitSet for Tracer input set (#11)
Browse files Browse the repository at this point in the history
* Use `BitSet` for Tracer input set

* Add NNlib convolution to benchmarks

* Removes `sortedinputs` due to BitSets always being sorted
  • Loading branch information
adrhill authored Apr 10, 2024
1 parent 95d432c commit e450dd7
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 45 deletions.
1 change: 0 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,4 @@ trace_input
The following utilities can be used to extract input indices from [`Tracer`](@ref)s:
```@docs
inputs
sortedinputs
```
2 changes: 1 addition & 1 deletion src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ include("connectivity.jl")

export Tracer
export tracer, trace_input
export inputs, sortedinputs
export inputs
export connectivity

end # module
41 changes: 6 additions & 35 deletions src/tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ julia> M * [x, y]
```
"""
struct Tracer <: Number
inputs::Set{UInt64} # indices of connected, enumerated inputs
inputs::BitSet # indices of connected, enumerated inputs
end

const EMPTY_TRACER = Tracer(Set{UInt64}())
const EMPTY_TRACER = Tracer(BitSet())

# We have to be careful when defining constructors:
# Generic code expecting "regular" numbers `x` will sometimes convert them
Expand All @@ -94,59 +94,30 @@ uniontracer(a::Tracer, b::Tracer) = Tracer(union(a.inputs, b.inputs))
Convenience constructor for [`Tracer`](@ref) from input indices.
"""
tracer(index::Integer) = Tracer(Set{UInt64}(index))
tracer(inds::NTuple{N,<:Integer}) where {N} = Tracer(Set{UInt64}(inds))
tracer(index::Integer) = Tracer(BitSet(index))
tracer(inds::NTuple{N,<:Integer}) where {N} = Tracer(BitSet(inds))
tracer(inds...) = tracer(inds)

# Utilities for accessing input indices
"""
inputs(tracer)
Return raw `UInt64` input indices of a [`Tracer`](@ref).
See also [`sortedinputs`](@ref).
## Example
```jldoctest
julia> t = tracer(1, 2, 4)
Tracer(1, 2, 4)
julia> inputs(t)
3-element Vector{UInt64}:
0x0000000000000004
0x0000000000000002
0x0000000000000001
```
"""
inputs(t::Tracer) = collect(keys(t.inputs.dict))

"""
sortedinputs(tracer)
sortedinputs([T=Int], tracer)
Return sorted input indices of a [`Tracer`](@ref).
See also [`inputs`](@ref).
## Example
```jldoctest
julia> t = tracer(1, 2, 4)
Tracer(1, 2, 4)
julia> sortedinputs(t)
3-element Vector{Int64}:
1
2
4
julia> sortedinputs(UInt8, t)
3-element Vector{UInt8}:
0x01
0x02
0x04
```
"""
sortedinputs(t::Tracer) = sortedinputs(Int, t)
sortedinputs(T::Type, t::Tracer) = convert.(T, sort!(inputs(t)))
inputs(t::Tracer) = collect(t.inputs)

function Base.show(io::IO, t::Tracer)
return Base.show_delim_array(io, sortedinputs(Int, t), "Tracer(", ',', ')', true)
return Base.show_delim_array(io, inputs(t), "Tracer(", ',', ')', true)
end
30 changes: 26 additions & 4 deletions test/benchmark.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using BenchmarkTools
using SparseConnectivityTracer
using Symbolics: Symbolics
using NNlib: conv

include("brusselator.jl")

Expand All @@ -24,8 +25,29 @@ function benchmark_brusselator(N::Integer, method=:tracer)
end
end

benchmark_brusselator(6, :tracer)
benchmark_brusselator(6, :symbolics)
function benchmark_conv(method=:tracer)
x = rand(28, 28, 3, 1) # WHCN image
w = rand(5, 5, 3, 16) # corresponds to Conv((5, 5), 3 => 16)
f(x) = conv(x, w)

benchmark_brusselator(24, :tracer)
benchmark_brusselator(24, :symbolics)
if method == :tracer
return @benchmark connectivity($f, $x)
elseif method == :symbolics
return @benchmark Symbolics.jacobian_sparsity($f, $x)
end
end

## Run Brusselator benchmarks
for N in (6, 24)
for method in (:tracer, :symbolics)
@info "Benchmarking Brusselator of size $N with $method..."
b = benchmark_brusselator(N, method)
display(b)
end
end

## Run conv benchmarks
@info "Benchmarking NNlib.conv with tracer..."
# Symbolics fails on this example
b = benchmark_conv(:tracer)
display(b)
8 changes: 4 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ DocMeta.setdocmeta!(
# Matrix multiplication
A = rand(1, 3)
yt = only(A * xt)
@test sortedinputs(yt) == [1, 2, 3]
@test inputs(yt) == [1, 2, 3]

@test connectivity(x -> only(A * x), x) [1 1 1]

# Custom functions
f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])]
yt = f(xt)
@test sortedinputs(yt[1]) == [1]
@test sortedinputs(yt[2]) == [1, 2]
@test sortedinputs(yt[3]) == [3]
@test inputs(yt[1]) == [1]
@test inputs(yt[2]) == [1, 2]
@test inputs(yt[3]) == [3]

@test connectivity(f, x) [1 0 0; 1 1 0; 0 0 1]

Expand Down

0 comments on commit e450dd7

Please sign in to comment.