Skip to content

Commit

Permalink
fix tests and dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Dec 31, 2024
1 parent 5718a1a commit 1a09c6b
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 9 deletions.
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ CellularAutomata = "878138dc-5b27-11ea-1a71-cb95d38d6b29"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

[weakdeps]
Expand All @@ -34,22 +34,24 @@ LIBSVM = "0.8"
LinearAlgebra = "1.10"
MLJLinearModels = "0.9.2, 0.10"
NNlib = "0.8.4, 0.9"
Optim = "1"
PartialFunctions = "1.2"
Random = "1.10"
Reexport = "1.2.2"
SafeTestsets = "0.1"
Statistics = "1.10"
StatsBase = "0.34.4"
Test = "1"
WeightInitializers = "1.0.4"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "SafeTestsets", "Random", "DifferentialEquations"]
test = ["Aqua", "Test", "SafeTestsets", "Random", "DifferentialEquations", "MLJLinearModels", "LIBSVM"]
4 changes: 2 additions & 2 deletions src/ReservoirComputing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ using CellularAutomata
using Distances
using LinearAlgebra
using NNlib
using Optim
using PartialFunctions
using Random
using Reexport: Reexport, @reexport
using Statistics
using StatsBase: sample
using WeightInitializers: DeviceAgnostic, PartialFunction, Utils
@reexport using WeightInitializers

Expand Down Expand Up @@ -160,7 +160,7 @@ export train
export ESN
export HybridESN, KnowledgeModel
export DeepESN
export RECA
export RECA, sample
export RandomMapping, RandomMaps
export Generative, Predictive, OutputLayer

Expand Down
2 changes: 1 addition & 1 deletion src/esn/deepesn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ function DeepESN(train_data,
nla_type = NLADefault(),
states_type = StandardStates(),
washout::Int = 0,
rng = WeightInitializers._default_rng(),
rng = Utils.default_rng(),
T = Float64,
matrix_type = typeof(train_data))
if states_type isa AbstractPaddedStates
Expand Down
2 changes: 1 addition & 1 deletion src/esn/esn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function ESN(train_data,
nla_type = NLADefault(),
states_type = StandardStates(),
washout = 0,
rng = WeightInitializers._default_rng(),
rng = Utils.default_rng(),
T = Float32,
matrix_type = typeof(train_data))
if states_type isa AbstractPaddedStates
Expand Down
2 changes: 1 addition & 1 deletion src/esn/hybridesn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ function HybridESN(model,
nla_type = NLADefault(),
states_type = StandardStates(),
washout = 0,
rng = WeightInitializers._default_rng(),
rng = Utils.default_rng(),
T = Float32,
matrix_type = typeof(train_data))
train_data = vcat(train_data, model.model_data[:, 1:(end - 1)])
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Test
end

@testset "Echo State Networks" begin
@safetestset "ESN Input Layers" include("esn/test_inits.jl")
@safetestset "ESN Initializers" include("esn/test_inits.jl")
@safetestset "ESN Train and Predict" include("esn/test_train.jl")
@safetestset "ESN Drivers" include("esn/test_drivers.jl")
@safetestset "Hybrid ESN" include("esn/test_hybrid.jl")
Expand Down

0 comments on commit 1a09c6b

Please sign in to comment.