Skip to content

Commit

Permalink
Merge pull request #43 from JuliaReach/schillic/print
Browse files Browse the repository at this point in the history
Better printing of neural networks
  • Loading branch information
schillic authored May 25, 2024
2 parents 55869a1 + b34e7f4 commit 1e98daa
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 2 deletions.
10 changes: 10 additions & 0 deletions src/Architecture/ActivationFunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ struct Id <: ActivationFunction end

(::Id)(x) = x

Base.show(io::IO, ::Id) = print(io, Id)

"""
ReLU
Expand All @@ -31,6 +33,8 @@ struct ReLU <: ActivationFunction end

(::ReLU)(x) = max.(x, zero(eltype(x)))

Base.show(io::IO, ::ReLU) = print(io, ReLU)

"""
Sigmoid
Expand All @@ -44,6 +48,8 @@ struct Sigmoid <: ActivationFunction end

(::Sigmoid)(x) = @. 1 / (1 + exp(-x))

Base.show(io::IO, ::Sigmoid) = print(io, Sigmoid)

"""
Tanh
Expand All @@ -57,6 +63,8 @@ struct Tanh <: ActivationFunction end

(::Tanh)(x) = tanh.(x)

Base.show(io::IO, ::Tanh) = print(io, Tanh)

"""
LeakyReLU{N<:Number}
Expand All @@ -78,6 +86,8 @@ end
(lr::LeakyReLU)(x::Number) = x >= zero(x) ? x : lr.slope * x
(lr::LeakyReLU)(x::AbstractVector) = lr.(x)

Base.show(io::IO, lr::LeakyReLU) = print(io, "$LeakyReLU($(lr.slope))")

# constant instances of each activation function
const _id = Id()
const _relu = ReLU()
Expand Down
2 changes: 1 addition & 1 deletion src/Architecture/DenseLayerOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ function Base.:isapprox(L1::DenseLayerOp, L2::DenseLayerOp; atol::Real=0,
end

function Base.show(io::IO, L::DenseLayerOp)
str = "$(string(DenseLayerOp)) with $(dim_in(L)) inputs, $(dim_out(L)) " *
str = "$DenseLayerOp with $(dim_in(L)) inputs, $(dim_out(L)) " *
"outputs, and $(L.activation) activation"
return print(io, str)
end
Expand Down
2 changes: 1 addition & 1 deletion src/Architecture/FeedforwardNetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ function load_Flux_convert_network()
end

function Base.show(io::IO, N::FeedforwardNetwork)
str = "$(string(FeedforwardNetwork)) with $(dim_in(N)) inputs, " *
str = "$FeedforwardNetwork with $(dim_in(N)) inputs, " *
"$(dim_out(N)) outputs, and $(length(N)) layers:"
for l in layers(N)
str *= "\n- $l"
Expand Down
5 changes: 5 additions & 0 deletions test/Architecture/ActivationFunction.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# printing
io = IOBuffer()
for act in (Id(), ReLU(), Sigmoid(), Tanh(), LeakyReLU(0.1))
println(io, act)
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import Flux, MAT, ONNX, YAML
struct TestActivation <: ActivationFunction end

@testset "Architecture" begin
@testset "ActivationFunction" begin
include("Architecture/ActivationFunction.jl")
end
@testset "AbstractLayerOp" begin
include("Architecture/AbstractLayerOp.jl")
end
Expand Down

0 comments on commit 1e98daa

Please sign in to comment.