Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding cross validation for point forecast #24

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/TimeSeriesInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ using RecipesBase
using Statistics

include("timeseries.jl")
include("forecast.jl")
include("models.jl")
include("forecast.jl")
include("FileFormats/FileFormats.jl")

end
56 changes: 44 additions & 12 deletions src/forecast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ abstract type ProbabilisticForecast <: Forecast end

Define the results of point forecasts.
"""
mutable struct PointForecast{T<:Real} <: Forecast
mutable struct PointForecast{T <: Real} <: Forecast
name::String
timestamps::Vector{DateTime}
forecast::Vector{T}
Expand All @@ -21,7 +21,7 @@ mutable struct PointForecast{T<:Real} <: Forecast
name::String,
timestamps::Vector{DateTime},
forecast::Vector{T},
) where {T<:Real}
) where {T <: Real}

if length(timestamps) != length(forecast)
throw(DimensionMismatch("timestamps and forecast do not have the same length."))
Expand All @@ -44,13 +44,13 @@ function PointForecast(
name::String,
timestamps::Vector{Date},
forecast::Vector{T},
) where {T<:Real}
) where {T <: Real}

return PointForecast(name, DateTime.(timestamps), forecast)
end

## Evaluation Metrics for Point Forecast
struct PointForecastMetrics{T<:Real}
struct PointForecastMetrics{T <: Real}
errors::Vector{T}
absolute_percentage_errors::Vector{T}
end
Expand Down Expand Up @@ -82,12 +82,44 @@ error(real::Vector{T}, forecast::Vector{T}) where {T} = real .- forecast
absolute_percentage_error(real::Vector{T}, forecast::Vector{T}) where {T} =
abs.(error(real, forecast) ./ real)

"""
cross_validation(fit_input::FitInput{T}, fit_function::Function, predict_function::Function, metric_function::Function,
min_history::Int, horizon::Int)

Function that receives all avaiable data in fit_input, a fit_function and a prediction_funtion,
a validation metric (eg. mape), a min_history as a starting point and a prediction horizon.
Returns the result of the cross validation separeted by lead time.
"""
function cross_validation(fit_input::FitInput{T},
fit_function::Function,
predict_function::Function,
metric_function::Function,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should use this function forecast_metrics

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and then make another function to aggregate the vector of metrics, maybe then add this metric function

min_history::Int,
horizon::Int) where T
n = length(fit_input.dependent[1].timestamps)
metric_matrix = Matrix{Float64}(undef, length(min_history:(n - horizon)), horizon)
for i = min_history:(n - horizon)
training_dependent = map(x -> TimeSeries(x.name, x.timestamps[1:i], x.vals[1:i]), fit_input.dependent)
training_exogenous = map(x -> TimeSeries(x.name, x.timestamps[1:i], x.vals[1:i]), fit_input.exogenous)
training_fit_input = FitInput(fit_input.parameters, training_dependent, training_exogenous)
validation_fit_result = fit_function(training_fit_input)
validation_timestamps_forecast = fit_input.dependent[1].timestamps[(i + 1):(i + horizon)]
validation_exogenous_forecast = map(x -> TimeSeries(x.name, x.timestamps[(i + 1):(i + horizon)], x.vals[(i + 1):(i + horizon)]), fit_input.exogenous)
validation_simulate_input = SimulateInput(training_fit_input, validation_timestamps_forecast,
validation_exogenous_forecast, validation_fit_result)
point_forecast = predict_function(validation_simulate_input)
observed = map(x -> TimeSeries(x.name, x.timestamps[(i + 1):(i + horizon)], x.vals[(i + 1):(i + horizon)]), fit_input.dependent)
metric_matrix[(i - min_history + 1), :] .= metric_function(observed[1].vals, point_forecast.forecast)
end
return mean(metric_matrix, dims=1)
end

"""
ScenariosForecast

Define the probabilistic forecast results calculated from scenarios.
"""
mutable struct ScenariosForecast{T<:Real} <: ProbabilisticForecast
mutable struct ScenariosForecast{T <: Real} <: ProbabilisticForecast
name::String
timestamps::Vector{DateTime}
scenarios::Matrix{T}
Expand All @@ -100,7 +132,7 @@ mutable struct ScenariosForecast{T<:Real} <: ProbabilisticForecast
scenarios::Matrix{T},
quantiles_probabilities::Vector{T},
quantiles::Matrix{T},
) where {T<:Real}
) where {T <: Real}

if length(timestamps) != size(scenarios, 1)
throw(DimensionMismatch("timestamps and scenarios do not have the same length."))
Expand Down Expand Up @@ -133,7 +165,7 @@ function ScenariosForecast(
scenarios::Matrix{T},
quantiles_probabilities::Vector{T},
quantiles::Matrix{T},
) where {T<:Real}
) where {T <: Real}

return ScenariosForecast(
name,
Expand Down Expand Up @@ -179,7 +211,7 @@ function forecast_metrics(
end

function get_quantiles(quantile_probs::Vector{T}, scenarios::Matrix{T}) where {T}
quantiles = mapslices(x -> quantile(x, quantile_probs), scenarios; dims = 2)
quantiles = mapslices(x -> quantile(x, quantile_probs), scenarios; dims=2)
return quantiles
end

Expand Down Expand Up @@ -298,7 +330,7 @@ function mean_crps(forecast::Vector{ScenariosForecastMetrics})
for (i, forec) in enumerate(forecast)
m_crps[:, i] = forec.crps
end
return vec(mean(m_crps, dims = 2))
return vec(mean(m_crps, dims=2))
end


Expand All @@ -307,7 +339,7 @@ end

Define the probabilistic forecast results calculated from distributions.
"""
mutable struct QuantilesForecast{T<:Real} <: ProbabilisticForecast
mutable struct QuantilesForecast{T <: Real} <: ProbabilisticForecast
name::String
timestamps::Vector{DateTime}
quantiles_probabilities::Vector{T}
Expand All @@ -318,7 +350,7 @@ mutable struct QuantilesForecast{T<:Real} <: ProbabilisticForecast
timestamps::Vector{DateTime},
quantiles_probabilities::Vector{T},
quantiles::Matrix{T},
) where {T<:Real}
) where {T <: Real}

if length(timestamps) != size(quantiles, 1)
throw(DimensionMismatch("timestamps and quantiles do not have the same length."))
Expand Down Expand Up @@ -346,7 +378,7 @@ function QuantilesForecast(
timestamps::Vector{Date},
quantiles_probabilities::Vector{T},
quantiles::Matrix{T},
) where {T<:Real}
) where {T <: Real}

return QuantilesForecast(
name,
Expand Down