From 1a09c6b696393c90781db4d4d41efbbe2dabb47e Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Tue, 31 Dec 2024 15:32:48 +0100 Subject: [PATCH] fix tests and dependencies --- Project.toml | 8 +++++--- src/ReservoirComputing.jl | 4 ++-- src/esn/deepesn.jl | 2 +- src/esn/esn.jl | 2 +- src/esn/hybridesn.jl | 2 +- test/runtests.jl | 2 +- 6 files changed, 11 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 4c036942..8f973f73 100755 --- a/Project.toml +++ b/Project.toml @@ -9,11 +9,11 @@ CellularAutomata = "878138dc-5b27-11ea-1a71-cb95d38d6b29" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -Optim = "429524aa-4258-5aef-a3af-852621145aeb" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" [weakdeps] @@ -34,12 +34,12 @@ LIBSVM = "0.8" LinearAlgebra = "1.10" MLJLinearModels = "0.9.2, 0.10" NNlib = "0.8.4, 0.9" -Optim = "1" PartialFunctions = "1.2" Random = "1.10" Reexport = "1.2.2" SafeTestsets = "0.1" Statistics = "1.10" +StatsBase = "0.34.4" Test = "1" WeightInitializers = "1.0.4" julia = "1.10" @@ -47,9 +47,11 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" +LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b" +MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Test", "SafeTestsets", "Random", "DifferentialEquations"] +test = ["Aqua", "Test", "SafeTestsets", "Random", "DifferentialEquations", "MLJLinearModels", "LIBSVM"] diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl index 24fc5a42..3ef8d25c 100755 --- a/src/ReservoirComputing.jl +++ b/src/ReservoirComputing.jl @@ -5,11 +5,11 @@ using CellularAutomata using Distances using LinearAlgebra using NNlib -using Optim using PartialFunctions using Random using Reexport: Reexport, @reexport using Statistics +using StatsBase: sample using WeightInitializers: DeviceAgnostic, PartialFunction, Utils @reexport using WeightInitializers @@ -160,7 +160,7 @@ export train export ESN export HybridESN, KnowledgeModel export DeepESN -export RECA +export RECA, sample export RandomMapping, RandomMaps export Generative, Predictive, OutputLayer diff --git a/src/esn/deepesn.jl b/src/esn/deepesn.jl index 636e0db1..cd8de8c0 100755 --- a/src/esn/deepesn.jl +++ b/src/esn/deepesn.jl @@ -80,7 +80,7 @@ function DeepESN(train_data, nla_type = NLADefault(), states_type = StandardStates(), washout::Int = 0, - rng = WeightInitializers._default_rng(), + rng = Utils.default_rng(), T = Float64, matrix_type = typeof(train_data)) if states_type isa AbstractPaddedStates diff --git a/src/esn/esn.jl b/src/esn/esn.jl index f53939a5..1c1a7a65 100755 --- a/src/esn/esn.jl +++ b/src/esn/esn.jl @@ -54,7 +54,7 @@ function ESN(train_data, nla_type = NLADefault(), states_type = StandardStates(), washout = 0, - rng = WeightInitializers._default_rng(), + rng = Utils.default_rng(), T = Float32, matrix_type = typeof(train_data)) if states_type isa AbstractPaddedStates diff --git a/src/esn/hybridesn.jl b/src/esn/hybridesn.jl index ad134a9e..b766b013 100755 --- a/src/esn/hybridesn.jl +++ b/src/esn/hybridesn.jl @@ -126,7 +126,7 @@ function HybridESN(model, nla_type = NLADefault(), states_type = StandardStates(), washout = 0, - rng = WeightInitializers._default_rng(), + rng = Utils.default_rng(), T = Float32, matrix_type = typeof(train_data)) train_data = vcat(train_data, model.model_data[:, 1:(end - 1)]) diff --git a/test/runtests.jl b/test/runtests.jl index 27a8ed2c..8f051129 100755 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,7 +7,7 @@ using Test end @testset "Echo State Networks" begin - @safetestset "ESN Input Layers" include("esn/test_inits.jl") + @safetestset "ESN Initializers" include("esn/test_inits.jl") @safetestset "ESN Train and Predict" include("esn/test_train.jl") @safetestset "ESN Drivers" include("esn/test_drivers.jl") @safetestset "Hybrid ESN" include("esn/test_hybrid.jl")