-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
130 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
style = "blue" | ||
align_assignment = true | ||
align_struct_field = true | ||
align_conditional = true | ||
align_pair_arrow = true | ||
align_matrix = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,23 @@ | ||
using SparseConnectivityTracer | ||
using Documenter | ||
|
||
DocMeta.setdocmeta!(SparseConnectivityTracer, :DocTestSetup, :(using SparseConnectivityTracer); recursive=true) | ||
DocMeta.setdocmeta!( | ||
SparseConnectivityTracer, | ||
:DocTestSetup, | ||
:(using SparseConnectivityTracer); | ||
recursive=true, | ||
) | ||
|
||
makedocs(; | ||
modules=[SparseConnectivityTracer], | ||
authors="Adrian Hill <[email protected]>", | ||
sitename="SparseConnectivityTracer.jl", | ||
format=Documenter.HTML(; | ||
canonical="https://adrhill.github.io/SparseConnectivityTracer.jl", | ||
edit_link="main", | ||
assets=String[], | ||
canonical = "https://adrhill.github.io/SparseConnectivityTracer.jl", | ||
edit_link = "main", | ||
assets = String[], | ||
), | ||
pages=[ | ||
"Home" => "index.md", | ||
], | ||
pages=["Home" => "index.md"], | ||
) | ||
|
||
deploydocs(; | ||
repo="github.com/adrhill/SparseConnectivityTracer.jl", | ||
devbranch="main", | ||
) | ||
deploydocs(; repo="github.com/adrhill/SparseConnectivityTracer.jl", devbranch="main") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,68 @@ | ||
module SparseConnectivityTracer | ||
|
||
# Write your package code here. | ||
# Input connectivity tracer | ||
struct Tracer <: Number | ||
inputs::Set{UInt64} # indices of connected, enumerated inputs | ||
end | ||
|
||
Tracer(i::Integer) = Tracer(Set{UInt64}(i)) | ||
Tracer(a::Tracer, b::Tracer) = Tracer(a.inputs ∪ b.inputs) | ||
|
||
# Enumerate inputs | ||
inputtrace(x) = inputtrace(x, 1) | ||
inputtrace(::Number, i) = Tracer(i) | ||
function inputtrace(x::AbstractArray, i) | ||
indices = (i - 1) .+ reshape(1:length(x), size(x)) | ||
return Tracer.(indices) | ||
end | ||
|
||
include("ops.jl") | ||
|
||
# Extent core operators | ||
for fn in (:+, :-, :*, :/, :^) | ||
@eval Base.$fn(a::Tracer, b::Tracer) = Tracer(a, b) | ||
for T in (:Number,) | ||
@eval Base.$fn(t::Tracer, ::$T) = t | ||
@eval Base.$fn(::$T, t::Tracer) = t | ||
end | ||
end | ||
|
||
Base.:^(a::Tracer, b::Tracer) = Tracer(a, b) | ||
for T in (:Number, :Integer, :Rational) | ||
@eval Base.:^(t::Tracer, ::$T) = t | ||
@eval Base.:^(::$T, t::Tracer) = t | ||
end | ||
Base.:^(t::Tracer, ::Irrational{:ℯ}) = t | ||
Base.:^(::Irrational{:ℯ}, t::Tracer) = t | ||
|
||
# Two-argument functions | ||
for fn in (:div, :fld, :cld) | ||
@eval Base.$fn(a::Tracer, b::Tracer) = Tracer(a, b) | ||
@eval Base.$fn(t::Tracer, ::Number) = t | ||
@eval Base.$fn(::Number, t::Tracer) = t | ||
end | ||
|
||
# Single-argument functions | ||
for fn in scalar_operations | ||
@eval Base.$fn(t::Tracer) = t | ||
end | ||
|
||
function connectivity(f, x) | ||
xt = inputtrace(x) | ||
yt = f(xt) | ||
n, m = length(xt), length(yt) | ||
|
||
# Construct connectivity matrix of size (ouput_dim, input_dim) | ||
C = BitArray(undef, m, n) | ||
for i in axes(C, 1) | ||
tracer = yt[i] | ||
for j in axes(C, 2) | ||
C[i, j] = j ∈ tracer.inputs | ||
end | ||
end | ||
return C | ||
end | ||
|
||
export connectivity | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
|
||
enumerate_tracers(x) = enumerate_tracers(x, 1) | ||
enumerate_tracers(::Number, i) = Tracer(i) | ||
function enumerate_tracers(x::AbstractArray, i) | ||
indices = (i - 1) .+ reshape(1:length(x), size(x)) | ||
return Tracer.(indices) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#! format: off | ||
scalar_operations = ( | ||
:exp2, :deg2rad, :rad2deg, | ||
:sincos, :sincospi, | ||
:cos, :cosd, :cosh, :cospi, :cosc, | ||
:sin, :sind, :sinh, :sinpi, :sinc, | ||
:tan, :tand, :tanh, | ||
:csc, :cscd, :csch, | ||
:sec, :secd, :sech, | ||
:cot, :cotd, :coth, | ||
:acos, :acosd, :acosh, | ||
:asin, :asind, :asinh, | ||
:atan, :atand, :atanh, | ||
:asec, :asech, | ||
:acsc, :acsch, | ||
:acot, :acoth, | ||
:exp, :expm1, :exp10, | ||
:frexp, :ldexp, | ||
) | ||
#! format: on |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,7 @@ | ||
[deps] | ||
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" | ||
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" | ||
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,29 @@ | ||
using SparseConnectivityTracer | ||
using Test | ||
using JuliaFormatter | ||
using Aqua | ||
using JET | ||
|
||
using LinearAlgebra | ||
using Random | ||
|
||
@testset "SparseConnectivityTracer.jl" begin | ||
@testset "Code quality (Aqua.jl)" begin | ||
@testset "Code formatting" begin | ||
@test JuliaFormatter.format( | ||
SparseConnectivityTracer; verbose=false, overwrite=false | ||
) | ||
end | ||
@testset "Aqua.jl tests" begin | ||
Aqua.test_all(SparseConnectivityTracer) | ||
end | ||
@testset "Code linting (JET.jl)" begin | ||
JET.test_package(SparseConnectivityTracer; target_defined_modules = true) | ||
@testset "JET tests" begin | ||
JET.test_package(SparseConnectivityTracer; target_defined_modules=true) | ||
end | ||
|
||
@testset "Connectivity" begin | ||
f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])] | ||
@test connectivity(f, rand(3)) == BitMatrix([1 0 0; 1 1 0; 0 0 1]) | ||
|
||
@test connectivity(identity, rand()) == BitMatrix([1;;]) | ||
end | ||
# Write your tests here. | ||
end |