Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use BitSet for Tracer input set #11

Merged
merged 3 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,4 @@ trace_input
The following utilities can be used to extract input indices from [`Tracer`](@ref)s:
```@docs
inputs
sortedinputs
```
2 changes: 1 addition & 1 deletion src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ include("connectivity.jl")

export Tracer
export tracer, trace_input
export inputs, sortedinputs
export inputs
export connectivity

end # module
41 changes: 6 additions & 35 deletions src/tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ julia> M * [x, y]
```
"""
struct Tracer <: Number
inputs::Set{UInt64} # indices of connected, enumerated inputs
inputs::BitSet # indices of connected, enumerated inputs
end

const EMPTY_TRACER = Tracer(Set{UInt64}())
const EMPTY_TRACER = Tracer(BitSet())

# We have to be careful when defining constructors:
# Generic code expecting "regular" numbers `x` will sometimes convert them
Expand All @@ -94,59 +94,30 @@ uniontracer(a::Tracer, b::Tracer) = Tracer(union(a.inputs, b.inputs))

Convenience constructor for [`Tracer`](@ref) from input indices.
"""
tracer(index::Integer) = Tracer(Set{UInt64}(index))
tracer(inds::NTuple{N,<:Integer}) where {N} = Tracer(Set{UInt64}(inds))
tracer(index::Integer) = Tracer(BitSet(index))
tracer(inds::NTuple{N,<:Integer}) where {N} = Tracer(BitSet(inds))
tracer(inds...) = tracer(inds)

# Utilities for accessing input indices
"""
inputs(tracer)

Return raw `UInt64` input indices of a [`Tracer`](@ref).
See also [`sortedinputs`](@ref).

## Example
```jldoctest
julia> t = tracer(1, 2, 4)
Tracer(1, 2, 4)

julia> inputs(t)
3-element Vector{UInt64}:
0x0000000000000004
0x0000000000000002
0x0000000000000001
```
"""
inputs(t::Tracer) = collect(keys(t.inputs.dict))

"""
sortedinputs(tracer)
sortedinputs([T=Int], tracer)

Return sorted input indices of a [`Tracer`](@ref).
See also [`inputs`](@ref).

## Example
```jldoctest
julia> t = tracer(1, 2, 4)
Tracer(1, 2, 4)

julia> sortedinputs(t)
3-element Vector{Int64}:
1
2
4

julia> sortedinputs(UInt8, t)
3-element Vector{UInt8}:
0x01
0x02
0x04
```
"""
sortedinputs(t::Tracer) = sortedinputs(Int, t)
sortedinputs(T::Type, t::Tracer) = convert.(T, sort!(inputs(t)))
inputs(t::Tracer) = collect(t.inputs)

function Base.show(io::IO, t::Tracer)
return Base.show_delim_array(io, sortedinputs(Int, t), "Tracer(", ',', ')', true)
return Base.show_delim_array(io, inputs(t), "Tracer(", ',', ')', true)
end
30 changes: 26 additions & 4 deletions test/benchmark.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using BenchmarkTools
using SparseConnectivityTracer
using Symbolics: Symbolics
using NNlib: conv

include("brusselator.jl")

Expand All @@ -24,8 +25,29 @@ function benchmark_brusselator(N::Integer, method=:tracer)
end
end

benchmark_brusselator(6, :tracer)
benchmark_brusselator(6, :symbolics)
function benchmark_conv(method=:tracer)
x = rand(28, 28, 3, 1) # WHCN image
w = rand(5, 5, 3, 16) # corresponds to Conv((5, 5), 3 => 16)
f(x) = conv(x, w)

benchmark_brusselator(24, :tracer)
benchmark_brusselator(24, :symbolics)
if method == :tracer
return @benchmark connectivity($f, $x)
elseif method == :symbolics
return @benchmark Symbolics.jacobian_sparsity($f, $x)
end
end

## Run Brusselator benchmarks
for N in (6, 24)
for method in (:tracer, :symbolics)
@info "Benchmarking Brusselator of size $N with $method..."
b = benchmark_brusselator(N, method)
display(b)
end
end

## Run conv benchmarks
@info "Benchmarking NNlib.conv with tracer..."
# Symbolics fails on this example
b = benchmark_conv(:tracer)
display(b)
8 changes: 4 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ DocMeta.setdocmeta!(
# Matrix multiplication
A = rand(1, 3)
yt = only(A * xt)
@test sortedinputs(yt) == [1, 2, 3]
@test inputs(yt) == [1, 2, 3]

@test connectivity(x -> only(A * x), x) ≈ [1 1 1]

# Custom functions
f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])]
yt = f(xt)
@test sortedinputs(yt[1]) == [1]
@test sortedinputs(yt[2]) == [1, 2]
@test sortedinputs(yt[3]) == [3]
@test inputs(yt[1]) == [1]
@test inputs(yt[2]) == [1, 2]
@test inputs(yt[3]) == [3]

@test connectivity(f, x) ≈ [1 0 0; 1 1 0; 0 0 1]

Expand Down
Loading