Skip to content

Commit

Permalink
Initial draft
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed Mar 28, 2024
1 parent bd6b9ec commit 9a1b46c
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 16 deletions.
6 changes: 6 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
style = "blue"
align_assignment = true
align_struct_field = true
align_conditional = true
align_pair_arrow = true
align_matrix = true
22 changes: 11 additions & 11 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
using SparseConnectivityTracer
using Documenter

DocMeta.setdocmeta!(SparseConnectivityTracer, :DocTestSetup, :(using SparseConnectivityTracer); recursive=true)
DocMeta.setdocmeta!(
SparseConnectivityTracer,
:DocTestSetup,
:(using SparseConnectivityTracer);
recursive=true,
)

makedocs(;
modules=[SparseConnectivityTracer],
authors="Adrian Hill <[email protected]>",
sitename="SparseConnectivityTracer.jl",
format=Documenter.HTML(;
canonical="https://adrhill.github.io/SparseConnectivityTracer.jl",
edit_link="main",
assets=String[],
canonical = "https://adrhill.github.io/SparseConnectivityTracer.jl",
edit_link = "main",
assets = String[],
),
pages=[
"Home" => "index.md",
],
pages=["Home" => "index.md"],
)

deploydocs(;
repo="github.com/adrhill/SparseConnectivityTracer.jl",
devbranch="main",
)
deploydocs(; repo="github.com/adrhill/SparseConnectivityTracer.jl", devbranch="main")
65 changes: 64 additions & 1 deletion src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,68 @@
module SparseConnectivityTracer

# Write your package code here.
# Input connectivity tracer
struct Tracer <: Number
inputs::Set{UInt64} # indices of connected, enumerated inputs
end

Tracer(i::Integer) = Tracer(Set{UInt64}(i))
Tracer(a::Tracer, b::Tracer) = Tracer(a.inputs b.inputs)

# Enumerate inputs
inputtrace(x) = inputtrace(x, 1)
inputtrace(::Number, i) = Tracer(i)
function inputtrace(x::AbstractArray, i)
indices = (i - 1) .+ reshape(1:length(x), size(x))
return Tracer.(indices)
end

include("ops.jl")

# Extent core operators
for fn in (:+, :-, :*, :/, :^)
@eval Base.$fn(a::Tracer, b::Tracer) = Tracer(a, b)
for T in (:Number,)
@eval Base.$fn(t::Tracer, ::$T) = t
@eval Base.$fn(::$T, t::Tracer) = t
end
end

Base.:^(a::Tracer, b::Tracer) = Tracer(a, b)
for T in (:Number, :Integer, :Rational)
@eval Base.:^(t::Tracer, ::$T) = t
@eval Base.:^(::$T, t::Tracer) = t
end
Base.:^(t::Tracer, ::Irrational{:ℯ}) = t
Base.:^(::Irrational{:ℯ}, t::Tracer) = t

# Two-argument functions
for fn in (:div, :fld, :cld)
@eval Base.$fn(a::Tracer, b::Tracer) = Tracer(a, b)
@eval Base.$fn(t::Tracer, ::Number) = t
@eval Base.$fn(::Number, t::Tracer) = t
end

# Single-argument functions
for fn in scalar_operations
@eval Base.$fn(t::Tracer) = t
end

function connectivity(f, x)
xt = inputtrace(x)
yt = f(xt)
n, m = length(xt), length(yt)

# Construct connectivity matrix of size (ouput_dim, input_dim)
C = BitArray(undef, m, n)
for i in axes(C, 1)
tracer = yt[i]
for j in axes(C, 2)
C[i, j] = j tracer.inputs
end
end
return C
end

export connectivity

end
7 changes: 7 additions & 0 deletions src/enumerate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

enumerate_tracers(x) = enumerate_tracers(x, 1)
enumerate_tracers(::Number, i) = Tracer(i)
function enumerate_tracers(x::AbstractArray, i)
indices = (i - 1) .+ reshape(1:length(x), size(x))
return Tracer.(indices)
end
20 changes: 20 additions & 0 deletions src/ops.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#! format: off
scalar_operations = (
:exp2, :deg2rad, :rad2deg,
:sincos, :sincospi,
:cos, :cosd, :cosh, :cospi, :cosc,
:sin, :sind, :sinh, :sinpi, :sinc,
:tan, :tand, :tanh,
:csc, :cscd, :csch,
:sec, :secd, :sech,
:cot, :cotd, :coth,
:acos, :acosd, :acosh,
:asin, :asind, :asinh,
:atan, :atand, :atanh,
:asec, :asech,
:acsc, :acsch,
:acot, :acoth,
:exp, :expm1, :exp10,
:frexp, :ldexp,
)
#! format: on
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
23 changes: 19 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,29 @@
using SparseConnectivityTracer
using Test
using JuliaFormatter
using Aqua
using JET

using LinearAlgebra
using Random

@testset "SparseConnectivityTracer.jl" begin
@testset "Code quality (Aqua.jl)" begin
@testset "Code formatting" begin
@test JuliaFormatter.format(
SparseConnectivityTracer; verbose=false, overwrite=false
)
end
@testset "Aqua.jl tests" begin
Aqua.test_all(SparseConnectivityTracer)
end
@testset "Code linting (JET.jl)" begin
JET.test_package(SparseConnectivityTracer; target_defined_modules = true)
@testset "JET tests" begin
JET.test_package(SparseConnectivityTracer; target_defined_modules=true)
end

@testset "Connectivity" begin
f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])]
@test connectivity(f, rand(3)) == BitMatrix([1 0 0; 1 1 0; 0 0 1])

@test connectivity(identity, rand()) == BitMatrix([1;;])
end
# Write your tests here.
end

0 comments on commit 9a1b46c

Please sign in to comment.