diff --git a/test/layers.jl b/test/layers.jl index 17cc3a57..282c6ba4 100644 --- a/test/layers.jl +++ b/test/layers.jl @@ -346,6 +346,24 @@ function test_summary() return n == 15 end +# test summary for abstract layer: +# +mutable struct AL <: AbstractLayer + w + function AL(i,o) + w = param(o,i) + return new(w) + end +end + +function test_abstract_layer_summary() + + al = AL(100,100) + n = summary(al) + return n == 1 +end + + function test_print() ch = Chain(Conv(3, 3, 3, 100), diff --git a/test/runtests.jl b/test/runtests.jl index 2e527e71..3e316b3d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -107,6 +107,7 @@ using Statistics: mean @test test_get_set_rnn() @test test_summary() +@test test_abstract_layer_summary() @test test_print()