From 8e593d40c2d8837cd73c647405c4bf274329c509 Mon Sep 17 00:00:00 2001 From: andre_ramos Date: Sun, 4 Aug 2024 00:32:08 -0400 Subject: [PATCH] fix estimation bug and tests --- paper_tests/m4_test/evaluate_model.jl | 2 +- paper_tests/simulation_test/evaluate_models.jl | 2 +- src/estimation_procedure/default_estimation_procedure.jl | 2 +- test/StateSpaceLearning.jl | 6 ++++++ 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/paper_tests/m4_test/evaluate_model.jl b/paper_tests/m4_test/evaluate_model.jl index cc8523f..ba0295a 100644 --- a/paper_tests/m4_test/evaluate_model.jl +++ b/paper_tests/m4_test/evaluate_model.jl @@ -9,7 +9,7 @@ function evaluate_SSL(initialization_df::DataFrame, results_df::DataFrame, input T= length(normalized_y) normalized_y = normalized_y[max(1, T-sample_size+1):end] output = StateSpaceLearning.fit_model(normalized_y; - model_input = Dict("stochastic_level" => true, "trend" => true, + model_input = Dict("level" => true, "stochastic_level" => true, "trend" => true, "stochastic_trend" => true, "seasonal" => true, "stochastic_seasonal" => true, "freq_seasonal" => 12, "outlier" => outlier, "ζ_ω_threshold" => 12), diff --git a/paper_tests/simulation_test/evaluate_models.jl b/paper_tests/simulation_test/evaluate_models.jl index 27b3722..42e4f74 100644 --- a/paper_tests/simulation_test/evaluate_models.jl +++ b/paper_tests/simulation_test/evaluate_models.jl @@ -11,7 +11,7 @@ function get_SSL_results(y_train::Vector{Float64}, true_features::Vector{Int64}, series_result=nothing t = @elapsed output = StateSpaceLearning.fit_model(y_train; Exogenous_X=X_train, - model_input = Dict("stochastic_level" => true, "trend" => true, + model_input = Dict("level" => true, "stochastic_level" => true, "trend" => true, "stochastic_trend" => true, "seasonal" => true, "stochastic_seasonal" => true, "freq_seasonal" => 12, "outlier" => false, "ζ_ω_threshold" => 12), diff --git a/src/estimation_procedure/default_estimation_procedure.jl b/src/estimation_procedure/default_estimation_procedure.jl index 3f6e37f..c8c52a8 100644 --- a/src/estimation_procedure/default_estimation_procedure.jl +++ b/src/estimation_procedure/default_estimation_procedure.jl @@ -225,5 +225,5 @@ function default_estimation_procedure(Estimation_X::Matrix{Tl}, estimation_y::Ve !penalize_initial_states ? ts_penalty_factor[components_indexes["initial_states"][2:end]] .= 0 : nothing end - return fit_lasso(Estimation_X, estimation_y, α, information_criteria, penalize_exogenous, components_indexes, penalty_factor; rm_average = false) + return fit_lasso(Estimation_X, estimation_y, α, information_criteria, penalize_exogenous, components_indexes, ts_penalty_factor; rm_average = false) end diff --git a/test/StateSpaceLearning.jl b/test/StateSpaceLearning.jl index 5ffdd38..187b635 100644 --- a/test/StateSpaceLearning.jl +++ b/test/StateSpaceLearning.jl @@ -36,4 +36,10 @@ end @test_throws AssertionError StateSpaceLearning.forecast(output1, 10; Exogenous_Forecast = rand(5, 3)) @test_throws AssertionError StateSpaceLearning.forecast(output2, 10) @test_throws AssertionError StateSpaceLearning.forecast(output2, 10; Exogenous_Forecast = rand(5, 3)) + + y3 = [4.718, 4.77, 4.882, 4.859, 4.795, 4.905, 4.997, 4.997, 4.912, 4.779, 4.644, 4.77, 4.744, 4.836, 4.948, 4.905, 4.828, 5.003, 5.135, 5.135, 5.062, 4.89, 4.736, 4.941, 4.976, 5.01, 5.181, 5.093, 5.147, 5.181, 5.293, 5.293, 5.214, 5.087, 4.983, 5.111, 5.141, 5.192, 5.262, 5.198, 5.209, 5.384, 5.438, 5.488, 5.342, 5.252, 5.147, 5.267, 5.278, 5.278, 5.463, 5.459, 5.433, 5.493, 5.575, 5.605, 5.468, 5.351, 5.192, 5.303, 5.318, 5.236, 5.459, 5.424, 5.455, 5.575, 5.71, 5.68, 5.556, 5.433, 5.313, 5.433, 5.488, 5.451, 5.587, 5.594, 5.598, 5.752, 5.897, 5.849, 5.743, 5.613, 5.468, 5.627, 5.648, 5.624, 5.758, 5.746, 5.762, 5.924, 6.023, 6.003, 5.872, 5.723, 5.602, 5.723, 5.752, 5.707, 5.874, 5.852, 5.872, 6.045, 6.142, 6.146, 6.001, 5.849, 5.72, 5.817, 5.828, 5.762, 5.891, 5.852, 5.894, 6.075, 6.196, 6.224, 6.001, 5.883, 5.736, 5.82, 5.886, 5.834, 6.006, 5.981, 6.04, 6.156, 6.306, 6.326, 6.137, 6.008, 5.891, 6.003, 6.033, 5.968, 6.037, 6.133, 6.156, 6.282, 6.432, 6.406, 6.23, 6.133, 5.966, 6.068] + output3 = StateSpaceLearning.fit_model(y3) + forecast3 = trunc.(StateSpaceLearning.forecast(output3, 18); digits = 3) + @assert forecast3 == [6.11, 6.082, 6.221, 6.19, 6.197, 6.328, 6.447, 6.44, 6.285, 6.163, 6.026, 6.142, 6.166, 6.138, 6.278, 6.246, 6.253, 6.384] + end \ No newline at end of file