From b550487602506e35d6d86249b4f79990aac3b5ca Mon Sep 17 00:00:00 2001 From: Guillaume Date: Mon, 5 Jul 2021 22:13:41 -0400 Subject: [PATCH] Added util function to generate Poisson Spike Train based on value vector --- src/WaspNet.jl | 5 ++++- src/utilities/poisson.jl | 14 ++++++++++++++ src/utils.jl | 6 ++++-- test/utility_tests.jl | 6 ++++++ 4 files changed, 28 insertions(+), 3 deletions(-) create mode 100644 src/utilities/poisson.jl diff --git a/src/WaspNet.jl b/src/WaspNet.jl index 120818c..3ff0cd2 100644 --- a/src/WaspNet.jl +++ b/src/WaspNet.jl @@ -8,10 +8,11 @@ using Parameters using Random include("types.jl") + include("defs.jl") -include("neurons.jl") include("layer.jl") include("network.jl") +include("neurons.jl") include("simulate.jl") include("utils.jl") @@ -22,4 +23,6 @@ export update export batch_layer_construction, network_constructor, layer_constructor, feed_forward_network +export poissonST + end # module diff --git a/src/utilities/poisson.jl b/src/utilities/poisson.jl new file mode 100644 index 0000000..ca3dfbe --- /dev/null +++ b/src/utilities/poisson.jl @@ -0,0 +1,14 @@ +""" + poissonST(l::AbstractVector) + +Generates Poisson Spike Trains based on the normalized vector. Each +pseudo-neuron (probability p in vector), fires with probability p at each +timestep of simulation. + +# Inputs +- `l`: array of values +""" + +function poissonST(l::AbstractVector) + return poissonST(t) = [Float64(rand(Bernoulli(p), 1)[1]) for p in normalize(l)] +end diff --git a/src/utils.jl b/src/utils.jl index ede8405..757c1b6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,6 +1,8 @@ -# Various useful utility functions +# Various useful utility functions include("utilities/utils.jl") # A pruning framework to remove neurons from Networks and Layers -include("utilities/pruning.jl") \ No newline at end of file +include("utilities/pruning.jl") + +include("utilities/poisson.jl") diff --git a/test/utility_tests.jl b/test/utility_tests.jl index 1e81cf3..b98d359 100644 --- a/test/utility_tests.jl +++ b/test/utility_tests.jl @@ -21,4 +21,10 @@ using WaspNet include("pruning_tests.jl") + @test begin + inputs = [1,1] + p = poissonST(inputs) + spikes = p(1) + spikes == [0,0] || spikes == [0,1] || spikes == [1,0] || spikes == [1,1] + end end